Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
ca7a2e34
Commit
ca7a2e34
authored
Feb 28, 2022
by
Vijay Korthikanti
Browse files
Fourth phase vision merge: classification and segmentation tasks
parent
53f3efc4
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
2329 additions
and
100 deletions
+2329
-100
megatron/model/transformer.py
megatron/model/transformer.py
+4
-2
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+3
-0
tasks/finetune_utils.py
tasks/finetune_utils.py
+1
-1
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+46
-7
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+3
-5
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+62
-63
tasks/vision/main.py
tasks/vision/main.py
+24
-22
tasks/vision/segmentation/cityscapes.py
tasks/vision/segmentation/cityscapes.py
+206
-0
tasks/vision/segmentation/data.py
tasks/vision/segmentation/data.py
+154
-0
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+251
-0
tasks/vision/segmentation/finetune_setr.py
tasks/vision/segmentation/finetune_setr.py
+225
-0
tasks/vision/segmentation/metrics.py
tasks/vision/segmentation/metrics.py
+594
-0
tasks/vision/segmentation/seg_heads.py
tasks/vision/segmentation/seg_heads.py
+143
-0
tasks/vision/segmentation/seg_models.py
tasks/vision/segmentation/seg_models.py
+95
-0
tasks/vision/segmentation/transforms.py
tasks/vision/segmentation/transforms.py
+433
-0
tasks/vision/segmentation/utils.py
tasks/vision/segmentation/utils.py
+85
-0
No files found.
megatron/model/transformer.py
View file @
ca7a2e34
...
@@ -660,6 +660,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -660,6 +660,7 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
post_layer_norm
=
True
,
pre_process
=
True
,
post_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
...
@@ -667,6 +668,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -667,6 +668,7 @@ class ParallelTransformer(MegatronModule):
self
.
bf16
=
args
.
bf16
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
post_layer_norm
=
post_layer_norm
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
input_tensor
=
None
...
@@ -739,7 +741,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -739,7 +741,7 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
...
@@ -870,7 +872,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -870,7 +872,7 @@ class ParallelTransformer(MegatronModule):
if
self
.
post_process
:
if
self
.
post_process
:
# Reverting data format change [s b h] --> [b s h].
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
if
self
.
post_layer_norm
else
hidden_states
else
:
else
:
output
=
hidden_states
output
=
hidden_states
...
...
megatron/model/vision/vit_backbone.py
View file @
ca7a2e34
...
@@ -148,6 +148,7 @@ class VitBackbone(MegatronModule):
...
@@ -148,6 +148,7 @@ class VitBackbone(MegatronModule):
post_process
=
True
,
post_process
=
True
,
class_token
=
True
,
class_token
=
True
,
single_token_output
=
False
,
single_token_output
=
False
,
post_layer_norm
=
True
,
drop_path_rate
=
0.0
):
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
args
=
get_args
()
...
@@ -165,6 +166,7 @@ class VitBackbone(MegatronModule):
...
@@ -165,6 +166,7 @@ class VitBackbone(MegatronModule):
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
class_token
=
class_token
self
.
class_token
=
class_token
self
.
post_layer_norm
=
post_layer_norm
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
patch_dim
=
args
.
patch_dim
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_h
=
args
.
img_h
...
@@ -218,6 +220,7 @@ class VitBackbone(MegatronModule):
...
@@ -218,6 +220,7 @@ class VitBackbone(MegatronModule):
self
.
scaled_init_method
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
post_layer_norm
=
self
.
post_layer_norm
,
drop_path_rate
=
self
.
drop_path_rate
drop_path_rate
=
self
.
drop_path_rate
)
)
...
...
tasks/finetune_utils.py
View file @
ca7a2e34
...
@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
...
@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix
=
'iteration {}'
.
format
(
iteration
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
evaluate_and_print_results
(
prefix
,
forward_step
,
valid_dataloader
,
model
,
valid_dataloader
,
model
,
iteration
,
False
)
iteration
,
None
,
False
)
# Exiting based on iterations
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
...
...
tasks/vision/classification/classification.py
View file @
ca7a2e34
...
@@ -15,12 +15,15 @@
...
@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation."""
"""Vision-classification finetuning/evaluation."""
from
megatron
import
get_args
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.model.vi
t_model
import
Vit
Model
from
megatron.model.vi
sion.classification
import
VitClassification
Model
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
tasks.vision.eval_utils
import
accuracy_func_provider
from
tasks.vision.
classification.
eval_utils
import
accuracy_func_provider
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
finetune
from
megatron.utils
import
average_losses_across_data_parallel_group
def
classification
():
def
classification
():
...
@@ -30,7 +33,7 @@ def classification():
...
@@ -30,7 +33,7 @@ def classification():
train_ds
,
valid_ds
=
build_train_valid_datasets
(
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
data_path
=
args
.
data_path
,
crop
_size
=
args
.
img_
dim
,
image
_size
=
(
args
.
img_
h
,
args
.
img_w
)
,
)
)
return
train_ds
,
valid_ds
return
train_ds
,
valid_ds
...
@@ -40,16 +43,52 @@ def classification():
...
@@ -40,16 +43,52 @@ def classification():
print_rank_0
(
"building classification model for ImageNet ..."
)
print_rank_0
(
"building classification model for ImageNet ..."
)
return
VitModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
return
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
"""Finetune/evaluate."""
"""Finetune/evaluate."""
finetune
(
finetune
(
train_valid_datasets_provider
,
train_valid_datasets_provider
,
model_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
)
def
main
():
def
main
():
classification
()
classification
()
tasks/vision/classification/eval_utils.py
View file @
ca7a2e34
...
@@ -33,11 +33,10 @@ def accuracy_func_provider():
...
@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies."""
"""Provide function that calculates accuracies."""
args
=
get_args
()
args
=
get_args
()
data_path
=
args
.
data_path
data_path
=
args
.
data_path
crop_size
=
args
.
img_
dim
crop_size
=
(
args
.
img_
h
,
args
.
img_w
)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
# Build dataloaders.
val_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"val"
)
val_data_path
=
data_path
[
1
]
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
transform_val
=
transforms
.
Compose
(
transform_val
=
transforms
.
Compose
(
[
[
...
@@ -54,6 +53,7 @@ def accuracy_func_provider():
...
@@ -54,6 +53,7 @@ def accuracy_func_provider():
args
.
micro_batch_size
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
)
def
metrics_func
(
model
,
epoch
):
def
metrics_func
(
model
,
epoch
):
...
@@ -71,7 +71,6 @@ def accuracy_func_provider():
...
@@ -71,7 +71,6 @@ def accuracy_func_provider():
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
"""Calculate correct over total answers"""
args
=
get_args
()
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
for
m
in
model
:
m
.
eval
()
m
.
eval
()
...
@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
...
@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images
,
labels
=
process_batch
(
batch_
)
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
# Forward model.
args
=
get_args
()
output_tensor
=
model
(
images
)
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
...
...
tasks/vision/finetune_utils.py
View file @
ca7a2e34
...
@@ -17,11 +17,10 @@
...
@@ -17,11 +17,10 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
,
utils
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
evaluate_and_print_results
...
@@ -29,7 +28,10 @@ from megatron.training import setup_model_and_optimizer
...
@@ -29,7 +28,10 @@ from megatron.training import setup_model_and_optimizer
from
megatron.training
import
train_step
from
megatron.training
import
train_step
from
megatron.training
import
training_log
from
megatron.training
import
training_log
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
,
print_params_min_max_norm
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
,
ModelType
def
process_batch
(
batch
):
def
process_batch
(
batch
):
...
@@ -39,45 +41,16 @@ def process_batch(batch):
...
@@ -39,45 +41,16 @@ def process_batch(batch):
return
images
,
labels
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
logits
=
output_tensor
num_workers
,
drop_last
,
shuffle
):
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
# Sampler.
world_size
=
mpu
.
get_data_parallel_world_size
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
)
# Data loader. Note that batch size is the per GPU batch size.
# Data loader. Note that batch size is the per GPU batch size.
...
@@ -112,14 +85,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -112,14 +85,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0
(
'building train and validation dataloaders ...'
)
print_rank_0
(
'building train and validation dataloaders ...'
)
# Training dataset.
# Training dataset.
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
False
,
True
)
# Set the training iterations.
# Set the training iterations.
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
# shuffling so we can just use a simple infinite loop.
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
True
,
False
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
# Now that we've built the data loaders, set batch_size arguments
# Now that we've built the data loaders, set batch_size arguments
...
@@ -132,6 +105,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -132,6 +105,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
return
train_dataloader
,
valid_dataloader
return
train_dataloader
,
valid_dataloader
def
_train
(
def
_train
(
model
,
model
,
optimizer
,
optimizer
,
...
@@ -140,6 +114,7 @@ def _train(
...
@@ -140,6 +114,7 @@ def _train(
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
end_of_epoch_callback
,
process_non_loss_data_func
=
None
):
):
"""Train the model."""
"""Train the model."""
args
=
get_args
()
args
=
get_args
()
...
@@ -167,10 +142,12 @@ def _train(
...
@@ -167,10 +142,12 @@ def _train(
# Set the data loader epoch to shuffle the index iterator.
# Set the data loader epoch to shuffle the index iterator.
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
train_dataloader
.
dataset
.
set_epoch
(
epoch
)
# For all the batches in the dataset.
# For all the batches in the dataset.
for
iteration_
,
batch
in
enumerate
(
train_dataloader
):
for
iteration_
,
batch
in
enumerate
(
train_dataloader
):
args
.
curr_iteration
=
iteration_
# Ignore the iterations before starting value
# Ignore the iterations before starting value
if
iteration_
<
start_iteration
:
if
iteration_
<
start_iteration
:
continue
continue
...
@@ -185,8 +162,6 @@ def _train(
...
@@ -185,8 +162,6 @@ def _train(
# Logging.
# Logging.
params_norm
=
None
params_norm
=
None
if
args
.
log_params_norm
:
params_norm
=
calc_params_l2_norm
(
model
)
report_memory_flag
=
training_log
(
report_memory_flag
=
training_log
(
losses_dict
,
losses_dict
,
...
@@ -202,20 +177,16 @@ def _train(
...
@@ -202,20 +177,16 @@ def _train(
)
)
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
(
if
args
.
adlr_autoresume
and
\
iteration
%
args
.
adlr_autoresume_interval
==
0
iteration
%
args
.
adlr_autoresume_interval
==
0
:
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
check_adlr_autoresume_termination
(
opt_param_scheduler
)
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Checkpointing
# Checkpointing
if
(
if
args
.
save
and
args
.
save_interval
and
\
args
.
save
iteration
%
args
.
save_interval
==
0
:
and
args
.
save_interval
save_checkpoint
(
iteration
,
model
,
optimizer
,
and
iteration
%
args
.
save_interval
==
0
opt_param_scheduler
)
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Evaluation
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
@@ -226,12 +197,10 @@ def _train(
...
@@ -226,12 +197,10 @@ def _train(
valid_dataloader
,
valid_dataloader
,
model
,
model
,
iteration
,
iteration
,
process_non_loss_data_func
,
False
,
False
,
)
)
end_of_epoch_callback
(
model
,
epoch
)
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Callback at the end of each epoch.
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
if
end_of_epoch_callback
is
not
None
:
...
@@ -241,7 +210,9 @@ def _train(
...
@@ -241,7 +210,9 @@ def _train(
def
finetune
(
def
finetune
(
train_valid_datasets_provider
,
train_valid_datasets_provider
,
model_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
forward_step
,
model_type
=
ModelType
.
encoder_or_decoder
,
process_non_loss_data_func
=
None
,
end_of_epoch_callback_provider
=
None
,
end_of_epoch_callback_provider
=
None
,
):
):
"""Main finetune function used across all tasks."""
"""Main finetune function used across all tasks."""
...
@@ -266,7 +237,12 @@ def finetune(
...
@@ -266,7 +237,12 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
model_provider
,
model_type
,
scale_lr_cond
=
lambda
name
,
param
:
".head."
in
name
,
lr_mult
=
args
.
head_lr_mult
)
timers
(
"model and optimizer"
).
stop
()
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
...
@@ -274,13 +250,34 @@ def finetune(
...
@@ -274,13 +250,34 @@ def finetune(
# checkpoint.
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
()
timers
(
"pretrained checkpoint"
).
start
()
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
if
args
.
pretrained_checkpoint_type
==
'default'
:
args
.
load
=
args
.
pretrained_checkpoint
original_load
=
args
.
load
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
args
.
pretrained_checkpoint
args
.
load
=
original_load
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
original_load
elif
args
.
pretrained_checkpoint_type
==
'external'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
elif
args
.
pretrained_checkpoint_type
==
'constrastive'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
state_dict
=
state_dict
[
"model"
]
state_dict
=
{
k
.
replace
(
"teacher.backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
"teacher.backbone."
)}
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
else
:
raise
Exception
(
"pretrained checkpoint type {} not supported"
.
format
(
args
.
pretrained_checkpoint_type
))
# This is critical when only model is loaded. We should make sure
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
# master parameters are also updated.
optimizer
.
reload_model_params
()
optimizer
.
reload_model_params
()
timers
(
"pretrained checkpoint"
).
stop
()
timers
(
"pretrained checkpoint"
).
stop
()
# Print setup timing.
# Print setup timing.
...
@@ -305,11 +302,13 @@ def finetune(
...
@@ -305,11 +302,13 @@ def finetune(
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
end_of_epoch_callback
,
process_non_loss_data_func
,
)
)
# Or just evaluate.
# Or just evaluate.
else
:
else
:
if
end_of_epoch_callback
is
not
None
:
if
end_of_epoch_callback
is
not
None
:
print_rank_0
(
"evaluation only mode, setting epoch to -1"
)
print_rank_0
(
"evaluation only mode, setting epoch to -1"
)
end_of_epoch_callback
(
model
,
epoch
=-
1
,
output_predictions
=
True
)
end_of_epoch_callback
(
model
,
epoch
=-
1
)
print_rank_0
(
"done :-)"
)
print_rank_0
(
"done :-)"
)
tasks/vision/main.py
View file @
ca7a2e34
...
@@ -28,32 +28,24 @@ sys.path.append(
...
@@ -28,32 +28,24 @@ sys.path.append(
)
)
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
classification
import
main
def
get_tasks_args
(
parser
):
def
get_tasks_args
(
parser
):
"""Provide extra arguments required for tasks."""
"""Provide extra arguments required for tasks."""
group
=
parser
.
add_argument_group
(
title
=
"tasks"
)
group
=
parser
.
add_argument_group
(
title
=
"tasks"
)
group
.
add_argument
(
group
.
add_argument
(
'--task'
,
type
=
str
,
default
=
'segment'
,
"--epochs"
,
choices
=
[
'classify'
,
'segment_setr'
,
'segment_segformer'
],
type
=
int
,
help
=
'task name.'
)
default
=
None
,
group
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
None
,
help
=
"Number of finetunning epochs. Zero results in "
help
=
"Number of finetunning epochs. Zero results in "
"evaluation only."
,
"evaluation only."
)
)
group
.
add_argument
(
'--pretrained-checkpoint-type'
,
type
=
str
,
default
=
'default'
,
group
.
add_argument
(
choices
=
[
'default'
,
'external'
,
'constrastive'
],
"--pretrained-checkpoint"
,
help
=
'Type of pretrained checkpoint'
)
type
=
str
,
group
.
add_argument
(
"--pretrained-checkpoint"
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
"Pretrained checkpoint used for finetunning."
)
help
=
"Pretrained checkpoint used for finetunning."
,
group
.
add_argument
(
'--seg-stride'
,
type
=
int
,
default
=
None
,
)
help
=
'sliding window stride during evaluation'
)
group
.
add_argument
(
"--keep-last"
,
action
=
"store_true"
,
help
=
"Keep the last batch (maybe incomplete) in"
"the data loader"
,
)
return
parser
return
parser
...
@@ -61,4 +53,14 @@ if __name__ == "__main__":
...
@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
args
=
get_args
()
main
()
if
args
.
task
==
'classify'
:
from
tasks.vision.classification.classification
import
main
main
()
elif
args
.
task
==
'segment_setr'
:
from
tasks.vision.segmentation.finetune_setr
import
main
main
()
elif
args
.
task
==
'segment_segformer'
:
from
tasks.vision.segmentation.finetune_segformer
import
main
main
()
tasks/vision/segmentation/cityscapes.py
0 → 100644
View file @
ca7a2e34
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
import
torch
import
json
import
os
from
collections
import
namedtuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
import
numpy
as
np
from
torchvision.datasets.utils
import
extract_archive
,
verify_str_arg
,
iterable_to_str
from
torchvision.datasets
import
VisionDataset
from
PIL
import
Image
from
megatron
import
print_rank_0
class
Cityscapes
(
VisionDataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "coarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
"""
num_classes
=
19
ignore_index
=
19
color_table
=
torch
.
tensor
(
[[
128
,
64
,
128
],
[
244
,
35
,
232
],
[
70
,
70
,
70
],
[
102
,
102
,
156
],
[
190
,
153
,
153
],
[
153
,
153
,
153
],
[
250
,
170
,
30
],
[
220
,
220
,
0
],
[
107
,
142
,
35
],
[
152
,
251
,
152
],
[
70
,
130
,
180
],
[
220
,
20
,
60
],
[
255
,
0
,
0
],
[
0
,
0
,
142
],
[
0
,
0
,
70
],
[
0
,
60
,
100
],
[
0
,
80
,
100
],
[
0
,
0
,
230
],
[
119
,
11
,
32
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass
=
namedtuple
(
'CityscapesClass'
,
[
'name'
,
'id'
,
'train_id'
,
'category'
,
'category_id'
,
'has_instances'
,
'ignore_in_eval'
,
'color'
])
classes
=
[
CityscapesClass
(
'unlabeled'
,
0
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'ego vehicle'
,
1
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'rectification border'
,
2
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'out of roi'
,
3
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'static'
,
4
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'dynamic'
,
5
,
19
,
'void'
,
0
,
False
,
True
,
(
111
,
74
,
0
)),
CityscapesClass
(
'ground'
,
6
,
19
,
'void'
,
0
,
False
,
True
,
(
81
,
0
,
81
)),
CityscapesClass
(
'road'
,
7
,
0
,
'flat'
,
1
,
False
,
False
,
(
128
,
64
,
128
)),
CityscapesClass
(
'sidewalk'
,
8
,
1
,
'flat'
,
1
,
False
,
False
,
(
244
,
35
,
232
)),
CityscapesClass
(
'parking'
,
9
,
19
,
'flat'
,
1
,
False
,
True
,
(
250
,
170
,
160
)),
CityscapesClass
(
'rail track'
,
10
,
19
,
'flat'
,
1
,
False
,
True
,
(
230
,
150
,
140
)),
CityscapesClass
(
'building'
,
11
,
2
,
'construction'
,
2
,
False
,
False
,
(
70
,
70
,
70
)),
CityscapesClass
(
'wall'
,
12
,
3
,
'construction'
,
2
,
False
,
False
,
(
102
,
102
,
156
)),
CityscapesClass
(
'fence'
,
13
,
4
,
'construction'
,
2
,
False
,
False
,
(
190
,
153
,
153
)),
CityscapesClass
(
'guard rail'
,
14
,
19
,
'construction'
,
2
,
False
,
True
,
(
180
,
165
,
180
)),
CityscapesClass
(
'bridge'
,
15
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
100
,
100
)),
CityscapesClass
(
'tunnel'
,
16
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
120
,
90
)),
CityscapesClass
(
'pole'
,
17
,
5
,
'object'
,
3
,
False
,
False
,
(
153
,
153
,
153
)),
CityscapesClass
(
'polegroup'
,
18
,
19
,
'object'
,
3
,
False
,
True
,
(
153
,
153
,
153
)),
CityscapesClass
(
'traffic light'
,
19
,
6
,
'object'
,
3
,
False
,
False
,
(
250
,
170
,
30
)),
CityscapesClass
(
'traffic sign'
,
20
,
7
,
'object'
,
3
,
False
,
False
,
(
220
,
220
,
0
)),
CityscapesClass
(
'vegetation'
,
21
,
8
,
'nature'
,
4
,
False
,
False
,
(
107
,
142
,
35
)),
CityscapesClass
(
'terrain'
,
22
,
9
,
'nature'
,
4
,
False
,
False
,
(
152
,
251
,
152
)),
CityscapesClass
(
'sky'
,
23
,
10
,
'sky'
,
5
,
False
,
False
,
(
70
,
130
,
180
)),
CityscapesClass
(
'person'
,
24
,
11
,
'human'
,
6
,
True
,
False
,
(
220
,
20
,
60
)),
CityscapesClass
(
'rider'
,
25
,
12
,
'human'
,
6
,
True
,
False
,
(
255
,
0
,
0
)),
CityscapesClass
(
'car'
,
26
,
13
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
142
)),
CityscapesClass
(
'truck'
,
27
,
14
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
70
)),
CityscapesClass
(
'bus'
,
28
,
15
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
60
,
100
)),
CityscapesClass
(
'caravan'
,
29
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
90
)),
CityscapesClass
(
'trailer'
,
30
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
110
)),
CityscapesClass
(
'train'
,
31
,
16
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
80
,
100
)),
CityscapesClass
(
'motorcycle'
,
32
,
17
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
230
)),
CityscapesClass
(
'bicycle'
,
33
,
18
,
'vehicle'
,
7
,
True
,
False
,
(
119
,
11
,
32
)),
CityscapesClass
(
'license plate'
,
-
1
,
-
1
,
'vehicle'
,
7
,
False
,
True
,
(
0
,
0
,
142
)),
]
# label2trainid
label2trainid
=
{
label
.
id
:
label
.
train_id
for
label
in
classes
}
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
mode
:
str
=
"fine"
,
resolution
:
int
=
1024
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
'gtFine'
if
mode
==
'fine'
else
'gtCoarse'
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
'leftImg8bit_trainvaltest/leftImg8bit'
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
'gtFine_trainvaltest/gtFine'
,
split
)
self
.
split
=
split
self
.
resolution
=
resolution
self
.
images
=
[]
self
.
targets
=
[]
for
city
in
sorted
(
os
.
listdir
(
self
.
images_dir
)):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
target_name
=
'{}_{}_labelIds.png'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
mode
)
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
np
.
array
(
target
)
target_copy
=
target
.
copy
()
for
k
,
v
in
Cityscapes
.
label2trainid
.
items
():
binary_target
=
(
target
==
k
)
target_copy
[
binary_target
]
=
v
target
=
target_copy
target
=
Image
.
fromarray
(
target
.
astype
(
np
.
uint8
))
if
self
.
transforms
is
not
None
:
image
,
target
=
self
.
transforms
(
image
,
target
)
return
image
,
target
def
__len__
(
self
)
->
int
:
# len(self.images)
return
len
(
self
.
images
)
tasks/vision/segmentation/data.py
0 → 100644
View file @
ca7a2e34
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron.data.autoaugment
import
ImageNetPolicy
from
tasks.vision.segmentation.cityscapes
import
Cityscapes
import
tasks.vision.segmentation.transforms
as
ET
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
class
VitSegmentationJointTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
if
self
.
train
:
self
.
transform0
=
ET
.
RandomSizeAndCrop
(
resolution
)
self
.
transform1
=
ET
.
RandomHorizontallyFlip
()
def
__call__
(
self
,
img
,
mask
):
if
self
.
train
:
img
,
mask
=
self
.
transform0
(
img
,
mask
)
img
,
mask
=
self
.
transform1
(
img
,
mask
)
return
img
,
mask
class
VitSegmentationImageTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
args
=
get_args
()
self
.
train
=
train
assert
args
.
fp16
or
args
.
bf16
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
self
.
mean_std
=
args
.
mean_std
if
self
.
train
:
assert
resolution
is
not
None
self
.
transform
=
T
.
Compose
([
ET
.
PhotoMetricDistortion
(),
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
self
.
transform
=
T
.
Compose
([
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
def
__call__
(
self
,
input
):
output
=
self
.
transform
(
input
)
return
output
class
VitSegmentationTargetTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
def
__call__
(
self
,
input
):
output
=
torch
.
from_numpy
(
np
.
array
(
input
,
dtype
=
np
.
int32
)).
long
()
return
output
class
RandomSeedSegmentationDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
joint_transform
,
image_transform
,
target_transform
):
args
=
get_args
()
self
.
base_seed
=
args
.
seed
self
.
curr_seed
=
self
.
base_seed
self
.
dataset
=
dataset
self
.
joint_transform
=
joint_transform
self
.
image_transform
=
image_transform
self
.
target_transform
=
target_transform
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
set_epoch
(
self
,
epoch
):
self
.
curr_seed
=
self
.
base_seed
+
100
*
epoch
def
__getitem__
(
self
,
idx
):
seed
=
idx
+
self
.
curr_seed
img
,
mask
=
self
.
dataset
[
idx
]
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
img
,
mask
=
self
.
joint_transform
(
img
,
mask
)
img
=
self
.
image_transform
(
img
)
mask
=
self
.
target_transform
(
mask
)
return
img
,
mask
def
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
):
args
=
get_args
()
args
.
num_classes
=
Cityscapes
.
num_classes
args
.
ignore_index
=
Cityscapes
.
ignore_index
args
.
color_table
=
Cityscapes
.
color_table
args
.
mean_std
=
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
])
train_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
True
,
resolution
=
image_size
)
val_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
False
,
resolution
=
image_size
)
train_image_transform
=
\
VitSegmentationImageTransform
(
train
=
True
,
resolution
=
image_size
)
val_image_transform
=
\
VitSegmentationImageTransform
(
train
=
False
,
resolution
=
image_size
)
train_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
True
,
resolution
=
image_size
)
val_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
False
,
resolution
=
image_size
)
# training dataset
train_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'train'
,
mode
=
'fine'
,
resolution
=
image_size
)
train_data
=
RandomSeedSegmentationDataset
(
train_data
,
joint_transform
=
train_joint_transform
,
image_transform
=
train_image_transform
,
target_transform
=
train_target_transform
)
# validation dataset
val_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'val'
,
mode
=
'fine'
,
resolution
=
image_size
)
val_data
=
RandomSeedSegmentationDataset
(
val_data
,
joint_transform
=
val_joint_transform
,
image_transform
=
val_image_transform
,
target_transform
=
val_target_transform
)
return
train_data
,
val_data
def
build_train_valid_datasets
(
data_path
,
image_size
):
return
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
)
tasks/vision/segmentation/finetune_segformer.py
0 → 100644
View file @
ca7a2e34
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SegformerSegmentationModel
from
megatron.model.vision.utils
import
resize
def
calculate_iou
(
hist_data
):
acc
=
np
.
diag
(
hist_data
).
sum
()
/
hist_data
.
sum
()
acc_cls
=
np
.
diag
(
hist_data
)
/
hist_data
.
sum
(
axis
=
1
)
acc_cls
=
np
.
nanmean
(
acc_cls
)
divisor
=
hist_data
.
sum
(
axis
=
1
)
+
hist_data
.
sum
(
axis
=
0
)
-
\
np
.
diag
(
hist_data
)
iu
=
np
.
diag
(
hist_data
)
/
divisor
return
iu
,
acc
,
acc_cls
def
fast_hist
(
pred
,
gtruth
,
num_classes
):
# mask indicates pixels we care about
mask
=
(
gtruth
>=
0
)
&
(
gtruth
<
num_classes
)
# stretch ground truth labels by num_classes
# class 0 -> 0
# class 1 -> 19
# class 18 -> 342
#
# TP at 0 + 0, 1 + 1, 2 + 2 ...
#
# TP exist where value == num_classes*class_id + class_id
# FP = row[class].sum() - TP
# FN = col[class].sum() - TP
hist
=
np
.
bincount
(
num_classes
*
gtruth
[
mask
].
astype
(
int
)
+
pred
[
mask
],
minlength
=
num_classes
**
2
)
hist
=
hist
.
reshape
(
num_classes
,
num_classes
)
return
hist
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
model
=
SegformerSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
print_rank_0
(
"model = {}"
.
format
(
model
))
return
model
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
logits
=
output_tensor
.
contiguous
().
float
()
logits
=
resize
(
logits
,
size
=
masks
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
# Cross-entropy loss.
# weight = calculate_weight(masks, num_classes)
loss
=
F
.
cross_entropy
(
logits
,
masks
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
logits
=
resize
(
logits
,
size
=
labels
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
cpu
().
numpy
()
performs
=
fast_hist
(
preds
.
flatten
(),
labels
.
cpu
().
numpy
().
flatten
(),
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
performs_tensor
=
torch
.
cuda
.
FloatTensor
(
performs
)
torch
.
distributed
.
all_reduce
(
performs_tensor
,
group
=
mpu
.
get_data_parallel_group
())
hist
=
performs_tensor
.
cpu
().
numpy
()
iu
,
acc
,
acc_cls
=
calculate_iou
(
hist
)
miou
=
np
.
nanmean
(
iu
)
return
iu
,
miou
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/finetune_setr.py
0 → 100644
View file @
ca7a2e34
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.metrics
import
CFMatrix
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SetrSegmentationModel
from
tasks.vision.segmentation.utils
import
slidingcrops
,
slidingjoins
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
return
SetrSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
weight
=
calculate_weight
(
masks
,
args
.
num_classes
)
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
masks
,
weight
=
weight
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
if
not
model
.
training
:
images
,
masks
,
_
,
_
=
slidingcrops
(
images
,
masks
)
#print_rank_0("images size = {}".format(images.size()))
if
not
model
.
training
:
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
else
:
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
slices_info
,
img_size
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
int
()
preds
,
labels
=
slidingjoins
(
preds
,
max_probs
,
labels
,
slices_info
,
img_size
)
_
,
performs
=
CFMatrix
()(
preds
,
labels
,
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
args
=
get_args
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
assert
not
model
.
training
images
,
labels
,
slices_info
,
img_size
=
slidingcrops
(
images
,
labels
)
# Forward model.
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
return
output_tensor
,
partial
(
loss_func
,
labels
,
slices_info
,
img_size
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
performs
,
group
=
mpu
.
get_data_parallel_group
())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive
=
performs
[:,
0
]
false_positive
=
performs
[:,
1
]
false_negative
=
performs
[:,
3
]
iou
=
true_positive
/
(
true_positive
+
false_positive
+
false_negative
)
miou
=
iou
[
~
torch
.
isnan
(
iou
)].
mean
()
return
iou
.
tolist
(),
miou
.
item
()
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/metrics.py
0 → 100644
View file @
ca7a2e34
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
eps
=
1e-6
def
_binarize
(
y_data
,
threshold
):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data
[
y_data
<
threshold
]
=
0.0
y_data
[
y_data
>=
threshold
]
=
1.0
return
y_data
def
_argmax
(
y_data
,
dim
):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return
torch
.
argmax
(
y_data
,
dim
).
int
()
def
_get_tp
(
y_pred
,
y_true
):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return
torch
.
sum
(
y_true
*
y_pred
).
float
()
def
_get_fp
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return
torch
.
sum
((
1
-
y_true
)
*
y_pred
).
float
()
def
_get_tn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return
torch
.
sum
((
1
-
y_true
)
*
(
1
-
y_pred
)).
float
()
def
_get_fn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return
torch
.
sum
(
y_true
*
(
1
-
y_pred
)).
float
()
def
_get_weights
(
y_true
,
nb_ch
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size
,
img_rows
,
img_cols
=
y_true
.
shape
pixels
=
batch_size
*
img_rows
*
img_cols
weights
=
[
torch
.
sum
(
y_true
==
ch
).
item
()
/
pixels
for
ch
in
range
(
nb_ch
)]
return
weights
class
CFMatrix
(
object
):
def
__init__
(
self
,
des
=
None
):
self
.
des
=
des
def
__repr__
(
self
):
return
"ConfusionMatrix"
def
__call__
(
self
,
y_pred
,
y_true
,
ignore_index
,
threshold
=
0.5
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size
,
img_rows
,
img_cols
=
y_pred
.
shape
chs
=
ignore_index
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
[
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
]
performs
=
None
else
:
performs
=
torch
.
zeros
(
chs
,
4
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_false_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_false_ch
[
torch
.
logical_and
((
y_true
!=
ch
),
(
y_true
!=
ignore_index
))]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
torch
.
sum
(
y_false_ch
*
y_pred_ch
).
float
()
nb_tn
=
torch
.
sum
(
y_false_ch
*
(
1
-
y_pred_ch
)).
float
()
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
),
:]
=
torch
.
FloatTensor
([
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
])
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
OAAcc
(
object
):
def
__init__
(
self
,
des
=
"Overall Accuracy"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"OAcc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
nb_tp_tn
=
torch
.
sum
(
y_true
==
y_pred
).
float
()
mperforms
=
nb_tp_tn
/
(
batch_size
*
img_rows
*
img_cols
)
performs
=
None
return
mperforms
,
performs
class
Precision
(
object
):
def
__init__
(
self
,
des
=
"Precision"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Prec"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Recall
(
object
):
def
__init__
(
self
,
des
=
"Recall"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Reca"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
F1Score
(
object
):
def
__init__
(
self
,
des
=
"F1Score"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"F1Sc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
2
*
_precision
*
_recall
/
(
_precision
+
_recall
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
[
int
(
ch
)]
=
2
*
_precision
*
\
_recall
/
(
_precision
+
_recall
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Kappa
(
object
):
def
__init__
(
self
,
des
=
"Kappa"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Kapp"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
mperforms
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_tn
=
_get_tn
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
performs
[
int
(
ch
)]
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Jaccard
(
object
):
def
__init__
(
self
,
des
=
"Jaccard"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Jacc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
_intersec
=
torch
.
sum
(
y_true
*
y_pred
).
float
()
_sum
=
torch
.
sum
(
y_true
+
y_pred
).
float
()
mperforms
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
_intersec
=
torch
.
sum
(
y_true_ch
*
y_pred_ch
).
float
()
_sum
=
torch
.
sum
(
y_true_ch
+
y_pred_ch
).
float
()
performs
[
int
(
ch
)]
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
MSE
(
object
):
def
__init__
(
self
,
des
=
"Mean Square Error"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"MSE"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
return
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
class
PSNR
(
object
):
def
__init__
(
self
,
des
=
"Peak Signal to Noise Ratio"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"PSNR"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
mse
=
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
return
10
*
torch
.
log10
(
1
/
mse
)
class
SSIM
(
object
):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def
__init__
(
self
,
des
=
"structural similarity index"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"SSIM"
def
gaussian
(
self
,
w_size
,
sigma
):
gauss
=
torch
.
Tensor
([
math
.
exp
(
-
(
x
-
w_size
//
2
)
**
2
/
float
(
2
*
sigma
**
2
))
for
x
in
range
(
w_size
)])
return
gauss
/
gauss
.
sum
()
def
create_window
(
self
,
w_size
,
channel
=
1
):
_1D_window
=
self
.
gaussian
(
w_size
,
1.5
).
unsqueeze
(
1
)
_2D_window
=
_1D_window
.
mm
(
_1D_window
.
t
()).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
window
=
_2D_window
.
expand
(
channel
,
1
,
w_size
,
w_size
).
contiguous
()
return
window
def
__call__
(
self
,
y_pred
,
y_true
,
w_size
=
11
,
size_average
=
True
,
full
=
False
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if
torch
.
max
(
y_pred
)
>
128
:
max_val
=
255
else
:
max_val
=
1
if
torch
.
min
(
y_pred
)
<
-
0.5
:
min_val
=
-
1
else
:
min_val
=
0
L
=
max_val
-
min_val
padd
=
0
(
_
,
channel
,
height
,
width
)
=
y_pred
.
size
()
window
=
self
.
create_window
(
w_size
,
channel
=
channel
).
to
(
y_pred
.
device
)
mu1
=
F
.
conv2d
(
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
mu2
=
F
.
conv2d
(
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
mu1_sq
=
mu1
.
pow
(
2
)
mu2_sq
=
mu2
.
pow
(
2
)
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
F
.
conv2d
(
y_pred
*
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_sq
sigma2_sq
=
F
.
conv2d
(
y_true
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu2_sq
sigma12
=
F
.
conv2d
(
y_pred
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_mu2
C1
=
(
0.01
*
L
)
**
2
C2
=
(
0.03
*
L
)
**
2
v1
=
2.0
*
sigma12
+
C2
v2
=
sigma1_sq
+
sigma2_sq
+
C2
cs
=
torch
.
mean
(
v1
/
v2
)
# contrast sensitivity
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
v1
)
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
v2
)
if
size_average
:
ret
=
ssim_map
.
mean
()
else
:
ret
=
ssim_map
.
mean
(
1
).
mean
(
1
).
mean
(
1
)
if
full
:
return
ret
,
cs
return
ret
class
AE
(
object
):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def
__init__
(
self
,
des
=
'average Angular Error'
):
self
.
des
=
des
def
__repr__
(
self
):
return
"AE"
def
__call__
(
self
,
y_pred
,
y_true
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP
=
torch
.
sum
(
y_pred
*
y_true
,
dim
=
1
)
Norm_pred
=
torch
.
sqrt
(
torch
.
sum
(
y_pred
*
y_pred
,
dim
=
1
))
Norm_true
=
torch
.
sqrt
(
torch
.
sum
(
y_true
*
y_true
,
dim
=
1
))
ae
=
180
/
math
.
pi
*
torch
.
acos
(
dotP
/
(
Norm_pred
*
Norm_true
+
eps
))
return
ae
.
mean
(
1
).
mean
(
1
)
if
__name__
==
"__main__"
:
for
ch
in
[
3
,
1
]:
batch_size
,
img_row
,
img_col
=
1
,
224
,
224
y_true
=
torch
.
rand
(
batch_size
,
ch
,
img_row
,
img_col
)
noise
=
torch
.
zeros
(
y_true
.
size
()).
data
.
normal_
(
0
,
std
=
0.1
)
y_pred
=
y_true
+
noise
for
cuda
in
[
False
,
True
]:
if
cuda
:
y_pred
=
y_pred
.
cuda
()
y_true
=
y_true
.
cuda
()
print
(
'#'
*
20
,
'Cuda : {} ; size : {}'
.
format
(
cuda
,
y_true
.
size
()))
########### similarity metrics
metric
=
MSE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
PSNR
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
SSIM
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
LPIPS
(
cuda
)
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
AE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
########### accuracy metrics
metric
=
OAAcc
()
maccu
,
accu
=
metric
(
y_pred
,
y_true
)
print
(
'mAccu:'
,
maccu
,
'Accu'
,
accu
)
metric
=
Precision
()
mprec
,
prec
=
metric
(
y_pred
,
y_true
)
print
(
'mPrec:'
,
mprec
,
'Prec'
,
prec
)
metric
=
Recall
()
mreca
,
reca
=
metric
(
y_pred
,
y_true
)
print
(
'mReca:'
,
mreca
,
'Reca'
,
reca
)
metric
=
F1Score
()
mf1sc
,
f1sc
=
metric
(
y_pred
,
y_true
)
print
(
'mF1sc:'
,
mf1sc
,
'F1sc'
,
f1sc
)
metric
=
Kappa
()
mkapp
,
kapp
=
metric
(
y_pred
,
y_true
)
print
(
'mKapp:'
,
mkapp
,
'Kapp'
,
kapp
)
metric
=
Jaccard
()
mjacc
,
jacc
=
metric
(
y_pred
,
y_true
)
print
(
'mJacc:'
,
mjacc
,
'Jacc'
,
jacc
)
tasks/vision/segmentation/seg_heads.py
0 → 100644
View file @
ca7a2e34
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.utils
import
resize
class
SetrSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
SetrSegmentationHead
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
num_classes
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
patch_dim
=
args
.
patch_dim
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
self
.
conv_0
=
torch
.
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
1
,
1
,
bias
=
False
)
self
.
norm_0
=
apex
.
parallel
.
SyncBatchNorm
(
hidden_size
)
self
.
conv_1
=
torch
.
nn
.
Conv2d
(
hidden_size
,
num_classes
,
1
,
1
)
def
to_2D
(
self
,
x
):
n
,
hw
,
c
=
x
.
shape
h
=
self
.
img_h
//
self
.
patch_dim
w
=
self
.
img_w
//
self
.
patch_dim
assert
(
hw
==
h
*
w
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
n
,
c
,
h
,
w
)
return
x
def
forward
(
self
,
hidden_states
):
# [b c h w]
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
self
.
to_2D
(
hidden_states
)
hidden_states
=
self
.
conv_0
(
hidden_states
)
hidden_states
=
self
.
norm_0
(
hidden_states
)
hidden_states
=
torch
.
tanh
(
hidden_states
)
hidden_states
=
self
.
conv_1
(
hidden_states
)
# [b c h w]
result
=
F
.
interpolate
(
hidden_states
,
size
=
(
self
.
img_h
,
self
.
img_w
),
mode
=
'bilinear'
)
return
result
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
SegformerSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
feature_strides
,
in_channels
,
embedding_dim
,
dropout_ratio
):
super
(
SegformerSegmentationHead
,
self
).
__init__
()
assert
len
(
feature_strides
)
==
len
(
in_channels
)
assert
min
(
feature_strides
)
==
feature_strides
[
0
]
args
=
get_args
()
self
.
feature_strides
=
feature_strides
self
.
in_channels
=
in_channels
self
.
embedding_dim
=
embedding_dim
self
.
num_classes
=
args
.
num_classes
self
.
dropout_ratio
=
dropout_ratio
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
\
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
self
.
dropout_ratio
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
num_classes
,
kernel_size
=
1
)
def
forward
(
self
,
inputs
):
c1
,
c2
,
c3
,
c4
=
inputs
############## MLP decoder on C1-C4 ###########
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
self
.
conv_fuse
(
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
))
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
return
x
tasks/vision/segmentation/seg_models.py
0 → 100644
View file @
ca7a2e34
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3
,
mit_b5
from
tasks.vision.segmentation.seg_heads
import
SetrSegmentationHead
,
SegformerSegmentationHead
class
SetrSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SetrSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
assert
post_process
&
pre_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
class_token
=
False
,
post_layer_norm
=
False
,
drop_path_rate
=
0.1
)
self
.
head
=
SetrSegmentationHead
(
self
.
hidden_size
,
self
.
num_classes
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
result_final
=
self
.
head
(
hidden_states
)
return
result_final
class
SegformerSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SegformerSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
mit_b5
()
self
.
head
=
SegformerSegmentationHead
(
feature_strides
=
[
4
,
8
,
16
,
32
],
in_channels
=
[
64
,
128
,
320
,
512
],
embedding_dim
=
768
,
dropout_ratio
=
0.1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
tasks/vision/segmentation/transforms.py
0 → 100644
View file @
ca7a2e34
# Copyright (c) 2020 The MMSegmenation Authors.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
,
ImageEnhance
import
torchvision.transforms
as
torch_tr
def
_is_pil_image
(
img
):
return
isinstance
(
img
,
Image
.
Image
)
class
PhotoMetricDistortion
(
object
):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def
__init__
(
self
,
brightness_delta
=
32
,
contrast_range
=
(
0.5
,
1.5
),
saturation_range
=
(
0.5
,
1.5
),
hue_delta
=
18
):
self
.
brightness_delta
=
brightness_delta
self
.
contrast_lower
,
self
.
contrast_upper
=
contrast_range
self
.
saturation_lower
,
self
.
saturation_upper
=
saturation_range
self
.
hue_delta
=
hue_delta
def
convert
(
self
,
img
,
alpha
=
1
,
beta
=
0
):
"""Multiple with alpha and add beat with clip."""
img
=
img
.
astype
(
np
.
float32
)
*
alpha
+
beta
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
.
astype
(
np
.
uint8
)
def
brightness
(
self
,
img
):
"""Brightness distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
beta
=
random
.
uniform
(
-
self
.
brightness_delta
,
self
.
brightness_delta
))
return
img
def
contrast
(
self
,
img
):
"""Contrast distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
alpha
=
random
.
uniform
(
self
.
contrast_lower
,
self
.
contrast_upper
))
return
img
def
saturation
(
self
,
img
):
"""Saturation distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
1
]
=
self
.
convert
(
img
[:,
:,
1
],
alpha
=
random
.
uniform
(
self
.
saturation_lower
,
self
.
saturation_upper
))
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
hue
(
self
,
img
):
"""Hue distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
0
]
=
(
img
[:,
:,
0
].
astype
(
int
)
+
random
.
randint
(
-
self
.
hue_delta
,
self
.
hue_delta
))
%
180
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
__call__
(
self
,
img
):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img
=
np
.
array
(
img
)
# random brightness
img
=
self
.
brightness
(
img
)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode
=
random
.
randint
(
0
,
1
)
if
mode
==
1
:
img
=
self
.
contrast
(
img
)
# random saturation
img
=
self
.
saturation
(
img
)
# random hue
img
=
self
.
hue
(
img
)
# random contrast
if
mode
==
0
:
img
=
self
.
contrast
(
img
)
img
=
Image
.
fromarray
(
img
.
astype
(
np
.
uint8
)).
convert
(
'RGB'
)
return
img
class
RandomCrop
(
object
):
"""
Take a random crop from the image.
First the image or crop size may need to be adjusted if the incoming image
is too small...
If the image is smaller than the crop, then:
the image is padded up to the size of the crop
unless 'nopad', in which case the crop size is shrunk to fit the image
A random crop is taken such that the crop fits within the image.
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
translation randomness of at least that value around the image.
if image < crop_size:
# slide crop within image, random offset
else:
# slide image within crop
"""
def
__init__
(
self
,
crop_size
):
args
=
get_args
()
self
.
size
=
crop_size
self
.
cat_max_ratio
=
0.75
self
.
ignore_index
=
args
.
ignore_index
self
.
pad_color
=
(
0
,
0
,
0
)
def
get_crop_bbox
(
self
,
img
):
"""Randomly get a crop bounding box."""
img_w
,
img_h
=
img
.
size
target_h
,
target_w
=
self
.
size
#[H W]
margin_h
=
max
(
img_h
-
target_h
,
0
)
margin_w
=
max
(
img_w
-
target_w
,
0
)
offset_h
=
random
.
randint
(
0
,
margin_h
)
offset_w
=
random
.
randint
(
0
,
margin_w
)
crop_y1
,
crop_y2
=
offset_h
,
offset_h
+
target_h
crop_x1
,
crop_x2
=
offset_w
,
offset_w
+
target_w
return
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
def
crop
(
self
,
img
,
crop_bbox
):
"""Crop from ``img``"""
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
=
crop_bbox
img
=
img
.
crop
((
crop_x1
,
crop_y1
,
crop_x2
,
crop_y2
))
return
img
@
staticmethod
def
crop_in_image
(
target_w
,
target_h
,
w
,
h
,
img
,
mask
):
if
w
==
target_w
:
x1
=
0
else
:
x1
=
random
.
randint
(
0
,
w
-
target_w
)
if
h
==
target_h
:
y1
=
0
else
:
y1
=
random
.
randint
(
0
,
h
-
target_h
)
return
[
img
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
)),
mask
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
))]
def
__call__
(
self
,
img
,
mask
):
w
,
h
=
img
.
size
target_h
,
target_w
=
self
.
size
# ASSUME H, W
if
w
==
target_w
and
h
==
target_h
:
return
img
,
mask
# Pad image if image < crop
if
target_h
>
h
:
pad_h
=
(
target_h
-
h
)
//
2
+
1
else
:
pad_h
=
0
if
target_w
>
w
:
pad_w
=
(
target_w
-
w
)
//
2
+
1
else
:
pad_w
=
0
border
=
(
pad_w
,
pad_h
,
pad_w
,
pad_h
)
if
pad_h
or
pad_w
:
img
=
ImageOps
.
expand
(
img
,
border
=
border
,
fill
=
(
0
,
0
,
0
))
mask
=
ImageOps
.
expand
(
mask
,
border
=
border
,
fill
=
self
.
ignore_index
)
w
,
h
=
img
.
size
crop_bbox
=
self
.
get_crop_bbox
(
img
)
if
self
.
cat_max_ratio
<
1.
:
# Repeat 10 times
for
_
in
range
(
10
):
seg_temp
=
self
.
crop
(
mask
,
crop_bbox
)
labels
,
cnt
=
np
.
unique
(
seg_temp
,
return_counts
=
True
)
cnt
=
cnt
[
labels
!=
self
.
ignore_index
]
if
len
(
cnt
)
>
1
and
np
.
max
(
cnt
)
/
np
.
sum
(
cnt
)
<
self
.
cat_max_ratio
:
break
crop_bbox
=
self
.
get_crop_bbox
(
img
)
# crop the image
img
=
self
.
crop
(
img
,
crop_bbox
)
# crop semantic seg
mask
=
self
.
crop
(
mask
,
crop_bbox
)
assert
(
img
.
size
[
0
]
==
self
.
size
[
1
]
and
img
.
size
[
1
]
==
self
.
size
[
0
])
return
img
,
mask
class
RandomSizeAndCrop
(
object
):
def
__init__
(
self
,
crop_size
,
scale_min
=
0.5
,
scale_max
=
2.0
):
self
.
crop
=
RandomCrop
(
crop_size
)
self
.
scale_min
=
scale_min
self
.
scale_max
=
scale_max
def
__call__
(
self
,
img
,
mask
):
scale_amt
=
random
.
uniform
(
self
.
scale_min
,
self
.
scale_max
)
w
,
h
=
[
int
(
i
*
scale_amt
)
for
i
in
img
.
size
]
resized_img
=
img
.
resize
((
w
,
h
),
Image
.
BICUBIC
)
resized_mask
=
mask
.
resize
((
w
,
h
),
Image
.
NEAREST
)
img
,
mask
=
self
.
crop
(
resized_img
,
resized_mask
)
return
img
,
mask
class
RandomHorizontallyFlip
(
object
):
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
),
mask
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
,
mask
def
adjust_brightness
(
img
,
brightness_factor
):
"""Adjust brightness of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Brightness
(
img
)
img
=
enhancer
.
enhance
(
brightness_factor
)
return
img
def
adjust_contrast
(
img
,
contrast_factor
):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Contrast
(
img
)
img
=
enhancer
.
enhance
(
contrast_factor
)
return
img
def
adjust_saturation
(
img
,
saturation_factor
):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Color
(
img
)
img
=
enhancer
.
enhance
(
saturation_factor
)
return
img
def
adjust_hue
(
img
,
hue_factor
):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
'hue_factor is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
input_mode
=
img
.
mode
if
input_mode
in
{
'L'
,
'1'
,
'I'
,
'F'
}:
return
img
h
,
s
,
v
=
img
.
convert
(
'HSV'
).
split
()
np_h
=
np
.
array
(
h
,
dtype
=
np
.
uint8
)
# uint8 addition take cares of rotation across boundaries
with
np
.
errstate
(
over
=
'ignore'
):
np_h
+=
np
.
uint8
(
hue_factor
*
255
)
h
=
Image
.
fromarray
(
np_h
,
'L'
)
img
=
Image
.
merge
(
'HSV'
,
(
h
,
s
,
v
)).
convert
(
input_mode
)
return
img
class
ColorJitter
(
object
):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def
__init__
(
self
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
self
.
brightness
=
brightness
self
.
contrast
=
contrast
self
.
saturation
=
saturation
self
.
hue
=
hue
@
staticmethod
def
get_params
(
brightness
,
contrast
,
saturation
,
hue
):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms
=
[]
if
brightness
>
0
:
brightness_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
brightness
),
1
+
brightness
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_brightness
(
img
,
brightness_factor
)))
if
contrast
>
0
:
contrast_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
contrast
),
1
+
contrast
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_contrast
(
img
,
contrast_factor
)))
if
saturation
>
0
:
saturation_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
saturation
),
1
+
saturation
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_saturation
(
img
,
saturation_factor
)))
if
hue
>
0
:
hue_factor
=
np
.
random
.
uniform
(
-
hue
,
hue
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_hue
(
img
,
hue_factor
)))
np
.
random
.
shuffle
(
transforms
)
transform
=
torch_tr
.
Compose
(
transforms
)
return
transform
def
__call__
(
self
,
img
):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform
=
self
.
get_params
(
self
.
brightness
,
self
.
contrast
,
self
.
saturation
,
self
.
hue
)
return
transform
(
img
)
tasks/vision/segmentation/utils.py
0 → 100644
View file @
ca7a2e34
import
math
import
torch
import
numpy
as
np
from
megatron
import
get_args
def
slidingcrops
(
img
,
mask
):
# img: [b c h w]
# mask: [b h w]
args
=
get_args
()
assert
args
.
img_h
==
args
.
img_w
crop_size
=
args
.
img_h
stride
=
args
.
seg_stride
ignore_index
=
args
.
ignore_index
n
,
c
,
h
,
w
=
img
.
shape
assert
h
>=
crop_size
assert
w
>=
crop_size
long_size
=
max
(
h
,
w
)
img_slices
,
mask_slices
,
slices_info
=
[],
[],
[]
if
long_size
>
crop_size
:
assert
stride
<=
crop_size
h_step_num
=
int
(
math
.
ceil
((
h
-
crop_size
)
/
float
(
stride
)))
+
1
w_step_num
=
int
(
math
.
ceil
((
w
-
crop_size
)
/
float
(
stride
)))
+
1
for
yy
in
range
(
h_step_num
):
for
xx
in
range
(
w_step_num
):
sy
,
sx
=
yy
*
stride
,
xx
*
stride
ey
,
ex
=
sy
+
crop_size
,
sx
+
crop_size
img_sub
=
img
[:,
:,
sy
:
ey
,
sx
:
ex
]
mask_sub
=
mask
[:,
sy
:
ey
,
sx
:
ex
]
# padding
sub_h
,
sub_w
=
img_sub
.
shape
[
2
:]
pad_h
=
max
(
crop_size
-
sub_h
,
0
)
pad_w
=
max
(
crop_size
-
sub_w
,
0
)
img_sub
=
torch
.
nn
.
functional
.
pad
(
img_sub
,
pad
=
(
0
,
pad_w
,
0
,
pad_h
),
value
=
ignore_index
)
mask_sub
=
torch
.
nn
.
functional
.
pad
(
mask_sub
,
pad
=
(
0
,
pad_w
,
0
,
pad_h
))
img_slices
.
append
(
img_sub
)
mask_slices
.
append
(
mask_sub
)
slices_info
.
append
([
sy
,
ey
,
sx
,
ex
,
sub_h
,
sub_w
])
return
torch
.
cat
(
img_slices
),
torch
.
cat
(
mask_slices
),
slices_info
,
(
h
,
w
)
else
:
return
img
,
mask
,
[[
0
,
h
,
0
,
w
,
h
,
w
]],
(
h
,
w
)
def
slidingjoins
(
preds
,
probs
,
labels
,
slices_info
,
img_size
):
args
=
get_args
()
num_slices
=
len
(
slices_info
)
if
num_slices
==
1
:
return
preds
,
labels
h
,
w
=
img_size
split_size
=
args
.
micro_batch_size
preds_split
=
torch
.
split
(
preds
,
split_size
)
probs_split
=
torch
.
split
(
probs
,
split_size
)
labels_split
=
torch
.
split
(
labels
,
split_size
)
assert
(
len
(
preds_split
)
==
num_slices
)
total_max_probs
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
float
,
device
=
'cuda'
)
total_preds
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
total_labels
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
for
i
in
range
(
num_slices
):
sy
,
ey
,
sx
,
ex
,
sub_h
,
sub_w
=
slices_info
[
i
]
assert
sy
+
sub_h
<=
h
assert
sx
+
sub_w
<=
w
curr_max_probs
=
total_max_probs
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
curr_preds
=
total_preds
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
local_max_probs
=
probs_split
[
i
][:,
:
sub_h
,
:
sub_w
]
local_preds
=
preds_split
[
i
][:,
:
sub_h
,
:
sub_w
]
result_max_probs
=
torch
.
maximum
(
curr_max_probs
,
local_max_probs
)
result_preds
=
torch
.
where
(
curr_max_probs
>=
local_max_probs
,
curr_preds
,
local_preds
)
total_max_probs
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
result_max_probs
total_preds
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
result_preds
total_labels
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
labels_split
[
i
][
0
,
:
sub_h
,
:
sub_w
]
return
total_preds
,
total_labels
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment