Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
1016e98a
Commit
1016e98a
authored
Feb 18, 2023
by
zhuww
Browse files
megatron-lm0.3.2 based on dtk-22.10
parent
6c3f6c7b
Changes
241
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4304 additions
and
0 deletions
+4304
-0
tasks/orqa/supervised/data.py
tasks/orqa/supervised/data.py
+300
-0
tasks/orqa/supervised/eval_utils.py
tasks/orqa/supervised/eval_utils.py
+206
-0
tasks/orqa/supervised/finetune.py
tasks/orqa/supervised/finetune.py
+251
-0
tasks/orqa/unsupervised/nq.py
tasks/orqa/unsupervised/nq.py
+228
-0
tasks/orqa/unsupervised/qa_utils.py
tasks/orqa/unsupervised/qa_utils.py
+177
-0
tasks/orqa/unsupervised/tokenizers.py
tasks/orqa/unsupervised/tokenizers.py
+243
-0
tasks/race/data.py
tasks/race/data.py
+135
-0
tasks/race/finetune.py
tasks/race/finetune.py
+67
-0
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+94
-0
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+129
-0
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+312
-0
tasks/vision/main.py
tasks/vision/main.py
+66
-0
tasks/vision/segmentation/cityscapes.py
tasks/vision/segmentation/cityscapes.py
+207
-0
tasks/vision/segmentation/data.py
tasks/vision/segmentation/data.py
+154
-0
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+251
-0
tasks/vision/segmentation/finetune_setr.py
tasks/vision/segmentation/finetune_setr.py
+225
-0
tasks/vision/segmentation/metrics.py
tasks/vision/segmentation/metrics.py
+594
-0
tasks/vision/segmentation/seg_heads.py
tasks/vision/segmentation/seg_heads.py
+140
-0
tasks/vision/segmentation/seg_models.py
tasks/vision/segmentation/seg_models.py
+92
-0
tasks/vision/segmentation/transforms.py
tasks/vision/segmentation/transforms.py
+433
-0
No files found.
tasks/orqa/supervised/data.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA dataset."""
import
json
import
random
from
abc
import
ABC
from
abc
import
abstractmethod
import
numpy
as
np
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
,
get_args
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
build_token_types_from_context_list
(
ctx_list
,
tokenizer
,
max_seq_length
):
ctx_id_list
,
ctx_types_list
=
[],
[]
for
context
in
ctx_list
:
title_ids
=
tokenizer
.
tokenize
(
context
[
'title'
])
ctx_ids
=
tokenizer
.
tokenize
(
context
[
'text'
])
ctx_ids
=
title_ids
+
[
tokenizer
.
sep_id
]
+
ctx_ids
ctx_ids
,
ctx_types
,
_
=
build_tokens_types_paddings_from_ids
(
ctx_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
ctx_id_list
.
append
(
ctx_ids
)
ctx_types_list
.
append
(
ctx_types
)
return
ctx_id_list
,
ctx_types_list
def
build_tokens_types_paddings_from_text
(
query
,
context
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
query_ids
=
tokenizer
.
tokenize
(
query
)
query_ids
,
query_types
,
query_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
query_ids
,
max_seq_length
,
\
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
# Appending the title of the context at front
extended_ctx_ids
=
None
if
context
is
not
None
:
title_ids
=
tokenizer
.
tokenize
(
context
[
'title'
])
ctx_ids
=
tokenizer
.
tokenize
(
context
[
'text'
])
extended_ctx_ids
=
title_ids
+
[
tokenizer
.
sep
]
+
ctx_ids
ctx_ids
,
ctx_types
,
ctx_pad_mask
=
\
build_tokens_types_paddings_from_ids
(
extended_ctx_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
return
query_ids
,
query_types
,
query_pad_mask
,
\
ctx_ids
,
ctx_types
,
ctx_pad_mask
# Similar code tasks/data_utils with some changes
def
build_tokens_types_paddings_from_ids
(
text_ids
,
max_seq_length
,
cls_id
,
sep_id
,
pad_id
):
"""Build token types and paddings, trim if needed, and pad if needed."""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
text_ids
)
enc_ids
.
extend
(
text_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
pad_mask
=
([
1
]
*
num_tokens_enc
)
+
([
0
]
*
padding_length
)
pad_mask
=
np
.
array
(
pad_mask
,
dtype
=
np
.
int64
)
return
enc_ids
,
tokentypes_enc
,
pad_mask
def
build_sample
(
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
ctx_pad_mask
,
answers
,
neg_ctx_id_list
=
None
,
neg_ctx_types_list
=
None
,
include_neg
=
False
):
"""Convert to numpy and return a sample consumed by the batch producer."""
query_ids
=
np
.
array
(
query_ids
,
dtype
=
np
.
int64
)
query_types
=
np
.
array
(
query_types
,
dtype
=
np
.
int64
)
query_mask
=
make_attention_mask
(
query_ids
,
query_ids
)
ctx_ids
=
np
.
array
(
ctx_ids
,
dtype
=
np
.
int64
)
ctx_types
=
np
.
array
(
ctx_types
,
dtype
=
np
.
int64
)
ctx_mask
=
make_attention_mask
(
ctx_ids
,
ctx_ids
)
sample
=
({
'query'
:
query_ids
,
'query_mask'
:
query_mask
,
'query_types'
:
query_types
,
'query_pad_mask'
:
query_pad_mask
,
'context'
:
ctx_ids
,
'context_mask'
:
ctx_mask
,
'context_types'
:
ctx_types
,
'context_pad_mask'
:
ctx_pad_mask
,
'reference'
:
answers
})
if
include_neg
:
neg_ctx_ids
=
np
.
array
(
neg_ctx_id_list
,
dtype
=
np
.
int64
)
neg_ctx_id_types
=
np
.
array
(
neg_ctx_types_list
,
dtype
=
np
.
int64
)
neg_ctx_mask
=
np
.
array
([
make_attention_mask
(
ids
,
ids
)
\
for
ids
in
neg_ctx_ids
],
dtype
=
np
.
int64
)
sample
[
'neg_context'
]
=
neg_ctx_ids
sample
[
'neg_context_types'
]
=
neg_ctx_id_types
sample
[
'neg_context_mask'
]
=
neg_ctx_mask
return
sample
class
OpenRetrievalAbstractDataset
(
ABC
,
Dataset
):
"""Open Retrieval base dataset class."""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapaths
,
tokenizer
,
\
max_seq_length
,
evaluate
=
False
):
# Store inputs.
args
=
get_args
()
self
.
evaluate
=
evaluate
self
.
val_av_rank_hard_neg
=
args
.
val_av_rank_hard_neg
self
.
val_av_rank_other_neg
=
args
.
val_av_rank_other_neg
self
.
train_with_neg
=
args
.
train_with_neg
self
.
train_hard_neg
=
args
.
train_hard_neg
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
# Process the files.
string
=
' > paths:'
for
path
in
datapaths
:
string
+=
' '
+
path
print_rank_0
(
string
)
self
.
samples
=
[]
for
datapath
in
datapaths
:
self
.
samples
.
extend
(
self
.
process_samples_from_single_path
(
datapath
))
args
=
get_args
()
if
args
.
sample_rate
<
1
:
# subsample
k
=
int
(
len
(
self
.
samples
)
*
args
.
sample_rate
)
self
.
samples
=
random
.
sample
(
self
.
samples
,
k
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
raw_sample
=
self
.
samples
[
idx
]
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
\
ctx_pad_mask
=
build_tokens_types_paddings_from_text
(
\
raw_sample
[
'question'
],
raw_sample
[
'pos_context'
],
\
self
.
tokenizer
,
self
.
max_seq_length
)
if
self
.
evaluate
:
neg_ctx_list
=
\
raw_sample
[
'negative_context'
][:
self
.
val_av_rank_other_neg
]
+
\
raw_sample
[
'hard_negative_context'
][:
self
.
val_av_rank_hard_neg
]
neg_ctx_id_list
,
neg_ctx_types_list
=
\
build_token_types_from_context_list
(
neg_ctx_list
,
\
self
.
tokenizer
,
self
.
max_seq_length
)
elif
self
.
train_with_neg
:
hard_negative_ctx
=
raw_sample
[
'hard_negative_context'
]
negative_ctx
=
raw_sample
[
'negative_context'
]
if
True
:
# TODO: fix this or remove this condition
random
.
shuffle
(
hard_negative_ctx
)
random
.
shuffle
(
negative_ctx
)
neg_ctx_list
=
hard_negative_ctx
[:
self
.
train_hard_neg
]
# In the Google NQ dataset by DPR paper, there are around more than
# 50 missing hard negatives in training data.
# In those cases, substitute hard negatives by simple negatives.
if
len
(
neg_ctx_list
)
<
self
.
train_hard_neg
:
neg_ctx_list
+=
negative_ctx
[:
self
.
train_hard_neg
-
\
len
(
neg_ctx_list
)]
neg_ctx_id_list
,
neg_ctx_types_list
=
\
build_token_types_from_context_list
(
neg_ctx_list
,
self
.
tokenizer
,
self
.
max_seq_length
)
else
:
neg_ctx_id_list
=
None
neg_ctx_types_list
=
None
sample
=
build_sample
(
query_ids
,
query_types
,
query_pad_mask
,
ctx_ids
,
ctx_types
,
ctx_pad_mask
,
raw_sample
[
'answers'
],
neg_ctx_id_list
,
neg_ctx_types_list
,
include_neg
=
self
.
evaluate
or
self
.
train_with_neg
)
return
sample
@
staticmethod
@
abstractmethod
def
process_samples_from_single_path
(
filename
):
"""Abstract method that takes a filename and
returns a list of dataset samples, each sample being a dict of
{'text': string, 'text': string}
"""
pass
def
normalize_question
(
question
):
if
question
[
-
1
]
==
'?'
:
question
=
question
[:
-
1
]
return
question
# The following class reads the datasets for training retriever as
# prepared by the DPR codebase (https://github.com/facebookresearch/DPR)
class
NQSupervisedDataset
(
OpenRetrievalAbstractDataset
):
def
__init__
(
self
,
name
,
datapaths
,
tokenizer
,
max_seq_length
,
\
evaluate
=
False
):
super
().
__init__
(
'natural_questions_ret'
,
name
,
datapaths
,
tokenizer
,
max_seq_length
,
evaluate
=
evaluate
)
@
staticmethod
def
process_samples_from_single_path
(
filename
):
""""Implement abstract method."""
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
samples
=
[]
total
=
0
with
open
(
filename
,
'r'
,
encoding
=
"utf-8"
)
as
f
:
data
=
json
.
load
(
f
)
for
row
in
data
:
question
=
normalize_question
(
row
[
'question'
])
pos_context
=
row
[
'positive_ctxs'
][
0
]
# Hard Negative Contexts
if
len
(
row
[
'hard_negative_ctxs'
])
>
0
:
hard_neg_context
=
row
[
'hard_negative_ctxs'
]
else
:
hard_neg_context
=
[]
# Negative Contexts
if
len
(
row
[
'negative_ctxs'
])
>
0
:
neg_context
=
row
[
'negative_ctxs'
]
else
:
neg_context
=
[]
answers
=
row
[
'answers'
]
sample
=
{
'question'
:
question
,
'pos_context'
:
pos_context
,
'hard_negative_context'
:
hard_neg_context
,
'negative_context'
:
neg_context
,
'answers'
:
answers
}
total
+=
1
samples
.
append
(
sample
)
if
total
%
5000
==
0
:
print_rank_0
(
' > processed {} so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
samples
)))
return
samples
tasks/orqa/supervised/eval_utils.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
from
collections
import
OrderedDict
import
math
import
numpy
as
np
import
time
import
torch
import
torch.nn.functional
as
F
from
torch.utils.data
import
DataLoader
from
megatron
import
get_args
,
print_rank_0
from
megatron
import
mpu
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.finetune_utils
import
build_data_loader
def
task_collate_fn
(
batch_data
):
# generate batch
batch_size
=
len
(
batch_data
)
tensorized
=
OrderedDict
()
for
d
in
batch_data
:
for
k
,
v
in
d
.
items
():
tensorized
.
setdefault
(
k
,
[]).
append
(
v
)
tensorized
[
'query'
]
=
torch
.
LongTensor
(
tensorized
[
'query'
])
tensorized
[
'query_mask'
]
=
torch
.
LongTensor
(
tensorized
[
'query_mask'
])
tensorized
[
'query_types'
]
=
torch
.
LongTensor
(
tensorized
[
'query_types'
])
tensorized
[
'query_pad_mask'
]
=
\
torch
.
LongTensor
(
tensorized
[
'query_pad_mask'
])
tensorized
[
'context'
]
=
torch
.
LongTensor
(
tensorized
[
'context'
])
tensorized
[
'context_mask'
]
=
\
torch
.
LongTensor
(
tensorized
[
'context_mask'
])
tensorized
[
'context_types'
]
=
\
torch
.
LongTensor
(
tensorized
[
'context_types'
])
tensorized
[
'context_pad_mask'
]
=
\
torch
.
LongTensor
(
tensorized
[
'context_pad_mask'
])
if
'neg_context'
in
tensorized
:
tensorized
[
'neg_context'
]
=
\
torch
.
LongTensor
(
np
.
concatenate
(
tensorized
[
'neg_context'
]))
tensorized
[
'neg_context_mask'
]
=
\
torch
.
LongTensor
(
np
.
concatenate
(
tensorized
[
'neg_context_mask'
]))
tensorized
[
'neg_context_types'
]
=
\
torch
.
LongTensor
(
np
.
concatenate
(
tensorized
[
'neg_context_types'
]))
return
tensorized
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
query_tokens
=
batch
[
'query'
].
long
().
cuda
()
query_mask
=
(
batch
[
'query_mask'
]
<
0.5
).
cuda
()
query_types
=
batch
[
'query_types'
].
long
().
cuda
()
query_pad_mask
=
batch
[
'query_pad_mask'
].
long
().
cuda
()
context_tokens
=
batch
[
'context'
].
long
().
cuda
()
context_mask
=
(
batch
[
'context_mask'
]
<
0.5
).
cuda
()
context_types
=
batch
[
'context_types'
].
long
().
cuda
()
context_pad_mask
=
batch
[
'context_pad_mask'
].
long
().
cuda
()
if
'neg_context'
in
batch
:
neg_context_tokens
=
batch
[
'neg_context'
].
long
().
cuda
()
neg_context_mask
=
(
batch
[
'neg_context_mask'
]
<
0.5
).
cuda
()
neg_context_types
=
batch
[
'neg_context_types'
].
long
().
cuda
()
else
:
neg_context_tokens
=
None
neg_context_mask
=
None
neg_context_types
=
None
reference
=
batch
[
'reference'
]
return
query_tokens
,
query_mask
,
query_types
,
query_pad_mask
,
\
context_tokens
,
context_mask
,
context_types
,
context_pad_mask
,
\
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
reference
def
accuracy_func_provider
(
single_dataset_provider
,
rank0sampler
=
False
):
"""Provide function that calculates accuracies."""
args
=
get_args
()
print_rank_0
(
"accuracy_func_provider is CALLED"
)
# Build dataloaders
datapath
=
args
.
valid_data
dataset
=
single_dataset_provider
(
datapath
)
drop_last
=
False
if
mpu
.
get_data_parallel_world_size
()
>
1
and
not
rank0sampler
:
drop_last
=
True
print_rank_0
(
datapath
)
print_rank_0
(
rank0sampler
)
dataloader
=
build_data_loader
(
dataset
,
args
.
eval_micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
drop_last
,
task_collate_fn
=
task_collate_fn
)
dataloaders
=
(
dataset
.
dataset_name
,
dataloader
)
def
metrics_func
(
model
,
epoch
,
output_predictions
=
False
):
print_rank_0
(
'calculating metrics by accuracy func in ORQA...'
)
if
output_predictions
:
assert
rank0sampler
names
=
'predictions'
name
,
dataloader
=
dataloaders
if
args
.
task
==
"RET-FINETUNE-NQ"
:
start_time
=
time
.
time
()
output
=
retrieval_loss
(
model
,
dataloader
)
stats_dict
,
total
=
output
format_string
=
""
for
k
,
v
in
stats_dict
.
items
():
format_string
+=
"|{} = {:.2f}"
.
format
(
k
,
v
/
total
)
print_rank_0
(
"epoch:{}{}"
.
format
(
epoch
,
format_string
))
print_rank_0
(
"taken time to calcuate metrics {:.3f}"
.
format
(
\
time
.
time
()
-
start_time
))
else
:
raise
AssertionError
(
"{} Task not supported"
.
format
(
args
.
task
))
return
metrics_func
def
retrieval_loss
(
model
,
dataloader
):
args
=
get_args
()
total
=
0
topk_stats_dict
=
{
'top{}_acc'
.
format
(
k
):
0
for
k
in
\
args
.
retriever_report_topk_accuracies
}
stats_dict
=
dict
(
rank
=
0
,
**
topk_stats_dict
)
assert
len
(
model
)
==
1
unwrapped_model
=
model
[
0
]
unwrapped_model
.
eval
()
with
torch
.
no_grad
():
# For all the batches in the dataset.
for
batch
in
dataloader
:
# Run the model forward.
query_tokens
,
query_mask
,
query_types
,
_
,
\
context_tokens
,
context_mask
,
context_types
,
_
,
\
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
\
reference
=
process_batch
(
batch
)
query_logits
,
context_logits
=
unwrapped_model
(
query_tokens
,
query_mask
,
query_types
,
torch
.
cat
([
context_tokens
,
neg_context_tokens
]),
torch
.
cat
([
context_mask
,
neg_context_mask
]),
torch
.
cat
([
context_types
,
neg_context_types
]))
retrieval_scores
=
torch
.
matmul
(
query_logits
,
torch
.
transpose
(
context_logits
,
0
,
1
))
if
args
.
retriever_score_scaling
:
retrieval_scores
=
retrieval_scores
/
\
math
.
sqrt
(
args
.
hidden_size
)
local_batch_size
=
query_logits
.
shape
[
0
]
labels
=
torch
.
arange
(
local_batch_size
).
long
().
cuda
()
softmax_scores
=
F
.
softmax
(
retrieval_scores
,
dim
=
1
)
sorted_vals
,
sorted_indices
=
torch
.
topk
(
softmax_scores
,
k
=
softmax_scores
.
shape
[
1
],
sorted
=
True
)
def
topk_accuracy
(
k
):
return
torch
.
cuda
.
FloatTensor
(
[
sum
([
int
(
labels
[
i
]
in
sorted_indices
[
i
,
:
k
])
for
i
in
\
range
(
local_batch_size
)])])
def
get_rank
():
return
torch
.
cuda
.
FloatTensor
(
[
sum
([
torch
.
nonzero
(
labels
[
i
]
==
sorted_indices
[
i
])[
0
][
0
]
\
for
i
in
range
(
local_batch_size
)])])
topk_accs
=
[
topk_accuracy
(
k
)
for
k
in
\
args
.
retriever_report_topk_accuracies
]
rank
=
get_rank
()
losses
=
average_losses_across_data_parallel_group
([
rank
,
\
*
topk_accs
])
# create stats_dict with retrieval loss and all specified
# top-k accuracies
topk_acc_dict
=
{
'top{}_acc'
.
format
(
k
):
v
*
100
for
k
,
v
in
\
zip
(
args
.
retriever_report_topk_accuracies
,
losses
[
1
:])}
temp_stats_dict
=
dict
(
rank
=
losses
[
0
],
**
topk_acc_dict
)
for
k
in
stats_dict
.
keys
():
stats_dict
[
k
]
+=
temp_stats_dict
[
k
]
total
+=
local_batch_size
unwrapped_model
.
train
()
return
stats_dict
,
total
tasks/orqa/supervised/finetune.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""ORQA finetuning/evaluation."""
from
functools
import
partial
import
sys
import
math
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
get_timers
,
get_tokenizer
from
megatron
import
mpu
,
print_rank_0
from
megatron.indexer
import
IndexBuilder
from
megatron.model.biencoder_model
import
biencoder_model_provider
from
megatron.utils
import
average_losses_across_data_parallel_group
from
pretrain_ict
import
get_group_world_size_rank
from
tasks.finetune_utils
import
finetune
from
tasks.orqa.supervised.eval_utils
import
accuracy_func_provider
from
tasks.orqa.supervised.eval_utils
import
process_batch
,
task_collate_fn
from
tasks.orqa.evaluate_utils
import
ORQAEvaluator
# input_ is a 2D tensor
def
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
input_
):
# gather the size of the first dimension of the tensor from all ranks
current_length
=
input_
.
size
()[
0
]
first_dim
=
torch
.
tensor
([[
current_length
]],
device
=
torch
.
cuda
.
current_device
())
input_list
=
[
torch
.
empty_like
(
first_dim
)
for
_
in
range
(
world_size
)]
input_list
[
rank
].
copy_
(
first_dim
)
torch
.
distributed
.
all_gather
(
input_list
,
first_dim
,
group
=
group
)
all_input_list
=
torch
.
cat
(
input_list
,
dim
=
0
).
contiguous
()
max_length
=
torch
.
max
(
all_input_list
)
# if the size are different than the max, extend the tensor
# accordingly
if
max_length
>
current_length
:
padding
=
tuple
([
0
]
*
(
input_
.
dim
()
*
2
-
1
))
+
\
tuple
([
max_length
-
current_length
])
input_
=
F
.
pad
(
input
=
input_
,
pad
=
padding
)
return
input_
def
orqa
(
Dataset
):
def
cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
tokenizer
=
get_tokenizer
()
# Get the batch.
timers
(
'batch generator'
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
group
,
rank
,
world_size
=
get_group_world_size_rank
()
query_tokens
,
query_mask
,
query_types
,
query_pad_mask
,
\
context_tokens
,
context_mask
,
context_types
,
context_pad_mask
,
\
neg_context_tokens
,
neg_context_mask
,
neg_context_types
,
\
reference
=
process_batch
(
batch_
)
timers
(
'batch generator'
).
stop
()
local_batch_size
=
query_tokens
.
shape
[
0
]
# Text representation of query and context
query_list
,
context_list
=
[],
[]
for
i
in
range
(
local_batch_size
):
query_list
.
append
(
tokenizer
.
decode
(
query_tokens
[
i
].
tolist
()))
context_list
.
append
(
tokenizer
.
decode
(
context_tokens
[
i
].
tolist
()))
if
neg_context_tokens
is
not
None
:
neg_context_tokens
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_tokens
)
neg_context_mask
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_mask
)
neg_context_types
=
check_and_append_tensor_for_gather
(
group
,
rank
,
world_size
,
neg_context_types
)
if
neg_context_tokens
is
not
None
:
context_tokens
=
torch
.
cat
([
context_tokens
,
neg_context_tokens
])
context_mask
=
torch
.
cat
([
context_mask
,
neg_context_mask
])
context_types
=
torch
.
cat
([
context_types
,
neg_context_types
])
# Forward model.
output_tensor
=
model
(
query_tokens
,
query_mask
,
query_types
,
context_tokens
,
context_mask
,
context_types
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
query_tokens
,
context_tokens
)
def
cross_entropy_loss_func
(
query_tokens
,
context_tokens
,
output_tensor
):
args
=
get_args
()
local_batch_size
=
query_tokens
.
shape
[
0
]
group
,
rank
,
world_size
=
get_group_world_size_rank
()
# recall we assert that model_parallel_size == 1
global_batch_size
=
world_size
*
local_batch_size
query_logits
,
context_logits
=
output_tensor
if
world_size
>
1
:
input_
=
torch
.
empty_like
(
context_logits
).
copy_
(
\
context_logits
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
].
copy_
(
input_
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# Check if all-gather happens in order
assert
tensor_list
[
rank
].
sum
().
item
()
==
\
context_logits
.
sum
().
item
()
# Preserves the gradient
tensor_list
[
rank
]
=
context_logits
all_context_logits
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
# Query tensors
input_
=
torch
.
empty_like
(
query_logits
).
copy_
(
\
query_logits
).
detach_
()
tensor_list
=
[
torch
.
empty_like
(
input_
)
for
_
in
range
(
world_size
)]
tensor_list
[
rank
].
copy_
(
input_
)
torch
.
distributed
.
all_gather
(
tensor_list
,
input_
,
group
=
group
)
# Check if all-gather happens in order
assert
tensor_list
[
rank
].
sum
().
item
()
==
query_logits
.
sum
().
item
()
# Preserves the gradient
tensor_list
[
rank
]
=
query_logits
all_query_logits
=
torch
.
cat
(
tensor_list
,
dim
=
0
).
contiguous
()
else
:
all_query_logits
=
query_logits
all_context_logits
=
context_logits
retrieval_scores
=
torch
.
matmul
(
all_query_logits
,
torch
.
transpose
(
all_context_logits
,
0
,
1
))
# Scaling the retrieval scores
if
args
.
retriever_score_scaling
:
retrieval_scores
=
retrieval_scores
/
math
.
sqrt
(
args
.
hidden_size
)
if
args
.
train_with_neg
:
# if the world size is 3, local batch size is 4, and
# local context size is 8, what we want is
# labels = [0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19]
labels
=
[]
local_context_size
=
context_tokens
.
shape
[
0
]
for
i
in
range
(
world_size
):
j
=
i
*
local_context_size
labels
.
extend
(
list
(
range
(
j
,
j
+
local_batch_size
)))
labels
=
torch
.
LongTensor
(
labels
).
cuda
()
assert
len
(
labels
)
==
global_batch_size
else
:
labels
=
torch
.
arange
(
global_batch_size
).
long
().
cuda
()
# Cross-entropy loss.
softmax_scores
=
F
.
log_softmax
(
retrieval_scores
,
dim
=
1
)
loss
=
F
.
nll_loss
(
softmax_scores
,
labels
,
reduction
=
'mean'
)
max_score
,
max_idxs
=
torch
.
max
(
softmax_scores
,
1
)
correct_predictions_count
=
(
max_idxs
==
labels
).
sum
().
float
()
# Reduce loss for logging.
reduced_loss
=
average_losses_across_data_parallel_group
([
loss
,
\
correct_predictions_count
])
# Loss scaling for correct losses in Supervised Retrieval
loss
=
loss
*
mpu
.
get_data_parallel_world_size
()
return
loss
,
{
'lm loss'
:
reduced_loss
[
0
],
'correct_prediction_count'
:
reduced_loss
[
1
]}
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
train_dataset
=
Dataset
(
'training'
,
args
.
train_data
,
tokenizer
,
args
.
retriever_seq_length
,
evaluate
=
False
)
valid_dataset
=
Dataset
(
'validation'
,
args
.
valid_data
,
tokenizer
,
args
.
retriever_seq_length
,
evaluate
=
True
)
return
train_dataset
,
valid_dataset
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building retriever model for {} ...'
.
format
(
args
.
task
))
model
=
biencoder_model_provider
(
only_context_model
=
False
,
only_query_model
=
False
,
biencoder_shared_query_context_model
=
\
args
.
biencoder_shared_query_context_model
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
single_dataset_provider
(
datapath
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
name
=
datapath
[
0
].
split
(
'/'
)[
-
1
].
split
(
'.'
)[
0
]
return
Dataset
(
name
,
datapath
,
tokenizer
,
args
.
retriever_seq_length
,
evaluate
=
True
)
def
metrics_func_provider
():
"""Provide metrics callback function."""
return
accuracy_func_provider
(
single_dataset_provider
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
metrics_func_provider
,
task_collate_fn
=
task_collate_fn
)
def
main
():
args
=
get_args
()
if
args
.
task
==
'RET-FINETUNE-NQ'
:
from
tasks.orqa.supervised.data
import
NQSupervisedDataset
as
Dataset
else
:
raise
NotImplementedError
(
'ORQA task {} is not implemented.'
.
format
(
args
.
task
))
orqa
(
Dataset
)
tasks/orqa/unsupervised/nq.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Data Loader for Google NQ dataset
"""
from
abc
import
ABC
import
csv
from
collections
import
OrderedDict
import
numpy
as
np
import
torch
from
torch.utils.data
import
DataLoader
from
torch.utils.data
import
Dataset
,
BatchSampler
from
megatron
import
print_rank_0
,
get_args
,
get_tokenizer
,
mpu
from
megatron.data.biencoder_dataset_utils
import
make_attention_mask
def
get_nq_dataset
(
qa_data
,
split
):
args
=
get_args
()
tokenizer
=
get_tokenizer
()
dataset
=
NQDataset
(
'Google NQ {} Split'
.
format
(
split
),
'Google Natural Questions'
,
qa_data
,
tokenizer
,
args
.
retriever_seq_length
)
return
dataset
def
process_nq_batch
(
batch
):
query_tokens
=
batch
[
'token_ids'
].
long
().
cuda
()
query_mask
=
(
batch
[
'token_mask'
]
<
0.5
).
cuda
()
query_types
=
batch
[
'token_types'
].
long
().
cuda
()
query_len
=
batch
[
'seq_len'
].
long
().
cuda
()
reference
=
batch
[
'reference'
]
return
query_tokens
,
query_mask
,
query_types
,
query_len
,
reference
class
CustomDataLoader
(
DataLoader
):
def
__init__
(
self
,
dataset
,
eval
=
False
,
**
kwargs
):
if
kwargs
.
get
(
'collate_fn'
,
None
)
is
None
:
kwargs
[
'collate_fn'
]
=
self
.
_collate_fn
self
.
eval
=
eval
super
().
__init__
(
dataset
,
**
kwargs
)
def
_collate_fn
(
self
,
batch_data
):
# generate batch
batch_size
=
len
(
batch_data
)
tensorized
=
OrderedDict
()
for
d
in
batch_data
:
for
k
,
v
in
d
.
items
():
tensorized
.
setdefault
(
k
,
[]).
append
(
v
)
assert
len
(
tensorized
)
==
5
tensorized
[
'token_ids'
]
=
torch
.
LongTensor
(
tensorized
[
'token_ids'
])
tensorized
[
'token_mask'
]
=
torch
.
LongTensor
(
tensorized
[
'token_mask'
])
tensorized
[
'token_types'
]
=
torch
.
LongTensor
(
tensorized
[
'token_types'
])
tensorized
[
'seq_len'
]
=
torch
.
LongTensor
(
tensorized
[
'seq_len'
])
return
tensorized
def
get_one_epoch_nq_dataloader
(
dataset
,
micro_batch_size
=
None
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size.
NOTE: This dataloader is not distributed !!!
"""
args
=
get_args
()
if
micro_batch_size
is
None
:
micro_batch_size
=
args
.
micro_batch_size
num_workers
=
args
.
num_workers
sampler
=
torch
.
utils
.
data
.
SequentialSampler
(
dataset
)
# importantly, drop_last must be False to get all the data.
batch_sampler
=
BatchSampler
(
sampler
,
batch_size
=
micro_batch_size
,
drop_last
=
False
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
CustomDataLoader
(
dataset
,
batch_sampler
=
batch_sampler
,
num_workers
=
num_workers
,
pin_memory
=
True
)
return
data_loader
def
build_tokens_types_paddings_from_text
(
src_text
,
tokenizer
,
max_seq_length
):
"""Build token types and paddings, trim if needed, and pad if needed."""
src_text_ids
=
tokenizer
.
tokenize
(
src_text
)
return
build_tokens_types_paddings_from_ids
(
src_text_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
def
build_tokens_types_paddings_from_ids
(
src_ids
,
max_seq_length
,
cls_id
,
\
sep_id
,
pad_id
):
"""
Build token types and paddings, trim if needed, and pad if needed.
TODO: Design modular interface to reuse this function. This is getting
repeated multiple times in different tasks
"""
enc_ids
=
[]
tokentypes_enc
=
[]
# [CLS].
enc_ids
.
append
(
cls_id
)
tokentypes_enc
.
append
(
0
)
# A.
len_src
=
len
(
src_ids
)
enc_ids
.
extend
(
src_ids
)
tokentypes_enc
.
extend
([
0
]
*
len_src
)
# Cap the size.
if
len
(
enc_ids
)
>
max_seq_length
-
1
:
enc_ids
=
enc_ids
[
0
:
max_seq_length
-
1
]
tokentypes_enc
=
tokentypes_enc
[
0
:
max_seq_length
-
1
]
# [SEP].
enc_ids
.
append
(
sep_id
)
tokentypes_enc
.
append
(
0
)
num_tokens_enc
=
len
(
enc_ids
)
# Padding.
padding_length
=
max_seq_length
-
len
(
enc_ids
)
if
padding_length
>
0
:
enc_ids
.
extend
([
pad_id
]
*
padding_length
)
tokentypes_enc
.
extend
([
pad_id
]
*
padding_length
)
return
enc_ids
,
tokentypes_enc
,
num_tokens_enc
def
build_sample
(
token_ids
,
token_types
,
num_tokens
,
reference
):
"""
Convert to numpy and return a sample consumed by the
batch producer.
"""
token_ids
=
np
.
array
(
token_ids
,
dtype
=
np
.
int64
)
token_types
=
np
.
array
(
token_types
,
dtype
=
np
.
int64
)
token_mask
=
make_attention_mask
(
token_ids
,
token_ids
)
sample
=
({
'token_ids'
:
token_ids
,
'token_mask'
:
token_mask
,
'token_types'
:
token_types
,
'seq_len'
:
num_tokens
,
'reference'
:
reference
})
return
sample
class
NQDataset
(
ABC
,
Dataset
):
"""
Open Retrieval Question Answering evaluation using Google NQ dataset.
"""
def
__init__
(
self
,
task_name
,
dataset_name
,
datapath
,
tokenizer
,
max_seq_length
):
# Store inputs.
self
.
task_name
=
task_name
self
.
dataset_name
=
dataset_name
self
.
tokenizer
=
tokenizer
self
.
max_seq_length
=
max_seq_length
print_rank_0
(
' > building {} dataset for {}:'
.
format
(
self
.
task_name
,
self
.
dataset_name
))
print_rank_0
(
datapath
)
self
.
samples
=
self
.
process_samples_from_single_path
(
datapath
)
print_rank_0
(
' >> total number of samples: {}'
.
format
(
\
len
(
self
.
samples
)))
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
raw_sample
=
self
.
samples
[
idx
]
ques_tokens
,
tokentypes_enc
,
num_tokens_ques
=
\
build_tokens_types_paddings_from_text
(
raw_sample
[
'question'
],
self
.
tokenizer
,
self
.
max_seq_length
)
sample
=
build_sample
(
ques_tokens
,
tokentypes_enc
,
num_tokens_ques
,
raw_sample
[
'answers'
])
return
sample
@
staticmethod
def
process_samples_from_single_path
(
filename
):
print_rank_0
(
' > Processing {} ...'
.
format
(
filename
))
samples
=
[]
total
=
0
with
open
(
filename
,
'r'
)
as
ifile
:
reader
=
csv
.
reader
(
ifile
,
delimiter
=
'
\t
'
)
for
row
in
reader
:
question
=
row
[
0
]
answers
=
eval
(
row
[
1
])
sample
=
{
'question'
:
question
,
'answers'
:
answers
}
total
+=
1
samples
.
append
(
sample
)
if
total
%
1000
==
0
:
print_rank_0
(
' > processed {} so far ...'
.
format
(
total
))
print_rank_0
(
' >> processed {} samples.'
.
format
(
len
(
samples
)))
return
samples
tasks/orqa/unsupervised/qa_utils.py
0 → 100644
View file @
1016e98a
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Set of utilities for Q&A results validation tasks - Retriver passage
validation and Reader predicted answer validation
"""
import
collections
import
logging
import
string
import
unicodedata
from
functools
import
partial
from
multiprocessing
import
Pool
as
ProcessPool
from
typing
import
Tuple
,
List
,
Dict
import
regex
as
re
from
tasks.orqa.unsupervised.tokenizers
import
SimpleTokenizer
logger
=
logging
.
getLogger
(
__name__
)
QAMatchStats
=
collections
.
namedtuple
(
'QAMatchStats'
,
[
'top_k_hits'
,
\
'questions_doc_hits'
])
def
calculate_matches
(
all_docs
:
Dict
[
object
,
Tuple
[
str
,
str
]],
answers
:
List
[
List
[
str
]],
closest_docs
:
List
[
Tuple
[
List
[
object
],
List
[
float
]]],
workers_num
:
int
,
match_type
:
str
)
->
QAMatchStats
:
"""
Evaluates answers presence in the set of documents. This function is
supposed to be used with a large collection of documents and results.
It internally forks multiple sub-processes for evaluation and then
merges results
:param all_docs: dictionary of the entire documents database.
doc_id -> (doc_text, title)
:param answers: list of answers's list. One list per question
:param closest_docs: document ids of the top results along with their
scores
:param workers_num: amount of parallel threads to process data
:param match_type: type of answer matching. Refer to has_answer code for
available options
:return: matching information tuple.
top_k_hits - a list where the index is the amount of top documents retrieved
and the value is the total amount of valid matches across an entire
dataset.
questions_doc_hits - more detailed info with answer matches for every
question and every retrieved document
"""
global
dpr_all_documents
dpr_all_documents
=
all_docs
tok_opts
=
{}
tokenizer
=
SimpleTokenizer
(
**
tok_opts
)
processes
=
ProcessPool
(
processes
=
workers_num
,
)
logger
.
info
(
'Matching answers in top docs...'
)
get_score_partial
=
partial
(
check_answer
,
match_type
=
match_type
,
tokenizer
=
tokenizer
)
questions_answers_docs
=
zip
(
answers
,
closest_docs
)
scores
=
processes
.
map
(
get_score_partial
,
questions_answers_docs
)
logger
.
info
(
'Per question validation results len=%d'
,
len
(
scores
))
n_docs
=
len
(
closest_docs
[
0
][
0
])
top_k_hits
=
[
0
]
*
n_docs
for
question_hits
in
scores
:
best_hit
=
next
((
i
for
i
,
x
in
enumerate
(
question_hits
)
if
x
),
None
)
if
best_hit
is
not
None
:
top_k_hits
[
best_hit
:]
=
[
v
+
1
for
v
in
top_k_hits
[
best_hit
:]]
return
QAMatchStats
(
top_k_hits
,
scores
)
def
check_answer
(
questions_answers_docs
,
tokenizer
,
match_type
)
->
List
[
bool
]:
"""
Search through all the top docs to see if they have any of the answers.
"""
answers
,
(
doc_ids
,
doc_scores
)
=
questions_answers_docs
global
dpr_all_documents
hits
=
[]
for
i
,
doc_id
in
enumerate
(
doc_ids
):
doc
=
dpr_all_documents
[
doc_id
]
text
=
doc
[
0
]
answer_found
=
False
if
text
is
None
:
# cannot find the document for some reason
logger
.
warning
(
"no doc in db"
)
hits
.
append
(
False
)
continue
if
has_answer
(
answers
,
text
,
tokenizer
,
match_type
):
answer_found
=
True
hits
.
append
(
answer_found
)
return
hits
def
has_answer
(
answers
,
text
,
tokenizer
,
match_type
)
->
bool
:
"""
Check if a document contains an answer string.
If `match_type` is string, token matching is done between the text
and answer.
If `match_type` is regex, we search the whole text with the regex.
"""
text
=
_normalize
(
text
)
if
match_type
==
'string'
:
# Answer is a list of possible strings
text
=
tokenizer
.
tokenize
(
text
).
words
(
uncased
=
True
)
for
single_answer
in
answers
:
single_answer
=
_normalize
(
single_answer
)
single_answer
=
tokenizer
.
tokenize
(
single_answer
)
single_answer
=
single_answer
.
words
(
uncased
=
True
)
for
i
in
range
(
0
,
len
(
text
)
-
len
(
single_answer
)
+
1
):
if
single_answer
==
text
[
i
:
i
+
len
(
single_answer
)]:
return
True
elif
match_type
==
'regex'
:
# Answer is a regex
for
single_answer
in
answers
:
single_answer
=
_normalize
(
single_answer
)
if
regex_match
(
text
,
single_answer
):
return
True
return
False
def
regex_match
(
text
,
pattern
):
"""Test if a regex pattern is contained within a text."""
try
:
pattern
=
re
.
compile
(
pattern
,
flags
=
re
.
IGNORECASE
+
re
.
UNICODE
+
re
.
MULTILINE
,
)
except
BaseException
:
return
False
return
pattern
.
search
(
text
)
is
not
None
# function for the reader model answer validation
def
exact_match_score
(
prediction
,
ground_truth
):
return
_normalize_answer
(
prediction
)
==
_normalize_answer
(
ground_truth
)
def
_normalize_answer
(
s
):
def
remove_articles
(
text
):
return
re
.
sub
(
r
'\b(a|an|the)\b'
,
' '
,
text
)
def
white_space_fix
(
text
):
return
' '
.
join
(
text
.
split
())
def
remove_punc
(
text
):
exclude
=
set
(
string
.
punctuation
)
return
''
.
join
(
ch
for
ch
in
text
if
ch
not
in
exclude
)
def
lower
(
text
):
return
text
.
lower
()
return
white_space_fix
(
remove_articles
(
remove_punc
(
lower
(
s
))))
def
_normalize
(
text
):
return
unicodedata
.
normalize
(
'NFD'
,
text
)
tasks/orqa/unsupervised/tokenizers.py
0 → 100644
View file @
1016e98a
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# The following code has been taken from
# https://github.com/facebookresearch/DPR, which is CC-BY-NC 4.0
# licensed as of now. More details on the license can be found
# at https://github.com/facebookresearch/DPR/blob/master/LICENSE
"""
Most of the tokenizers code here is copied from DrQA codebase to avoid adding extra dependency
"""
import
copy
import
logging
import
regex
import
spacy
logger
=
logging
.
getLogger
(
__name__
)
class
Tokens
(
object
):
"""A class to represent a list of tokenized text."""
TEXT
=
0
TEXT_WS
=
1
SPAN
=
2
POS
=
3
LEMMA
=
4
NER
=
5
def
__init__
(
self
,
data
,
annotators
,
opts
=
None
):
self
.
data
=
data
self
.
annotators
=
annotators
self
.
opts
=
opts
or
{}
def
__len__
(
self
):
"""The number of tokens."""
return
len
(
self
.
data
)
def
slice
(
self
,
i
=
None
,
j
=
None
):
"""Return a view of the list of tokens from [i, j)."""
new_tokens
=
copy
.
copy
(
self
)
new_tokens
.
data
=
self
.
data
[
i
:
j
]
return
new_tokens
def
untokenize
(
self
):
"""Returns the original text (with whitespace reinserted)."""
return
''
.
join
([
t
[
self
.
TEXT_WS
]
for
t
in
self
.
data
]).
strip
()
def
words
(
self
,
uncased
=
False
):
"""Returns a list of the text of each token
Args:
uncased: lower cases text
"""
if
uncased
:
return
[
t
[
self
.
TEXT
].
lower
()
for
t
in
self
.
data
]
else
:
return
[
t
[
self
.
TEXT
]
for
t
in
self
.
data
]
def
offsets
(
self
):
"""Returns a list of [start, end) character offsets of each token."""
return
[
t
[
self
.
SPAN
]
for
t
in
self
.
data
]
def
pos
(
self
):
"""Returns a list of part-of-speech tags of each token.
Returns None if this annotation was not included.
"""
if
'pos'
not
in
self
.
annotators
:
return
None
return
[
t
[
self
.
POS
]
for
t
in
self
.
data
]
def
lemmas
(
self
):
"""Returns a list of the lemmatized text of each token.
Returns None if this annotation was not included.
"""
if
'lemma'
not
in
self
.
annotators
:
return
None
return
[
t
[
self
.
LEMMA
]
for
t
in
self
.
data
]
def
entities
(
self
):
"""Returns a list of named-entity-recognition tags of each token.
Returns None if this annotation was not included.
"""
if
'ner'
not
in
self
.
annotators
:
return
None
return
[
t
[
self
.
NER
]
for
t
in
self
.
data
]
def
ngrams
(
self
,
n
=
1
,
uncased
=
False
,
filter_fn
=
None
,
as_strings
=
True
):
"""Returns a list of all ngrams from length 1 to n.
Args:
n: upper limit of ngram length
uncased: lower cases text
filter_fn: user function that takes in an ngram list and returns
True or False to keep or not keep the ngram
as_string: return the ngram as a string vs list
"""
def
_skip
(
gram
):
if
not
filter_fn
:
return
False
return
filter_fn
(
gram
)
words
=
self
.
words
(
uncased
)
ngrams
=
[(
s
,
e
+
1
)
for
s
in
range
(
len
(
words
))
for
e
in
range
(
s
,
min
(
s
+
n
,
len
(
words
)))
if
not
_skip
(
words
[
s
:
e
+
1
])]
# Concatenate into strings
if
as_strings
:
ngrams
=
[
'{}'
.
format
(
' '
.
join
(
words
[
s
:
e
]))
for
(
s
,
e
)
in
ngrams
]
return
ngrams
def
entity_groups
(
self
):
"""Group consecutive entity tokens with the same NER tag."""
entities
=
self
.
entities
()
if
not
entities
:
return
None
non_ent
=
self
.
opts
.
get
(
'non_ent'
,
'O'
)
groups
=
[]
idx
=
0
while
idx
<
len
(
entities
):
ner_tag
=
entities
[
idx
]
# Check for entity tag
if
ner_tag
!=
non_ent
:
# Chomp the sequence
start
=
idx
while
(
idx
<
len
(
entities
)
and
entities
[
idx
]
==
ner_tag
):
idx
+=
1
groups
.
append
((
self
.
slice
(
start
,
idx
).
untokenize
(),
ner_tag
))
else
:
idx
+=
1
return
groups
class
Tokenizer
(
object
):
"""Base tokenizer class.
Tokenizers implement tokenize, which should return a Tokens class.
"""
def
tokenize
(
self
,
text
):
raise
NotImplementedError
def
shutdown
(
self
):
pass
def
__del__
(
self
):
self
.
shutdown
()
class
SimpleTokenizer
(
Tokenizer
):
ALPHA_NUM
=
r
'[\p{L}\p{N}\p{M}]+'
NON_WS
=
r
'[^\p{Z}\p{C}]'
def
__init__
(
self
,
**
kwargs
):
"""
Args:
annotators: None or empty set (only tokenizes).
"""
self
.
_regexp
=
regex
.
compile
(
'(%s)|(%s)'
%
(
self
.
ALPHA_NUM
,
self
.
NON_WS
),
flags
=
regex
.
IGNORECASE
+
regex
.
UNICODE
+
regex
.
MULTILINE
)
if
len
(
kwargs
.
get
(
'annotators'
,
{}))
>
0
:
logger
.
warning
(
'%s only tokenizes! Skipping annotators: %s'
%
(
type
(
self
).
__name__
,
kwargs
.
get
(
'annotators'
)))
self
.
annotators
=
set
()
def
tokenize
(
self
,
text
):
data
=
[]
matches
=
[
m
for
m
in
self
.
_regexp
.
finditer
(
text
)]
for
i
in
range
(
len
(
matches
)):
# Get text
token
=
matches
[
i
].
group
()
# Get whitespace
span
=
matches
[
i
].
span
()
start_ws
=
span
[
0
]
if
i
+
1
<
len
(
matches
):
end_ws
=
matches
[
i
+
1
].
span
()[
0
]
else
:
end_ws
=
span
[
1
]
# Format data
data
.
append
((
token
,
text
[
start_ws
:
end_ws
],
span
,
))
return
Tokens
(
data
,
self
.
annotators
)
class
SpacyTokenizer
(
Tokenizer
):
def
__init__
(
self
,
**
kwargs
):
"""
Args:
annotators: set that can include pos, lemma, and ner.
model: spaCy model to use (either path, or keyword like 'en').
"""
model
=
kwargs
.
get
(
'model'
,
'en'
)
self
.
annotators
=
copy
.
deepcopy
(
kwargs
.
get
(
'annotators'
,
set
()))
nlp_kwargs
=
{
'parser'
:
False
}
if
not
any
([
p
in
self
.
annotators
for
p
in
[
'lemma'
,
'pos'
,
'ner'
]]):
nlp_kwargs
[
'tagger'
]
=
False
if
'ner'
not
in
self
.
annotators
:
nlp_kwargs
[
'entity'
]
=
False
self
.
nlp
=
spacy
.
load
(
model
,
**
nlp_kwargs
)
def
tokenize
(
self
,
text
):
# We don't treat new lines as tokens.
clean_text
=
text
.
replace
(
'
\n
'
,
' '
)
tokens
=
self
.
nlp
.
tokenizer
(
clean_text
)
if
any
([
p
in
self
.
annotators
for
p
in
[
'lemma'
,
'pos'
,
'ner'
]]):
self
.
nlp
.
tagger
(
tokens
)
if
'ner'
in
self
.
annotators
:
self
.
nlp
.
entity
(
tokens
)
data
=
[]
for
i
in
range
(
len
(
tokens
)):
# Get whitespace
start_ws
=
tokens
[
i
].
idx
if
i
+
1
<
len
(
tokens
):
end_ws
=
tokens
[
i
+
1
].
idx
else
:
end_ws
=
tokens
[
i
].
idx
+
len
(
tokens
[
i
].
text
)
data
.
append
((
tokens
[
i
].
text
,
text
[
start_ws
:
end_ws
],
(
tokens
[
i
].
idx
,
tokens
[
i
].
idx
+
len
(
tokens
[
i
].
text
)),
tokens
[
i
].
tag_
,
tokens
[
i
].
lemma_
,
tokens
[
i
].
ent_type_
,
))
# Set special option for non-entity tag: '' vs 'O' in spaCy
return
Tokens
(
data
,
self
.
annotators
,
opts
=
{
'non_ent'
:
''
})
tasks/race/data.py
0 → 100644
View file @
1016e98a
import
glob
import
json
import
os
import
time
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
from
tasks.data_utils
import
build_sample
from
tasks.data_utils
import
build_tokens_types_paddings_from_ids
from
tasks.data_utils
import
clean_text
NUM_CHOICES
=
4
MAX_QA_LENGTH
=
128
class
RaceDataset
(
Dataset
):
def
__init__
(
self
,
dataset_name
,
datapaths
,
tokenizer
,
max_seq_length
,
max_qa_length
=
MAX_QA_LENGTH
):
self
.
dataset_name
=
dataset_name
print_rank_0
(
' > building RACE dataset for {}:'
.
format
(
self
.
dataset_name
))
string
=
' > paths:'
for
path
in
datapaths
:
string
+=
' '
+
path
print_rank_0
(
string
)
self
.
samples
=
[]
for
datapath
in
datapaths
:
self
.
samples
.
extend
(
process_single_datapath
(
datapath
,
tokenizer
,
max_qa_length
,
max_seq_length
))
print_rank_0
(
' >> total number of samples: {}'
.
format
(
len
(
self
.
samples
)))
# This indicates that each "sample" has multiple samples that
# will collapse into batch dimension
self
.
sample_multiplier
=
NUM_CHOICES
def
__len__
(
self
):
return
len
(
self
.
samples
)
def
__getitem__
(
self
,
idx
):
return
self
.
samples
[
idx
]
def
process_single_datapath
(
datapath
,
tokenizer
,
max_qa_length
,
max_seq_length
):
"""Read in RACE files, combine, clean-up, tokenize, and convert to
samples."""
print_rank_0
(
' > working on {}'
.
format
(
datapath
))
start_time
=
time
.
time
()
# Get list of files.
filenames
=
glob
.
glob
(
os
.
path
.
join
(
datapath
,
'*.txt'
))
samples
=
[]
num_docs
=
0
num_questions
=
0
num_samples
=
0
# Load all the files
for
filename
in
filenames
:
with
open
(
filename
,
'r'
)
as
f
:
for
line
in
f
:
data
=
json
.
loads
(
line
)
num_docs
+=
1
context
=
data
[
"article"
]
questions
=
data
[
"questions"
]
choices
=
data
[
"options"
]
answers
=
data
[
"answers"
]
# Check the length.
assert
len
(
questions
)
==
len
(
answers
)
assert
len
(
questions
)
==
len
(
choices
)
# Context: clean up and convert to ids.
context
=
clean_text
(
context
)
context_ids
=
tokenizer
.
tokenize
(
context
)
# Loop over questions.
for
qi
,
question
in
enumerate
(
questions
):
num_questions
+=
1
# Label.
label
=
ord
(
answers
[
qi
])
-
ord
(
"A"
)
assert
label
>=
0
assert
label
<
NUM_CHOICES
assert
len
(
choices
[
qi
])
==
NUM_CHOICES
# For each question, build num-choices samples.
ids_list
=
[]
types_list
=
[]
paddings_list
=
[]
for
ci
in
range
(
NUM_CHOICES
):
choice
=
choices
[
qi
][
ci
]
# Merge with choice.
if
"_"
in
question
:
qa
=
question
.
replace
(
"_"
,
choice
)
else
:
qa
=
" "
.
join
([
question
,
choice
])
# Clean QA.
qa
=
clean_text
(
qa
)
# Tokenize.
qa_ids
=
tokenizer
.
tokenize
(
qa
)
# Trim if needed.
if
len
(
qa_ids
)
>
max_qa_length
:
qa_ids
=
qa_ids
[
0
:
max_qa_length
]
# Build the sample.
ids
,
types
,
paddings
\
=
build_tokens_types_paddings_from_ids
(
qa_ids
,
context_ids
,
max_seq_length
,
tokenizer
.
cls
,
tokenizer
.
sep
,
tokenizer
.
pad
)
ids_list
.
append
(
ids
)
types_list
.
append
(
types
)
paddings_list
.
append
(
paddings
)
# Convert to numpy and add to samples
samples
.
append
(
build_sample
(
ids_list
,
types_list
,
paddings_list
,
label
,
num_samples
))
num_samples
+=
1
elapsed_time
=
time
.
time
()
-
start_time
print_rank_0
(
' > processed {} document, {} questions, and {} samples'
' in {:.2f} seconds'
.
format
(
num_docs
,
num_questions
,
num_samples
,
elapsed_time
))
return
samples
tasks/race/finetune.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Race."""
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_tokenizer
from
megatron
import
mpu
from
megatron.model.multiple_choice
import
MultipleChoice
from
tasks.eval_utils
import
accuracy_func_provider
from
tasks.finetune_utils
import
finetune
from
tasks.race.data
import
RaceDataset
def
train_valid_datasets_provider
():
"""Provide train and validation datasets."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
train_dataset
=
RaceDataset
(
'training'
,
args
.
train_data
,
tokenizer
,
args
.
seq_length
)
valid_dataset
=
RaceDataset
(
'validation'
,
args
.
valid_data
,
tokenizer
,
args
.
seq_length
)
return
train_dataset
,
valid_dataset
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
'building multichoice model for RACE ...'
)
model
=
MultipleChoice
(
num_tokentypes
=
2
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
model
def
metrics_func_provider
():
"""Privde metrics callback function."""
args
=
get_args
()
tokenizer
=
get_tokenizer
()
def
single_dataset_provider
(
datapath
):
name
=
datapath
.
split
(
'RACE'
)[
-
1
].
strip
(
'/'
).
replace
(
'/'
,
'-'
)
return
RaceDataset
(
name
,
[
datapath
],
tokenizer
,
args
.
seq_length
)
return
accuracy_func_provider
(
single_dataset_provider
)
def
main
():
finetune
(
train_valid_datasets_provider
,
model_provider
,
end_of_epoch_callback_provider
=
metrics_func_provider
)
tasks/vision/classification/classification.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
print_rank_0
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
tasks.vision.classification.eval_utils
import
accuracy_func_provider
from
tasks.vision.finetune_utils
import
finetune
from
megatron.utils
import
average_losses_across_data_parallel_group
def
classification
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
),
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
print_rank_0
(
"building classification model for ImageNet ..."
)
return
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
classification
()
tasks/vision/classification/eval_utils.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Evaluation utilities."""
import
os
from
functools
import
partial
import
torch
from
megatron
import
get_args
from
megatron
import
print_rank_0
,
print_rank_last
from
megatron
import
mpu
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.finetune_utils
import
build_data_loader
from
tasks.vision.finetune_utils
import
process_batch
from
torchvision
import
datasets
,
transforms
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
data_path
=
args
.
data_path
crop_size
=
(
args
.
img_h
,
args
.
img_w
)
# Build dataloaders.
val_data_path
=
data_path
[
1
]
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
transform_val
=
transforms
.
Compose
(
[
transforms
.
Resize
(
crop_size
),
transforms
.
CenterCrop
(
crop_size
),
transforms
.
ToTensor
(),
normalize
,
]
)
dataset
=
datasets
.
ImageFolder
(
root
=
val_data_path
,
transform
=
transform_val
)
dataloader
=
build_data_loader
(
dataset
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
correct
,
total
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
percent
=
float
(
correct
)
*
100.0
/
float
(
total
)
print_rank_last
(
" >> |epoch: {}| overall: correct / total = {} / {} = "
"{:.4f} %"
.
format
(
epoch
,
correct
,
total
,
percent
)
)
return
metrics_func
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
loss_dict
=
{}
# Compute the correct answers.
predicted
=
torch
.
argmax
(
logits
,
dim
=-
1
)
corrects
=
(
predicted
==
labels
).
float
()
# Add to the counters.
loss_dict
[
'total'
]
=
labels
.
size
(
0
)
loss_dict
[
'correct'
]
=
corrects
.
sum
().
item
()
return
0
,
loss_dict
#defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
total
=
0
correct
=
0
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
total
+=
loss_dict
[
'total'
]
correct
+=
loss_dict
[
'correct'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
unreduced
=
torch
.
cuda
.
LongTensor
([
correct
,
total
])
torch
.
distributed
.
all_reduce
(
unreduced
,
group
=
mpu
.
get_data_parallel_group
())
# Print on screen.
correct_ans
=
unreduced
[
0
].
item
()
total_count
=
unreduced
[
1
].
item
()
return
correct_ans
,
total_count
tasks/vision/finetune_utils.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Finetune utilities."""
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
mpu
,
utils
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
setup_model_and_optimizer
from
megatron.training
import
train_step
from
megatron.training
import
training_log
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
average_losses_across_data_parallel_group
,
print_params_min_max_norm
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
,
ModelType
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
labels
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
,
shuffle
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
micro_batch_size
,
sampler
=
sampler
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
drop_last
,
pin_memory
=
True
,
)
return
data_loader
def
_build_infinite_size_dataloader
(
dataloader
):
"""Build a looped dataloader with infinite size."""
iterator
=
dataloader
.
__iter__
()
while
True
:
try
:
yield
iterator
.
__next__
()
except
StopIteration
:
iterator
=
dataloader
.
__iter__
()
def
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
):
"""Traing and validation dataloaders."""
args
=
get_args
()
print_rank_0
(
'building train and validation dataloaders ...'
)
# Training dataset.
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
False
,
True
)
# Set the training iterations.
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
True
,
False
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
# Now that we've built the data loaders, set batch_size arguments
# to the actual batch size the model will see for this dataset.
# This is necessary so pipeline transfers know what size they are
# and the LR schedule, which is based on samples seen, gets set
# correctly.
args
.
orig_micro_batch_size
=
args
.
micro_batch_size
args
.
orig_global_batch_size
=
args
.
global_batch_size
return
train_dataloader
,
valid_dataloader
def
_train
(
model
,
optimizer
,
opt_param_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
process_non_loss_data_func
=
None
):
"""Train the model."""
args
=
get_args
()
timers
=
get_timers
()
# Turn on training mode which enables dropout.
for
m
in
model
:
m
.
train
()
# Tracking loss.
losses_dict_sum
=
{}
# Starting epoch and iteration
start_epoch
=
args
.
iteration
//
args
.
train_iters_per_epoch
start_iteration
=
args
.
iteration
%
args
.
train_iters_per_epoch
iteration
=
args
.
iteration
# Memory reporting flag.
report_memory_flag
=
True
# For each remaining epoch
timers
(
"interval-time"
).
start
()
for
epoch
in
range
(
start_epoch
,
args
.
epochs
):
print_rank_0
(
"working on epoch {} ..."
.
format
(
epoch
+
1
))
# Set the data loader epoch to shuffle the index iterator.
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
train_dataloader
.
dataset
.
set_epoch
(
epoch
)
# For all the batches in the dataset.
for
iteration_
,
batch
in
enumerate
(
train_dataloader
):
# Ignore the iterations before starting value
if
iteration_
<
start_iteration
:
continue
# Set to zero so the next epoch does not skip any batches.
start_iteration
=
0
# Train for one step.
losses_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
train_step
(
forward_step
,
batch
,
model
,
optimizer
,
opt_param_scheduler
)
iteration
+=
1
# Logging.
params_norm
=
None
report_memory_flag
=
training_log
(
losses_dict
,
losses_dict_sum
,
optimizer
.
param_groups
[
0
][
"lr"
],
iteration
,
optimizer
.
get_loss_scale
().
item
(),
report_memory_flag
,
skipped_iter
,
grad_norm
,
params_norm
,
num_zeros_in_grad
)
# Autoresume
if
args
.
adlr_autoresume
and
\
iteration
%
args
.
adlr_autoresume_interval
==
0
:
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Checkpointing
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
prefix
=
"iteration {}"
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
valid_dataloader
,
model
,
iteration
,
process_non_loss_data_func
,
False
,
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
end_of_epoch_callback
(
model
,
epoch
)
def
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
,
model_type
=
ModelType
.
encoder_or_decoder
,
process_non_loss_data_func
=
None
,
end_of_epoch_callback_provider
=
None
,
):
"""Main finetune function used across all tasks."""
args
=
get_args
()
timers
=
get_timers
()
# Train and validation data loaders.
timers
(
"train/valid/test dataset/dataloder"
).
start
()
if
args
.
epochs
>
0
:
train_dataset
,
valid_dataset
=
train_valid_datasets_provider
()
train_dataloader
,
valid_dataloader
=
_build_train_valid_dataloaders
(
train_dataset
,
valid_dataset
)
timers
(
"train/valid/test dataset/dataloder"
).
stop
()
# Build calback function.
timers
(
"callback function"
).
start
()
end_of_epoch_callback
=
None
if
end_of_epoch_callback_provider
is
not
None
:
end_of_epoch_callback
=
end_of_epoch_callback_provider
()
timers
(
"callback function"
).
stop
()
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
model_provider
,
model_type
,
scale_lr_cond
=
lambda
name
,
param
:
".head."
in
name
,
lr_mult
=
args
.
head_lr_mult
)
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# any iteration (i.e., iteration is zero), then load the pretrained
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
()
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
pretrained_checkpoint_type
==
'default'
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
original_load
elif
args
.
pretrained_checkpoint_type
==
'external'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
elif
args
.
pretrained_checkpoint_type
==
'constrastive'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
state_dict
=
state_dict
[
"model"
]
state_dict
=
{
k
.
replace
(
"teacher.backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
"teacher.backbone."
)}
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
else
:
raise
Exception
(
"pretrained checkpoint type {} not supported"
.
format
(
args
.
pretrained_checkpoint_type
))
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer
.
reload_model_params
()
timers
(
"pretrained checkpoint"
).
stop
()
# Print setup timing.
print_rank_0
(
"done with setups ..."
)
timers
.
log
(
[
"train/valid/test dataset/dataloder"
,
"callback function"
,
"model and optimizer"
,
"pretrained checkpoint"
,
]
)
print_rank_0
(
"training ..."
)
# Finetune the model.
if
args
.
epochs
>
0
:
_train
(
model
,
optimizer
,
opt_param_scheduler
,
forward_step
,
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
process_non_loss_data_func
,
)
# Or just evaluate.
else
:
if
end_of_epoch_callback
is
not
None
:
print_rank_0
(
"evaluation only mode, setting epoch to -1"
)
end_of_epoch_callback
(
model
,
epoch
=-
1
)
print_rank_0
(
"done :-)"
)
tasks/vision/main.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Main tasks functionality."""
import
os
import
sys
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
),
os
.
path
.
pardir
,
)
)
)
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
def
get_tasks_args
(
parser
):
"""Provide extra arguments required for tasks."""
group
=
parser
.
add_argument_group
(
title
=
"tasks"
)
group
.
add_argument
(
'--task'
,
type
=
str
,
default
=
'segment'
,
choices
=
[
'classify'
,
'segment_setr'
,
'segment_segformer'
],
help
=
'task name.'
)
group
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
None
,
help
=
"Number of finetunning epochs. Zero results in "
"evaluation only."
)
group
.
add_argument
(
'--pretrained-checkpoint-type'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'external'
,
'constrastive'
],
help
=
'Type of pretrained checkpoint'
)
group
.
add_argument
(
"--pretrained-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Pretrained checkpoint used for finetunning."
)
group
.
add_argument
(
'--seg-stride'
,
type
=
int
,
default
=
None
,
help
=
'sliding window stride during evaluation'
)
return
parser
if
__name__
==
"__main__"
:
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
if
args
.
task
==
'classify'
:
from
tasks.vision.classification.classification
import
main
main
()
elif
args
.
task
==
'segment_setr'
:
from
tasks.vision.segmentation.finetune_setr
import
main
main
()
elif
args
.
task
==
'segment_segformer'
:
from
tasks.vision.segmentation.finetune_segformer
import
main
main
()
tasks/vision/segmentation/cityscapes.py
0 → 100644
View file @
1016e98a
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
# modified it to change max label index from 255 to 19 (num_classes)
import
torch
import
json
import
os
from
collections
import
namedtuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
import
numpy
as
np
from
torchvision.datasets.utils
import
extract_archive
,
verify_str_arg
,
iterable_to_str
from
torchvision.datasets
import
VisionDataset
from
PIL
import
Image
from
megatron
import
print_rank_0
class
Cityscapes
(
VisionDataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "coarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
"""
num_classes
=
19
ignore_index
=
19
color_table
=
torch
.
tensor
(
[[
128
,
64
,
128
],
[
244
,
35
,
232
],
[
70
,
70
,
70
],
[
102
,
102
,
156
],
[
190
,
153
,
153
],
[
153
,
153
,
153
],
[
250
,
170
,
30
],
[
220
,
220
,
0
],
[
107
,
142
,
35
],
[
152
,
251
,
152
],
[
70
,
130
,
180
],
[
220
,
20
,
60
],
[
255
,
0
,
0
],
[
0
,
0
,
142
],
[
0
,
0
,
70
],
[
0
,
60
,
100
],
[
0
,
80
,
100
],
[
0
,
0
,
230
],
[
119
,
11
,
32
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass
=
namedtuple
(
'CityscapesClass'
,
[
'name'
,
'id'
,
'train_id'
,
'category'
,
'category_id'
,
'has_instances'
,
'ignore_in_eval'
,
'color'
])
classes
=
[
CityscapesClass
(
'unlabeled'
,
0
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'ego vehicle'
,
1
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'rectification border'
,
2
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'out of roi'
,
3
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'static'
,
4
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'dynamic'
,
5
,
19
,
'void'
,
0
,
False
,
True
,
(
111
,
74
,
0
)),
CityscapesClass
(
'ground'
,
6
,
19
,
'void'
,
0
,
False
,
True
,
(
81
,
0
,
81
)),
CityscapesClass
(
'road'
,
7
,
0
,
'flat'
,
1
,
False
,
False
,
(
128
,
64
,
128
)),
CityscapesClass
(
'sidewalk'
,
8
,
1
,
'flat'
,
1
,
False
,
False
,
(
244
,
35
,
232
)),
CityscapesClass
(
'parking'
,
9
,
19
,
'flat'
,
1
,
False
,
True
,
(
250
,
170
,
160
)),
CityscapesClass
(
'rail track'
,
10
,
19
,
'flat'
,
1
,
False
,
True
,
(
230
,
150
,
140
)),
CityscapesClass
(
'building'
,
11
,
2
,
'construction'
,
2
,
False
,
False
,
(
70
,
70
,
70
)),
CityscapesClass
(
'wall'
,
12
,
3
,
'construction'
,
2
,
False
,
False
,
(
102
,
102
,
156
)),
CityscapesClass
(
'fence'
,
13
,
4
,
'construction'
,
2
,
False
,
False
,
(
190
,
153
,
153
)),
CityscapesClass
(
'guard rail'
,
14
,
19
,
'construction'
,
2
,
False
,
True
,
(
180
,
165
,
180
)),
CityscapesClass
(
'bridge'
,
15
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
100
,
100
)),
CityscapesClass
(
'tunnel'
,
16
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
120
,
90
)),
CityscapesClass
(
'pole'
,
17
,
5
,
'object'
,
3
,
False
,
False
,
(
153
,
153
,
153
)),
CityscapesClass
(
'polegroup'
,
18
,
19
,
'object'
,
3
,
False
,
True
,
(
153
,
153
,
153
)),
CityscapesClass
(
'traffic light'
,
19
,
6
,
'object'
,
3
,
False
,
False
,
(
250
,
170
,
30
)),
CityscapesClass
(
'traffic sign'
,
20
,
7
,
'object'
,
3
,
False
,
False
,
(
220
,
220
,
0
)),
CityscapesClass
(
'vegetation'
,
21
,
8
,
'nature'
,
4
,
False
,
False
,
(
107
,
142
,
35
)),
CityscapesClass
(
'terrain'
,
22
,
9
,
'nature'
,
4
,
False
,
False
,
(
152
,
251
,
152
)),
CityscapesClass
(
'sky'
,
23
,
10
,
'sky'
,
5
,
False
,
False
,
(
70
,
130
,
180
)),
CityscapesClass
(
'person'
,
24
,
11
,
'human'
,
6
,
True
,
False
,
(
220
,
20
,
60
)),
CityscapesClass
(
'rider'
,
25
,
12
,
'human'
,
6
,
True
,
False
,
(
255
,
0
,
0
)),
CityscapesClass
(
'car'
,
26
,
13
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
142
)),
CityscapesClass
(
'truck'
,
27
,
14
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
70
)),
CityscapesClass
(
'bus'
,
28
,
15
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
60
,
100
)),
CityscapesClass
(
'caravan'
,
29
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
90
)),
CityscapesClass
(
'trailer'
,
30
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
110
)),
CityscapesClass
(
'train'
,
31
,
16
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
80
,
100
)),
CityscapesClass
(
'motorcycle'
,
32
,
17
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
230
)),
CityscapesClass
(
'bicycle'
,
33
,
18
,
'vehicle'
,
7
,
True
,
False
,
(
119
,
11
,
32
)),
CityscapesClass
(
'license plate'
,
-
1
,
-
1
,
'vehicle'
,
7
,
False
,
True
,
(
0
,
0
,
142
)),
]
# label2trainid
label2trainid
=
{
label
.
id
:
label
.
train_id
for
label
in
classes
}
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
mode
:
str
=
"fine"
,
resolution
:
int
=
1024
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
'gtFine'
if
mode
==
'fine'
else
'gtCoarse'
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
'leftImg8bit_trainvaltest/leftImg8bit'
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
'gtFine_trainvaltest/gtFine'
,
split
)
self
.
split
=
split
self
.
resolution
=
resolution
self
.
images
=
[]
self
.
targets
=
[]
for
city
in
sorted
(
os
.
listdir
(
self
.
images_dir
)):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
target_name
=
'{}_{}_labelIds.png'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
mode
)
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
np
.
array
(
target
)
target_copy
=
target
.
copy
()
for
k
,
v
in
Cityscapes
.
label2trainid
.
items
():
binary_target
=
(
target
==
k
)
target_copy
[
binary_target
]
=
v
target
=
target_copy
target
=
Image
.
fromarray
(
target
.
astype
(
np
.
uint8
))
if
self
.
transforms
is
not
None
:
image
,
target
=
self
.
transforms
(
image
,
target
)
return
image
,
target
def
__len__
(
self
)
->
int
:
# len(self.images)
return
len
(
self
.
images
)
tasks/vision/segmentation/data.py
0 → 100644
View file @
1016e98a
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron.data.autoaugment
import
ImageNetPolicy
from
tasks.vision.segmentation.cityscapes
import
Cityscapes
import
tasks.vision.segmentation.transforms
as
ET
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
class
VitSegmentationJointTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
if
self
.
train
:
self
.
transform0
=
ET
.
RandomSizeAndCrop
(
resolution
)
self
.
transform1
=
ET
.
RandomHorizontallyFlip
()
def
__call__
(
self
,
img
,
mask
):
if
self
.
train
:
img
,
mask
=
self
.
transform0
(
img
,
mask
)
img
,
mask
=
self
.
transform1
(
img
,
mask
)
return
img
,
mask
class
VitSegmentationImageTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
args
=
get_args
()
self
.
train
=
train
assert
args
.
fp16
or
args
.
bf16
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
self
.
mean_std
=
args
.
mean_std
if
self
.
train
:
assert
resolution
is
not
None
self
.
transform
=
T
.
Compose
([
ET
.
PhotoMetricDistortion
(),
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
self
.
transform
=
T
.
Compose
([
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
def
__call__
(
self
,
input
):
output
=
self
.
transform
(
input
)
return
output
class
VitSegmentationTargetTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
def
__call__
(
self
,
input
):
output
=
torch
.
from_numpy
(
np
.
array
(
input
,
dtype
=
np
.
int32
)).
long
()
return
output
class
RandomSeedSegmentationDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
joint_transform
,
image_transform
,
target_transform
):
args
=
get_args
()
self
.
base_seed
=
args
.
seed
self
.
curr_seed
=
self
.
base_seed
self
.
dataset
=
dataset
self
.
joint_transform
=
joint_transform
self
.
image_transform
=
image_transform
self
.
target_transform
=
target_transform
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
set_epoch
(
self
,
epoch
):
self
.
curr_seed
=
self
.
base_seed
+
100
*
epoch
def
__getitem__
(
self
,
idx
):
seed
=
idx
+
self
.
curr_seed
img
,
mask
=
self
.
dataset
[
idx
]
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
img
,
mask
=
self
.
joint_transform
(
img
,
mask
)
img
=
self
.
image_transform
(
img
)
mask
=
self
.
target_transform
(
mask
)
return
img
,
mask
def
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
):
args
=
get_args
()
args
.
num_classes
=
Cityscapes
.
num_classes
args
.
ignore_index
=
Cityscapes
.
ignore_index
args
.
color_table
=
Cityscapes
.
color_table
args
.
mean_std
=
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
])
train_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
True
,
resolution
=
image_size
)
val_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
False
,
resolution
=
image_size
)
train_image_transform
=
\
VitSegmentationImageTransform
(
train
=
True
,
resolution
=
image_size
)
val_image_transform
=
\
VitSegmentationImageTransform
(
train
=
False
,
resolution
=
image_size
)
train_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
True
,
resolution
=
image_size
)
val_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
False
,
resolution
=
image_size
)
# training dataset
train_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'train'
,
mode
=
'fine'
,
resolution
=
image_size
)
train_data
=
RandomSeedSegmentationDataset
(
train_data
,
joint_transform
=
train_joint_transform
,
image_transform
=
train_image_transform
,
target_transform
=
train_target_transform
)
# validation dataset
val_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'val'
,
mode
=
'fine'
,
resolution
=
image_size
)
val_data
=
RandomSeedSegmentationDataset
(
val_data
,
joint_transform
=
val_joint_transform
,
image_transform
=
val_image_transform
,
target_transform
=
val_target_transform
)
return
train_data
,
val_data
def
build_train_valid_datasets
(
data_path
,
image_size
):
return
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
)
tasks/vision/segmentation/finetune_segformer.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SegformerSegmentationModel
from
megatron.model.vision.utils
import
resize
def
calculate_iou
(
hist_data
):
acc
=
np
.
diag
(
hist_data
).
sum
()
/
hist_data
.
sum
()
acc_cls
=
np
.
diag
(
hist_data
)
/
hist_data
.
sum
(
axis
=
1
)
acc_cls
=
np
.
nanmean
(
acc_cls
)
divisor
=
hist_data
.
sum
(
axis
=
1
)
+
hist_data
.
sum
(
axis
=
0
)
-
\
np
.
diag
(
hist_data
)
iu
=
np
.
diag
(
hist_data
)
/
divisor
return
iu
,
acc
,
acc_cls
def
fast_hist
(
pred
,
gtruth
,
num_classes
):
# mask indicates pixels we care about
mask
=
(
gtruth
>=
0
)
&
(
gtruth
<
num_classes
)
# stretch ground truth labels by num_classes
# class 0 -> 0
# class 1 -> 19
# class 18 -> 342
#
# TP at 0 + 0, 1 + 1, 2 + 2 ...
#
# TP exist where value == num_classes*class_id + class_id
# FP = row[class].sum() - TP
# FN = col[class].sum() - TP
hist
=
np
.
bincount
(
num_classes
*
gtruth
[
mask
].
astype
(
int
)
+
pred
[
mask
],
minlength
=
num_classes
**
2
)
hist
=
hist
.
reshape
(
num_classes
,
num_classes
)
return
hist
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
model
=
SegformerSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
print_rank_0
(
"model = {}"
.
format
(
model
))
return
model
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
logits
=
output_tensor
.
contiguous
().
float
()
logits
=
resize
(
logits
,
size
=
masks
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
# Cross-entropy loss.
# weight = calculate_weight(masks, num_classes)
loss
=
F
.
cross_entropy
(
logits
,
masks
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
logits
=
resize
(
logits
,
size
=
labels
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
cpu
().
numpy
()
performs
=
fast_hist
(
preds
.
flatten
(),
labels
.
cpu
().
numpy
().
flatten
(),
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
performs_tensor
=
torch
.
cuda
.
FloatTensor
(
performs
)
torch
.
distributed
.
all_reduce
(
performs_tensor
,
group
=
mpu
.
get_data_parallel_group
())
hist
=
performs_tensor
.
cpu
().
numpy
()
iu
,
acc
,
acc_cls
=
calculate_iou
(
hist
)
miou
=
np
.
nanmean
(
iu
)
return
iu
,
miou
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/finetune_setr.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.metrics
import
CFMatrix
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SetrSegmentationModel
from
tasks.vision.segmentation.utils
import
slidingcrops
,
slidingjoins
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
return
SetrSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
weight
=
calculate_weight
(
masks
,
args
.
num_classes
)
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
masks
,
weight
=
weight
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
if
not
model
.
training
:
images
,
masks
,
_
,
_
=
slidingcrops
(
images
,
masks
)
#print_rank_0("images size = {}".format(images.size()))
if
not
model
.
training
:
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
else
:
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
slices_info
,
img_size
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
int
()
preds
,
labels
=
slidingjoins
(
preds
,
max_probs
,
labels
,
slices_info
,
img_size
)
_
,
performs
=
CFMatrix
()(
preds
,
labels
,
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
args
=
get_args
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
assert
not
model
.
training
images
,
labels
,
slices_info
,
img_size
=
slidingcrops
(
images
,
labels
)
# Forward model.
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
return
output_tensor
,
partial
(
loss_func
,
labels
,
slices_info
,
img_size
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
performs
,
group
=
mpu
.
get_data_parallel_group
())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive
=
performs
[:,
0
]
false_positive
=
performs
[:,
1
]
false_negative
=
performs
[:,
3
]
iou
=
true_positive
/
(
true_positive
+
false_positive
+
false_negative
)
miou
=
iou
[
~
torch
.
isnan
(
iou
)].
mean
()
return
iou
.
tolist
(),
miou
.
item
()
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/metrics.py
0 → 100644
View file @
1016e98a
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
eps
=
1e-6
def
_binarize
(
y_data
,
threshold
):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data
[
y_data
<
threshold
]
=
0.0
y_data
[
y_data
>=
threshold
]
=
1.0
return
y_data
def
_argmax
(
y_data
,
dim
):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return
torch
.
argmax
(
y_data
,
dim
).
int
()
def
_get_tp
(
y_pred
,
y_true
):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return
torch
.
sum
(
y_true
*
y_pred
).
float
()
def
_get_fp
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return
torch
.
sum
((
1
-
y_true
)
*
y_pred
).
float
()
def
_get_tn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return
torch
.
sum
((
1
-
y_true
)
*
(
1
-
y_pred
)).
float
()
def
_get_fn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return
torch
.
sum
(
y_true
*
(
1
-
y_pred
)).
float
()
def
_get_weights
(
y_true
,
nb_ch
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size
,
img_rows
,
img_cols
=
y_true
.
shape
pixels
=
batch_size
*
img_rows
*
img_cols
weights
=
[
torch
.
sum
(
y_true
==
ch
).
item
()
/
pixels
for
ch
in
range
(
nb_ch
)]
return
weights
class
CFMatrix
(
object
):
def
__init__
(
self
,
des
=
None
):
self
.
des
=
des
def
__repr__
(
self
):
return
"ConfusionMatrix"
def
__call__
(
self
,
y_pred
,
y_true
,
ignore_index
,
threshold
=
0.5
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size
,
img_rows
,
img_cols
=
y_pred
.
shape
chs
=
ignore_index
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
[
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
]
performs
=
None
else
:
performs
=
torch
.
zeros
(
chs
,
4
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_false_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_false_ch
[
torch
.
logical_and
((
y_true
!=
ch
),
(
y_true
!=
ignore_index
))]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
torch
.
sum
(
y_false_ch
*
y_pred_ch
).
float
()
nb_tn
=
torch
.
sum
(
y_false_ch
*
(
1
-
y_pred_ch
)).
float
()
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
),
:]
=
torch
.
FloatTensor
([
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
])
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
OAAcc
(
object
):
def
__init__
(
self
,
des
=
"Overall Accuracy"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"OAcc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
nb_tp_tn
=
torch
.
sum
(
y_true
==
y_pred
).
float
()
mperforms
=
nb_tp_tn
/
(
batch_size
*
img_rows
*
img_cols
)
performs
=
None
return
mperforms
,
performs
class
Precision
(
object
):
def
__init__
(
self
,
des
=
"Precision"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Prec"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Recall
(
object
):
def
__init__
(
self
,
des
=
"Recall"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Reca"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
F1Score
(
object
):
def
__init__
(
self
,
des
=
"F1Score"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"F1Sc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
2
*
_precision
*
_recall
/
(
_precision
+
_recall
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
[
int
(
ch
)]
=
2
*
_precision
*
\
_recall
/
(
_precision
+
_recall
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Kappa
(
object
):
def
__init__
(
self
,
des
=
"Kappa"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Kapp"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
mperforms
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_tn
=
_get_tn
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
performs
[
int
(
ch
)]
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Jaccard
(
object
):
def
__init__
(
self
,
des
=
"Jaccard"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Jacc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
_intersec
=
torch
.
sum
(
y_true
*
y_pred
).
float
()
_sum
=
torch
.
sum
(
y_true
+
y_pred
).
float
()
mperforms
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
_intersec
=
torch
.
sum
(
y_true_ch
*
y_pred_ch
).
float
()
_sum
=
torch
.
sum
(
y_true_ch
+
y_pred_ch
).
float
()
performs
[
int
(
ch
)]
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
MSE
(
object
):
def
__init__
(
self
,
des
=
"Mean Square Error"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"MSE"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
return
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
class
PSNR
(
object
):
def
__init__
(
self
,
des
=
"Peak Signal to Noise Ratio"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"PSNR"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
mse
=
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
return
10
*
torch
.
log10
(
1
/
mse
)
class
SSIM
(
object
):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def
__init__
(
self
,
des
=
"structural similarity index"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"SSIM"
def
gaussian
(
self
,
w_size
,
sigma
):
gauss
=
torch
.
Tensor
([
math
.
exp
(
-
(
x
-
w_size
//
2
)
**
2
/
float
(
2
*
sigma
**
2
))
for
x
in
range
(
w_size
)])
return
gauss
/
gauss
.
sum
()
def
create_window
(
self
,
w_size
,
channel
=
1
):
_1D_window
=
self
.
gaussian
(
w_size
,
1.5
).
unsqueeze
(
1
)
_2D_window
=
_1D_window
.
mm
(
_1D_window
.
t
()).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
window
=
_2D_window
.
expand
(
channel
,
1
,
w_size
,
w_size
).
contiguous
()
return
window
def
__call__
(
self
,
y_pred
,
y_true
,
w_size
=
11
,
size_average
=
True
,
full
=
False
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if
torch
.
max
(
y_pred
)
>
128
:
max_val
=
255
else
:
max_val
=
1
if
torch
.
min
(
y_pred
)
<
-
0.5
:
min_val
=
-
1
else
:
min_val
=
0
L
=
max_val
-
min_val
padd
=
0
(
_
,
channel
,
height
,
width
)
=
y_pred
.
size
()
window
=
self
.
create_window
(
w_size
,
channel
=
channel
).
to
(
y_pred
.
device
)
mu1
=
F
.
conv2d
(
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
mu2
=
F
.
conv2d
(
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
mu1_sq
=
mu1
.
pow
(
2
)
mu2_sq
=
mu2
.
pow
(
2
)
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
F
.
conv2d
(
y_pred
*
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_sq
sigma2_sq
=
F
.
conv2d
(
y_true
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu2_sq
sigma12
=
F
.
conv2d
(
y_pred
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_mu2
C1
=
(
0.01
*
L
)
**
2
C2
=
(
0.03
*
L
)
**
2
v1
=
2.0
*
sigma12
+
C2
v2
=
sigma1_sq
+
sigma2_sq
+
C2
cs
=
torch
.
mean
(
v1
/
v2
)
# contrast sensitivity
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
v1
)
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
v2
)
if
size_average
:
ret
=
ssim_map
.
mean
()
else
:
ret
=
ssim_map
.
mean
(
1
).
mean
(
1
).
mean
(
1
)
if
full
:
return
ret
,
cs
return
ret
class
AE
(
object
):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def
__init__
(
self
,
des
=
'average Angular Error'
):
self
.
des
=
des
def
__repr__
(
self
):
return
"AE"
def
__call__
(
self
,
y_pred
,
y_true
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP
=
torch
.
sum
(
y_pred
*
y_true
,
dim
=
1
)
Norm_pred
=
torch
.
sqrt
(
torch
.
sum
(
y_pred
*
y_pred
,
dim
=
1
))
Norm_true
=
torch
.
sqrt
(
torch
.
sum
(
y_true
*
y_true
,
dim
=
1
))
ae
=
180
/
math
.
pi
*
torch
.
acos
(
dotP
/
(
Norm_pred
*
Norm_true
+
eps
))
return
ae
.
mean
(
1
).
mean
(
1
)
if
__name__
==
"__main__"
:
for
ch
in
[
3
,
1
]:
batch_size
,
img_row
,
img_col
=
1
,
224
,
224
y_true
=
torch
.
rand
(
batch_size
,
ch
,
img_row
,
img_col
)
noise
=
torch
.
zeros
(
y_true
.
size
()).
data
.
normal_
(
0
,
std
=
0.1
)
y_pred
=
y_true
+
noise
for
cuda
in
[
False
,
True
]:
if
cuda
:
y_pred
=
y_pred
.
cuda
()
y_true
=
y_true
.
cuda
()
print
(
'#'
*
20
,
'Cuda : {} ; size : {}'
.
format
(
cuda
,
y_true
.
size
()))
########### similarity metrics
metric
=
MSE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
PSNR
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
SSIM
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
LPIPS
(
cuda
)
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
AE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
########### accuracy metrics
metric
=
OAAcc
()
maccu
,
accu
=
metric
(
y_pred
,
y_true
)
print
(
'mAccu:'
,
maccu
,
'Accu'
,
accu
)
metric
=
Precision
()
mprec
,
prec
=
metric
(
y_pred
,
y_true
)
print
(
'mPrec:'
,
mprec
,
'Prec'
,
prec
)
metric
=
Recall
()
mreca
,
reca
=
metric
(
y_pred
,
y_true
)
print
(
'mReca:'
,
mreca
,
'Reca'
,
reca
)
metric
=
F1Score
()
mf1sc
,
f1sc
=
metric
(
y_pred
,
y_true
)
print
(
'mF1sc:'
,
mf1sc
,
'F1sc'
,
f1sc
)
metric
=
Kappa
()
mkapp
,
kapp
=
metric
(
y_pred
,
y_true
)
print
(
'mKapp:'
,
mkapp
,
'Kapp'
,
kapp
)
metric
=
Jaccard
()
mjacc
,
jacc
=
metric
(
y_pred
,
y_true
)
print
(
'mJacc:'
,
mjacc
,
'Jacc'
,
jacc
)
tasks/vision/segmentation/seg_heads.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.utils
import
resize
class
SetrSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
SetrSegmentationHead
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
num_classes
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
patch_dim
=
args
.
patch_dim
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
self
.
conv_0
=
torch
.
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
1
,
1
,
bias
=
False
)
self
.
norm_0
=
apex
.
parallel
.
SyncBatchNorm
(
hidden_size
)
self
.
conv_1
=
torch
.
nn
.
Conv2d
(
hidden_size
,
num_classes
,
1
,
1
)
def
to_2D
(
self
,
x
):
n
,
hw
,
c
=
x
.
shape
h
=
self
.
img_h
//
self
.
patch_dim
w
=
self
.
img_w
//
self
.
patch_dim
assert
(
hw
==
h
*
w
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
n
,
c
,
h
,
w
)
return
x
def
forward
(
self
,
hidden_states
):
# [b c h w]
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
self
.
to_2D
(
hidden_states
)
hidden_states
=
self
.
conv_0
(
hidden_states
)
hidden_states
=
self
.
norm_0
(
hidden_states
)
hidden_states
=
torch
.
tanh
(
hidden_states
)
hidden_states
=
self
.
conv_1
(
hidden_states
)
# [b c h w]
result
=
F
.
interpolate
(
hidden_states
,
size
=
(
self
.
img_h
,
self
.
img_w
),
mode
=
'bilinear'
)
return
result
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
SegformerSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
feature_strides
,
in_channels
,
embedding_dim
,
dropout_ratio
):
super
(
SegformerSegmentationHead
,
self
).
__init__
()
assert
len
(
feature_strides
)
==
len
(
in_channels
)
assert
min
(
feature_strides
)
==
feature_strides
[
0
]
args
=
get_args
()
self
.
feature_strides
=
feature_strides
self
.
in_channels
=
in_channels
self
.
embedding_dim
=
embedding_dim
self
.
num_classes
=
args
.
num_classes
self
.
dropout_ratio
=
dropout_ratio
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
\
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
self
.
dropout_ratio
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
num_classes
,
kernel_size
=
1
)
def
forward
(
self
,
inputs
):
c1
,
c2
,
c3
,
c4
=
inputs
############## MLP decoder on C1-C4 ###########
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
self
.
conv_fuse
(
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
))
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
return
x
tasks/vision/segmentation/seg_models.py
0 → 100644
View file @
1016e98a
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3
,
mit_b5
from
tasks.vision.segmentation.seg_heads
import
SetrSegmentationHead
,
SegformerSegmentationHead
class
SetrSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SetrSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
assert
post_process
&
pre_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
class_token
=
False
,
post_layer_norm
=
False
,
drop_path_rate
=
0.1
)
self
.
head
=
SetrSegmentationHead
(
self
.
hidden_size
,
self
.
num_classes
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
result_final
=
self
.
head
(
hidden_states
)
return
result_final
class
SegformerSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SegformerSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
mit_b5
()
self
.
head
=
SegformerSegmentationHead
(
feature_strides
=
[
4
,
8
,
16
,
32
],
in_channels
=
[
64
,
128
,
320
,
512
],
embedding_dim
=
768
,
dropout_ratio
=
0.1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
tasks/vision/segmentation/transforms.py
0 → 100644
View file @
1016e98a
# Copyright (c) 2020 The MMSegmenation Authors.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
,
ImageEnhance
import
torchvision.transforms
as
torch_tr
def
_is_pil_image
(
img
):
return
isinstance
(
img
,
Image
.
Image
)
class
PhotoMetricDistortion
(
object
):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def
__init__
(
self
,
brightness_delta
=
32
,
contrast_range
=
(
0.5
,
1.5
),
saturation_range
=
(
0.5
,
1.5
),
hue_delta
=
18
):
self
.
brightness_delta
=
brightness_delta
self
.
contrast_lower
,
self
.
contrast_upper
=
contrast_range
self
.
saturation_lower
,
self
.
saturation_upper
=
saturation_range
self
.
hue_delta
=
hue_delta
def
convert
(
self
,
img
,
alpha
=
1
,
beta
=
0
):
"""Multiple with alpha and add beat with clip."""
img
=
img
.
astype
(
np
.
float32
)
*
alpha
+
beta
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
.
astype
(
np
.
uint8
)
def
brightness
(
self
,
img
):
"""Brightness distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
beta
=
random
.
uniform
(
-
self
.
brightness_delta
,
self
.
brightness_delta
))
return
img
def
contrast
(
self
,
img
):
"""Contrast distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
alpha
=
random
.
uniform
(
self
.
contrast_lower
,
self
.
contrast_upper
))
return
img
def
saturation
(
self
,
img
):
"""Saturation distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
1
]
=
self
.
convert
(
img
[:,
:,
1
],
alpha
=
random
.
uniform
(
self
.
saturation_lower
,
self
.
saturation_upper
))
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
hue
(
self
,
img
):
"""Hue distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
0
]
=
(
img
[:,
:,
0
].
astype
(
int
)
+
random
.
randint
(
-
self
.
hue_delta
,
self
.
hue_delta
))
%
180
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
__call__
(
self
,
img
):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img
=
np
.
array
(
img
)
# random brightness
img
=
self
.
brightness
(
img
)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode
=
random
.
randint
(
0
,
1
)
if
mode
==
1
:
img
=
self
.
contrast
(
img
)
# random saturation
img
=
self
.
saturation
(
img
)
# random hue
img
=
self
.
hue
(
img
)
# random contrast
if
mode
==
0
:
img
=
self
.
contrast
(
img
)
img
=
Image
.
fromarray
(
img
.
astype
(
np
.
uint8
)).
convert
(
'RGB'
)
return
img
class
RandomCrop
(
object
):
"""
Take a random crop from the image.
First the image or crop size may need to be adjusted if the incoming image
is too small...
If the image is smaller than the crop, then:
the image is padded up to the size of the crop
unless 'nopad', in which case the crop size is shrunk to fit the image
A random crop is taken such that the crop fits within the image.
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
translation randomness of at least that value around the image.
if image < crop_size:
# slide crop within image, random offset
else:
# slide image within crop
"""
def
__init__
(
self
,
crop_size
):
args
=
get_args
()
self
.
size
=
crop_size
self
.
cat_max_ratio
=
0.75
self
.
ignore_index
=
args
.
ignore_index
self
.
pad_color
=
(
0
,
0
,
0
)
def
get_crop_bbox
(
self
,
img
):
"""Randomly get a crop bounding box."""
img_w
,
img_h
=
img
.
size
target_h
,
target_w
=
self
.
size
#[H W]
margin_h
=
max
(
img_h
-
target_h
,
0
)
margin_w
=
max
(
img_w
-
target_w
,
0
)
offset_h
=
random
.
randint
(
0
,
margin_h
)
offset_w
=
random
.
randint
(
0
,
margin_w
)
crop_y1
,
crop_y2
=
offset_h
,
offset_h
+
target_h
crop_x1
,
crop_x2
=
offset_w
,
offset_w
+
target_w
return
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
def
crop
(
self
,
img
,
crop_bbox
):
"""Crop from ``img``"""
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
=
crop_bbox
img
=
img
.
crop
((
crop_x1
,
crop_y1
,
crop_x2
,
crop_y2
))
return
img
@
staticmethod
def
crop_in_image
(
target_w
,
target_h
,
w
,
h
,
img
,
mask
):
if
w
==
target_w
:
x1
=
0
else
:
x1
=
random
.
randint
(
0
,
w
-
target_w
)
if
h
==
target_h
:
y1
=
0
else
:
y1
=
random
.
randint
(
0
,
h
-
target_h
)
return
[
img
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
)),
mask
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
))]
def
__call__
(
self
,
img
,
mask
):
w
,
h
=
img
.
size
target_h
,
target_w
=
self
.
size
# ASSUME H, W
if
w
==
target_w
and
h
==
target_h
:
return
img
,
mask
# Pad image if image < crop
if
target_h
>
h
:
pad_h
=
(
target_h
-
h
)
//
2
+
1
else
:
pad_h
=
0
if
target_w
>
w
:
pad_w
=
(
target_w
-
w
)
//
2
+
1
else
:
pad_w
=
0
border
=
(
pad_w
,
pad_h
,
pad_w
,
pad_h
)
if
pad_h
or
pad_w
:
img
=
ImageOps
.
expand
(
img
,
border
=
border
,
fill
=
(
0
,
0
,
0
))
mask
=
ImageOps
.
expand
(
mask
,
border
=
border
,
fill
=
self
.
ignore_index
)
w
,
h
=
img
.
size
crop_bbox
=
self
.
get_crop_bbox
(
img
)
if
self
.
cat_max_ratio
<
1.
:
# Repeat 10 times
for
_
in
range
(
10
):
seg_temp
=
self
.
crop
(
mask
,
crop_bbox
)
labels
,
cnt
=
np
.
unique
(
seg_temp
,
return_counts
=
True
)
cnt
=
cnt
[
labels
!=
self
.
ignore_index
]
if
len
(
cnt
)
>
1
and
np
.
max
(
cnt
)
/
np
.
sum
(
cnt
)
<
self
.
cat_max_ratio
:
break
crop_bbox
=
self
.
get_crop_bbox
(
img
)
# crop the image
img
=
self
.
crop
(
img
,
crop_bbox
)
# crop semantic seg
mask
=
self
.
crop
(
mask
,
crop_bbox
)
assert
(
img
.
size
[
0
]
==
self
.
size
[
1
]
and
img
.
size
[
1
]
==
self
.
size
[
0
])
return
img
,
mask
class
RandomSizeAndCrop
(
object
):
def
__init__
(
self
,
crop_size
,
scale_min
=
0.5
,
scale_max
=
2.0
):
self
.
crop
=
RandomCrop
(
crop_size
)
self
.
scale_min
=
scale_min
self
.
scale_max
=
scale_max
def
__call__
(
self
,
img
,
mask
):
scale_amt
=
random
.
uniform
(
self
.
scale_min
,
self
.
scale_max
)
w
,
h
=
[
int
(
i
*
scale_amt
)
for
i
in
img
.
size
]
resized_img
=
img
.
resize
((
w
,
h
),
Image
.
BICUBIC
)
resized_mask
=
mask
.
resize
((
w
,
h
),
Image
.
NEAREST
)
img
,
mask
=
self
.
crop
(
resized_img
,
resized_mask
)
return
img
,
mask
class
RandomHorizontallyFlip
(
object
):
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
),
mask
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
,
mask
def
adjust_brightness
(
img
,
brightness_factor
):
"""Adjust brightness of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Brightness
(
img
)
img
=
enhancer
.
enhance
(
brightness_factor
)
return
img
def
adjust_contrast
(
img
,
contrast_factor
):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Contrast
(
img
)
img
=
enhancer
.
enhance
(
contrast_factor
)
return
img
def
adjust_saturation
(
img
,
saturation_factor
):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Color
(
img
)
img
=
enhancer
.
enhance
(
saturation_factor
)
return
img
def
adjust_hue
(
img
,
hue_factor
):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
'hue_factor is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
input_mode
=
img
.
mode
if
input_mode
in
{
'L'
,
'1'
,
'I'
,
'F'
}:
return
img
h
,
s
,
v
=
img
.
convert
(
'HSV'
).
split
()
np_h
=
np
.
array
(
h
,
dtype
=
np
.
uint8
)
# uint8 addition take cares of rotation across boundaries
with
np
.
errstate
(
over
=
'ignore'
):
np_h
+=
np
.
uint8
(
hue_factor
*
255
)
h
=
Image
.
fromarray
(
np_h
,
'L'
)
img
=
Image
.
merge
(
'HSV'
,
(
h
,
s
,
v
)).
convert
(
input_mode
)
return
img
class
ColorJitter
(
object
):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def
__init__
(
self
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
self
.
brightness
=
brightness
self
.
contrast
=
contrast
self
.
saturation
=
saturation
self
.
hue
=
hue
@
staticmethod
def
get_params
(
brightness
,
contrast
,
saturation
,
hue
):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms
=
[]
if
brightness
>
0
:
brightness_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
brightness
),
1
+
brightness
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_brightness
(
img
,
brightness_factor
)))
if
contrast
>
0
:
contrast_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
contrast
),
1
+
contrast
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_contrast
(
img
,
contrast_factor
)))
if
saturation
>
0
:
saturation_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
saturation
),
1
+
saturation
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_saturation
(
img
,
saturation_factor
)))
if
hue
>
0
:
hue_factor
=
np
.
random
.
uniform
(
-
hue
,
hue
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_hue
(
img
,
hue_factor
)))
np
.
random
.
shuffle
(
transforms
)
transform
=
torch_tr
.
Compose
(
transforms
)
return
transform
def
__call__
(
self
,
img
):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform
=
self
.
get_params
(
self
.
brightness
,
self
.
contrast
,
self
.
saturation
,
self
.
hue
)
return
transform
(
img
)
Prev
1
…
7
8
9
10
11
12
13
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment