Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
2eea6216
Commit
2eea6216
authored
Jul 18, 2022
by
rprenger
Browse files
Merging with main and fixing merge conflict
parents
ed6806ac
5f694372
Changes
63
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3163 additions
and
158 deletions
+3163
-158
megatron/text_generation_server.py
megatron/text_generation_server.py
+57
-6
megatron/training.py
megatron/training.py
+26
-53
megatron/utils.py
megatron/utils.py
+19
-1
tasks/finetune_utils.py
tasks/finetune_utils.py
+1
-1
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+46
-7
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+3
-5
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+60
-63
tasks/vision/main.py
tasks/vision/main.py
+24
-22
tasks/vision/segmentation/cityscapes.py
tasks/vision/segmentation/cityscapes.py
+207
-0
tasks/vision/segmentation/data.py
tasks/vision/segmentation/data.py
+154
-0
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+251
-0
tasks/vision/segmentation/finetune_setr.py
tasks/vision/segmentation/finetune_setr.py
+225
-0
tasks/vision/segmentation/metrics.py
tasks/vision/segmentation/metrics.py
+594
-0
tasks/vision/segmentation/seg_heads.py
tasks/vision/segmentation/seg_heads.py
+140
-0
tasks/vision/segmentation/seg_models.py
tasks/vision/segmentation/seg_models.py
+92
-0
tasks/vision/segmentation/transforms.py
tasks/vision/segmentation/transforms.py
+433
-0
tasks/vision/segmentation/utils.py
tasks/vision/segmentation/utils.py
+85
-0
tools/checkpoint_loader_megatron.py
tools/checkpoint_loader_megatron.py
+273
-0
tools/checkpoint_saver_megatron.py
tools/checkpoint_saver_megatron.py
+322
-0
tools/checkpoint_util.py
tools/checkpoint_util.py
+151
-0
No files found.
megatron/text_generation_server.py
View file @
2eea6216
...
@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
...
@@ -20,9 +20,11 @@ from flask import Flask, request, jsonify, current_app
from
flask_restful
import
Resource
,
Api
from
flask_restful
import
Resource
,
Api
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
generate_and_post_process
from
megatron.text_generation
import
beam_search_and_post_process
GENERATE_NUM
=
0
GENERATE_NUM
=
0
BEAM_NUM
=
1
lock
=
threading
.
Lock
()
lock
=
threading
.
Lock
()
class
MegatronGenerate
(
Resource
):
class
MegatronGenerate
(
Resource
):
...
@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
...
@@ -34,6 +36,11 @@ class MegatronGenerate(Resource):
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
choice
=
torch
.
cuda
.
LongTensor
([
GENERATE_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
torch
.
distributed
.
broadcast
(
choice
,
0
)
@
staticmethod
def
send_do_beam_search
():
choice
=
torch
.
cuda
.
LongTensor
([
BEAM_NUM
])
torch
.
distributed
.
broadcast
(
choice
,
0
)
def
put
(
self
):
def
put
(
self
):
args
=
get_args
()
args
=
get_args
()
...
@@ -148,15 +155,57 @@ class MegatronGenerate(Resource):
...
@@ -148,15 +155,57 @@ class MegatronGenerate(Resource):
if
not
isinstance
(
no_log
,
bool
):
if
not
isinstance
(
no_log
,
bool
):
return
"no_log must be a boolean value"
return
"no_log must be a boolean value"
beam_width
=
None
if
"beam_width"
in
request
.
get_json
():
beam_width
=
request
.
get_json
()[
"beam_width"
]
if
not
isinstance
(
beam_width
,
int
):
return
"beam_width must be integer"
if
beam_width
<
1
:
return
"beam_width must be an integer > 1"
if
len
(
prompts
)
>
1
:
return
"When doing beam_search, batch size must be 1"
stop_token
=
50256
if
"stop_token"
in
request
.
get_json
():
stop_token
=
request
.
get_json
()[
"stop_token"
]
if
not
isinstance
(
stop_token
,
int
):
return
"stop_token must be an integer"
length_penalty
=
1
if
"length_penalty"
in
request
.
get_json
():
length_penalty
=
request
.
get_json
()[
"length_penalty"
]
if
not
isinstance
(
length_penalty
,
float
):
return
"length_penalty must be a float"
with
lock
:
# Need to get lock to keep multiple threads from hitting code
with
lock
:
# Need to get lock to keep multiple threads from hitting code
if
not
no_log
:
if
not
no_log
:
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"start time: "
,
datetime
.
datetime
.
now
())
print
(
"start time: "
,
datetime
.
datetime
.
now
())
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
try
:
try
:
response
,
response_seg
,
response_logprobs
,
_
=
\
if
beam_width
is
not
None
:
generate_and_post_process
(
MegatronGenerate
.
send_do_beam_search
()
# Tell other ranks we're doing beam_search
response
,
response_seg
,
response_scores
=
\
beam_search_and_post_process
(
self
.
model
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
beam_size
=
beam_width
,
add_BOS
=
add_BOS
,
stop_token
=
stop_token
,
num_return_gen
=
beam_width
,
# Returning whole beam
length_penalty
=
length_penalty
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"scores"
:
response_scores
})
else
:
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
response
,
response_seg
,
response_logprobs
,
_
=
\
generate_and_post_process
(
self
.
model
,
self
.
model
,
prompts
=
prompts
,
prompts
=
prompts
,
tokens_to_generate
=
tokens_to_generate
,
tokens_to_generate
=
tokens_to_generate
,
...
@@ -171,13 +220,15 @@ class MegatronGenerate(Resource):
...
@@ -171,13 +220,15 @@ class MegatronGenerate(Resource):
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_double_eol
=
stop_on_double_eol
,
stop_on_eol
=
stop_on_eol
,
stop_on_eol
=
stop_on_eol
,
random_seed
=
random_seed
)
random_seed
=
random_seed
)
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
except
ValueError
as
ve
:
except
ValueError
as
ve
:
return
"Length of prompt + tokens_to_generate longer than allowed"
return
"Length of prompt + tokens_to_generate longer than allowed"
print
(
"end time: "
,
datetime
.
datetime
.
now
())
print
(
"end time: "
,
datetime
.
datetime
.
now
())
return
jsonify
({
"text"
:
response
,
"segments"
:
response_seg
,
"logprobs"
:
response_logprobs
})
class
MegatronServer
(
object
):
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
...
...
megatron/training.py
View file @
2eea6216
...
@@ -42,6 +42,7 @@ from megatron.model import ModelType
...
@@ -42,6 +42,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
set_jit_fusion_options
from
megatron.optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
...
@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
...
@@ -99,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
args_defaults
=
args_defaults
)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options
()
# Adjust the startup time so it reflects the largest value.
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
# This will be closer to what scheduler will see (outside of
...
@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
...
@@ -361,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args
=
get_args
()
args
=
get_args
()
model
=
get_model
(
model_provider_func
,
model_type
)
model
=
get_model
(
model_provider_func
,
model_type
)
unwrapped_model
=
unwrap_model
(
model
,
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
optimizer
=
get_megatron_optimizer
(
model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
opt_param_scheduler
=
get_optimizer_param_scheduler
(
optimizer
)
opt_param_scheduler
=
get_optimizer_param_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
...
@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
...
@@ -409,78 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition
.
zero_grad_buffer
()
partition
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
# Forward pass.
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
optimizer
,
timers
,
forward_only
=
False
)
# Empty unused memory
# Empty unused memory
.
if
args
.
empty_unused_memory_level
>=
1
:
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
# All-reduce if needed.
# Reduce gradients.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-reduce-model-grads'
).
start
()
timers
(
'backward-params-all-reduce'
).
start
()
optimizer
.
reduce_model_grads
(
args
,
timers
)
for
model_module
in
model
:
timers
(
'backward-reduce-model-grads'
).
stop
()
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
# Update parameters.
timers
(
'optimizer'
).
start
()
timers
(
'optimizer'
).
start
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
)
timers
(
'optimizer'
).
stop
()
timers
(
'optimizer'
).
stop
()
# Gather params.
if
update_successful
:
timers
(
'backward-gather-model-params'
).
start
()
optimizer
.
gather_model_params
(
args
,
timers
)
timers
(
'backward-gather-model-params'
).
stop
()
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
# Update learning rate.
if
update_successful
:
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
increment
=
get_num_microbatches
()
*
\
...
@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
...
@@ -491,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else
:
else
:
skipped_iter
=
1
skipped_iter
=
1
# Empty unused memory
# Empty unused memory
.
if
args
.
empty_unused_memory_level
>=
2
:
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
...
@@ -558,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-layernorm-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-reduce-model-grads'
)
add_to_logging
(
'backward-gather-model-params'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
add_to_logging
(
'optimizer-clip-main-grad'
)
add_to_logging
(
'optimizer-clip-main-grad'
)
add_to_logging
(
'optimizer-count-zeros'
)
add_to_logging
(
'optimizer-inner-step'
)
add_to_logging
(
'optimizer-copy-main-to-model-params'
)
add_to_logging
(
'optimizer-copy-main-to-model-params'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch-generator'
)
add_to_logging
(
'batch-generator'
)
...
...
megatron/utils.py
View file @
2eea6216
...
@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
...
@@ -24,7 +24,6 @@ from apex.multi_tensor_apply import multi_tensor_applier
import
amp_C
import
amp_C
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_adlr_autoresume
from
megatron
import
get_adlr_autoresume
from
megatron
import
mpu
from
megatron
import
mpu
from
megatron.model.module
import
param_is_not_shared
from
megatron.model.module
import
param_is_not_shared
...
@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
...
@@ -204,3 +203,22 @@ def get_ltor_masks_and_position_ids(data,
return
attention_mask
,
loss_mask
,
position_ids
return
attention_mask
,
loss_mask
,
position_ids
def
print_rank_0
(
message
):
"""If distributed is initialized, print only on rank 0."""
if
torch
.
distributed
.
is_initialized
():
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
def
is_last_rank
():
return
torch
.
distributed
.
get_rank
()
==
(
torch
.
distributed
.
get_world_size
()
-
1
)
def
print_rank_last
(
message
):
"""If distributed is initialized, print only on last rank."""
if
torch
.
distributed
.
is_initialized
():
if
is_last_rank
():
print
(
message
,
flush
=
True
)
else
:
print
(
message
,
flush
=
True
)
tasks/finetune_utils.py
View file @
2eea6216
...
@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
...
@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix
=
'iteration {}'
.
format
(
iteration
)
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
evaluate_and_print_results
(
prefix
,
forward_step
,
valid_dataloader
,
model
,
valid_dataloader
,
model
,
iteration
,
False
)
iteration
,
None
,
False
)
# Exiting based on iterations
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
...
...
tasks/vision/classification/classification.py
View file @
2eea6216
...
@@ -15,12 +15,15 @@
...
@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation."""
"""Vision-classification finetuning/evaluation."""
from
megatron
import
get_args
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.model.vi
t_model
import
Vit
Model
from
megatron.model.vi
sion.classification
import
VitClassification
Model
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
tasks.vision.eval_utils
import
accuracy_func_provider
from
tasks.vision.
classification.
eval_utils
import
accuracy_func_provider
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
finetune
from
megatron.utils
import
average_losses_across_data_parallel_group
def
classification
():
def
classification
():
...
@@ -30,7 +33,7 @@ def classification():
...
@@ -30,7 +33,7 @@ def classification():
train_ds
,
valid_ds
=
build_train_valid_datasets
(
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
data_path
=
args
.
data_path
,
crop
_size
=
args
.
img_
dim
,
image
_size
=
(
args
.
img_
h
,
args
.
img_w
)
,
)
)
return
train_ds
,
valid_ds
return
train_ds
,
valid_ds
...
@@ -40,16 +43,52 @@ def classification():
...
@@ -40,16 +43,52 @@ def classification():
print_rank_0
(
"building classification model for ImageNet ..."
)
print_rank_0
(
"building classification model for ImageNet ..."
)
return
VitModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
return
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
"""Finetune/evaluate."""
"""Finetune/evaluate."""
finetune
(
finetune
(
train_valid_datasets_provider
,
train_valid_datasets_provider
,
model_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
)
def
main
():
def
main
():
classification
()
classification
()
tasks/vision/classification/eval_utils.py
View file @
2eea6216
...
@@ -33,11 +33,10 @@ def accuracy_func_provider():
...
@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies."""
"""Provide function that calculates accuracies."""
args
=
get_args
()
args
=
get_args
()
data_path
=
args
.
data_path
data_path
=
args
.
data_path
crop_size
=
args
.
img_
dim
crop_size
=
(
args
.
img_
h
,
args
.
img_w
)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
# Build dataloaders.
val_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"val"
)
val_data_path
=
data_path
[
1
]
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
transform_val
=
transforms
.
Compose
(
transform_val
=
transforms
.
Compose
(
[
[
...
@@ -54,6 +53,7 @@ def accuracy_func_provider():
...
@@ -54,6 +53,7 @@ def accuracy_func_provider():
args
.
micro_batch_size
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
)
def
metrics_func
(
model
,
epoch
):
def
metrics_func
(
model
,
epoch
):
...
@@ -71,7 +71,6 @@ def accuracy_func_provider():
...
@@ -71,7 +71,6 @@ def accuracy_func_provider():
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
"""Calculate correct over total answers"""
args
=
get_args
()
forward_backward_func
=
get_forward_backward_func
()
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
for
m
in
model
:
m
.
eval
()
m
.
eval
()
...
@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
...
@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images
,
labels
=
process_batch
(
batch_
)
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
# Forward model.
args
=
get_args
()
output_tensor
=
model
(
images
)
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
...
...
tasks/vision/finetune_utils.py
View file @
2eea6216
...
@@ -17,11 +17,10 @@
...
@@ -17,11 +17,10 @@
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
,
utils
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
from
megatron.training
import
evaluate_and_print_results
...
@@ -29,7 +28,10 @@ from megatron.training import setup_model_and_optimizer
...
@@ -29,7 +28,10 @@ from megatron.training import setup_model_and_optimizer
from
megatron.training
import
train_step
from
megatron.training
import
train_step
from
megatron.training
import
training_log
from
megatron.training
import
training_log
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
,
print_params_min_max_norm
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
,
ModelType
def
process_batch
(
batch
):
def
process_batch
(
batch
):
...
@@ -39,45 +41,16 @@ def process_batch(batch):
...
@@ -39,45 +41,16 @@ def process_batch(batch):
return
images
,
labels
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
logits
=
output_tensor
num_workers
,
drop_last
,
shuffle
):
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
# Sampler.
world_size
=
mpu
.
get_data_parallel_world_size
()
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
)
# Data loader. Note that batch size is the per GPU batch size.
# Data loader. Note that batch size is the per GPU batch size.
...
@@ -112,14 +85,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -112,14 +85,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0
(
'building train and validation dataloaders ...'
)
print_rank_0
(
'building train and validation dataloaders ...'
)
# Training dataset.
# Training dataset.
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
False
,
True
)
# Set the training iterations.
# Set the training iterations.
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
# shuffling so we can just use a simple infinite loop.
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
True
,
False
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
# Now that we've built the data loaders, set batch_size arguments
# Now that we've built the data loaders, set batch_size arguments
...
@@ -132,6 +105,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
...
@@ -132,6 +105,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
return
train_dataloader
,
valid_dataloader
return
train_dataloader
,
valid_dataloader
def
_train
(
def
_train
(
model
,
model
,
optimizer
,
optimizer
,
...
@@ -140,6 +114,7 @@ def _train(
...
@@ -140,6 +114,7 @@ def _train(
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
end_of_epoch_callback
,
process_non_loss_data_func
=
None
):
):
"""Train the model."""
"""Train the model."""
args
=
get_args
()
args
=
get_args
()
...
@@ -167,6 +142,7 @@ def _train(
...
@@ -167,6 +142,7 @@ def _train(
# Set the data loader epoch to shuffle the index iterator.
# Set the data loader epoch to shuffle the index iterator.
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
train_dataloader
.
dataset
.
set_epoch
(
epoch
)
# For all the batches in the dataset.
# For all the batches in the dataset.
for
iteration_
,
batch
in
enumerate
(
train_dataloader
):
for
iteration_
,
batch
in
enumerate
(
train_dataloader
):
...
@@ -185,8 +161,6 @@ def _train(
...
@@ -185,8 +161,6 @@ def _train(
# Logging.
# Logging.
params_norm
=
None
params_norm
=
None
if
args
.
log_params_norm
:
params_norm
=
calc_params_l2_norm
(
model
)
report_memory_flag
=
training_log
(
report_memory_flag
=
training_log
(
losses_dict
,
losses_dict
,
...
@@ -202,20 +176,16 @@ def _train(
...
@@ -202,20 +176,16 @@ def _train(
)
)
# Autoresume
# Autoresume
if
args
.
adlr_autoresume
and
(
if
args
.
adlr_autoresume
and
\
iteration
%
args
.
adlr_autoresume_interval
==
0
iteration
%
args
.
adlr_autoresume_interval
==
0
:
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
check_adlr_autoresume_termination
(
opt_param_scheduler
)
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Checkpointing
# Checkpointing
if
(
if
args
.
save
and
args
.
save_interval
and
\
args
.
save
iteration
%
args
.
save_interval
==
0
:
and
args
.
save_interval
save_checkpoint
(
iteration
,
model
,
optimizer
,
and
iteration
%
args
.
save_interval
==
0
opt_param_scheduler
)
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Evaluation
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
@@ -226,13 +196,10 @@ def _train(
...
@@ -226,13 +196,10 @@ def _train(
valid_dataloader
,
valid_dataloader
,
model
,
model
,
iteration
,
iteration
,
process_non_loss_data_func
,
False
,
False
,
)
)
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Callback at the end of each epoch.
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
if
end_of_epoch_callback
is
not
None
:
end_of_epoch_callback
(
model
,
epoch
)
end_of_epoch_callback
(
model
,
epoch
)
...
@@ -241,7 +208,9 @@ def _train(
...
@@ -241,7 +208,9 @@ def _train(
def
finetune
(
def
finetune
(
train_valid_datasets_provider
,
train_valid_datasets_provider
,
model_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
forward_step
,
model_type
=
ModelType
.
encoder_or_decoder
,
process_non_loss_data_func
=
None
,
end_of_epoch_callback_provider
=
None
,
end_of_epoch_callback_provider
=
None
,
):
):
"""Main finetune function used across all tasks."""
"""Main finetune function used across all tasks."""
...
@@ -266,7 +235,12 @@ def finetune(
...
@@ -266,7 +235,12 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
model_provider
,
model_type
,
scale_lr_cond
=
lambda
name
,
param
:
".head."
in
name
,
lr_mult
=
args
.
head_lr_mult
)
timers
(
"model and optimizer"
).
stop
()
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
# If pretrained checkpoint is provided and we have not trained for
...
@@ -274,13 +248,34 @@ def finetune(
...
@@ -274,13 +248,34 @@ def finetune(
# checkpoint.
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
()
timers
(
"pretrained checkpoint"
).
start
()
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
if
args
.
pretrained_checkpoint_type
==
'default'
:
args
.
load
=
args
.
pretrained_checkpoint
original_load
=
args
.
load
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
args
.
pretrained_checkpoint
args
.
load
=
original_load
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
original_load
elif
args
.
pretrained_checkpoint_type
==
'external'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
elif
args
.
pretrained_checkpoint_type
==
'constrastive'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
state_dict
=
state_dict
[
"model"
]
state_dict
=
{
k
.
replace
(
"teacher.backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
"teacher.backbone."
)}
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
else
:
raise
Exception
(
"pretrained checkpoint type {} not supported"
.
format
(
args
.
pretrained_checkpoint_type
))
# This is critical when only model is loaded. We should make sure
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
# master parameters are also updated.
optimizer
.
reload_model_params
()
optimizer
.
reload_model_params
()
timers
(
"pretrained checkpoint"
).
stop
()
timers
(
"pretrained checkpoint"
).
stop
()
# Print setup timing.
# Print setup timing.
...
@@ -305,11 +300,13 @@ def finetune(
...
@@ -305,11 +300,13 @@ def finetune(
train_dataloader
,
train_dataloader
,
valid_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
end_of_epoch_callback
,
process_non_loss_data_func
,
)
)
# Or just evaluate.
# Or just evaluate.
else
:
else
:
if
end_of_epoch_callback
is
not
None
:
if
end_of_epoch_callback
is
not
None
:
print_rank_0
(
"evaluation only mode, setting epoch to -1"
)
print_rank_0
(
"evaluation only mode, setting epoch to -1"
)
end_of_epoch_callback
(
model
,
epoch
=-
1
,
output_predictions
=
True
)
end_of_epoch_callback
(
model
,
epoch
=-
1
)
print_rank_0
(
"done :-)"
)
print_rank_0
(
"done :-)"
)
tasks/vision/main.py
View file @
2eea6216
...
@@ -28,32 +28,24 @@ sys.path.append(
...
@@ -28,32 +28,24 @@ sys.path.append(
)
)
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
initialize_megatron
from
classification
import
main
def
get_tasks_args
(
parser
):
def
get_tasks_args
(
parser
):
"""Provide extra arguments required for tasks."""
"""Provide extra arguments required for tasks."""
group
=
parser
.
add_argument_group
(
title
=
"tasks"
)
group
=
parser
.
add_argument_group
(
title
=
"tasks"
)
group
.
add_argument
(
group
.
add_argument
(
'--task'
,
type
=
str
,
default
=
'segment'
,
"--epochs"
,
choices
=
[
'classify'
,
'segment_setr'
,
'segment_segformer'
],
type
=
int
,
help
=
'task name.'
)
default
=
None
,
group
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
None
,
help
=
"Number of finetunning epochs. Zero results in "
help
=
"Number of finetunning epochs. Zero results in "
"evaluation only."
,
"evaluation only."
)
)
group
.
add_argument
(
'--pretrained-checkpoint-type'
,
type
=
str
,
default
=
'default'
,
group
.
add_argument
(
choices
=
[
'default'
,
'external'
,
'constrastive'
],
"--pretrained-checkpoint"
,
help
=
'Type of pretrained checkpoint'
)
type
=
str
,
group
.
add_argument
(
"--pretrained-checkpoint"
,
type
=
str
,
default
=
None
,
default
=
None
,
help
=
"Pretrained checkpoint used for finetunning."
)
help
=
"Pretrained checkpoint used for finetunning."
,
group
.
add_argument
(
'--seg-stride'
,
type
=
int
,
default
=
None
,
)
help
=
'sliding window stride during evaluation'
)
group
.
add_argument
(
"--keep-last"
,
action
=
"store_true"
,
help
=
"Keep the last batch (maybe incomplete) in"
"the data loader"
,
)
return
parser
return
parser
...
@@ -61,4 +53,14 @@ if __name__ == "__main__":
...
@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
args
=
get_args
()
main
()
if
args
.
task
==
'classify'
:
from
tasks.vision.classification.classification
import
main
main
()
elif
args
.
task
==
'segment_setr'
:
from
tasks.vision.segmentation.finetune_setr
import
main
main
()
elif
args
.
task
==
'segment_segformer'
:
from
tasks.vision.segmentation.finetune_segformer
import
main
main
()
tasks/vision/segmentation/cityscapes.py
0 → 100644
View file @
2eea6216
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
# modified it to change max label index from 255 to 19 (num_classes)
import
torch
import
json
import
os
from
collections
import
namedtuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
import
numpy
as
np
from
torchvision.datasets.utils
import
extract_archive
,
verify_str_arg
,
iterable_to_str
from
torchvision.datasets
import
VisionDataset
from
PIL
import
Image
from
megatron
import
print_rank_0
class
Cityscapes
(
VisionDataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "coarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
"""
num_classes
=
19
ignore_index
=
19
color_table
=
torch
.
tensor
(
[[
128
,
64
,
128
],
[
244
,
35
,
232
],
[
70
,
70
,
70
],
[
102
,
102
,
156
],
[
190
,
153
,
153
],
[
153
,
153
,
153
],
[
250
,
170
,
30
],
[
220
,
220
,
0
],
[
107
,
142
,
35
],
[
152
,
251
,
152
],
[
70
,
130
,
180
],
[
220
,
20
,
60
],
[
255
,
0
,
0
],
[
0
,
0
,
142
],
[
0
,
0
,
70
],
[
0
,
60
,
100
],
[
0
,
80
,
100
],
[
0
,
0
,
230
],
[
119
,
11
,
32
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass
=
namedtuple
(
'CityscapesClass'
,
[
'name'
,
'id'
,
'train_id'
,
'category'
,
'category_id'
,
'has_instances'
,
'ignore_in_eval'
,
'color'
])
classes
=
[
CityscapesClass
(
'unlabeled'
,
0
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'ego vehicle'
,
1
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'rectification border'
,
2
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'out of roi'
,
3
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'static'
,
4
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'dynamic'
,
5
,
19
,
'void'
,
0
,
False
,
True
,
(
111
,
74
,
0
)),
CityscapesClass
(
'ground'
,
6
,
19
,
'void'
,
0
,
False
,
True
,
(
81
,
0
,
81
)),
CityscapesClass
(
'road'
,
7
,
0
,
'flat'
,
1
,
False
,
False
,
(
128
,
64
,
128
)),
CityscapesClass
(
'sidewalk'
,
8
,
1
,
'flat'
,
1
,
False
,
False
,
(
244
,
35
,
232
)),
CityscapesClass
(
'parking'
,
9
,
19
,
'flat'
,
1
,
False
,
True
,
(
250
,
170
,
160
)),
CityscapesClass
(
'rail track'
,
10
,
19
,
'flat'
,
1
,
False
,
True
,
(
230
,
150
,
140
)),
CityscapesClass
(
'building'
,
11
,
2
,
'construction'
,
2
,
False
,
False
,
(
70
,
70
,
70
)),
CityscapesClass
(
'wall'
,
12
,
3
,
'construction'
,
2
,
False
,
False
,
(
102
,
102
,
156
)),
CityscapesClass
(
'fence'
,
13
,
4
,
'construction'
,
2
,
False
,
False
,
(
190
,
153
,
153
)),
CityscapesClass
(
'guard rail'
,
14
,
19
,
'construction'
,
2
,
False
,
True
,
(
180
,
165
,
180
)),
CityscapesClass
(
'bridge'
,
15
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
100
,
100
)),
CityscapesClass
(
'tunnel'
,
16
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
120
,
90
)),
CityscapesClass
(
'pole'
,
17
,
5
,
'object'
,
3
,
False
,
False
,
(
153
,
153
,
153
)),
CityscapesClass
(
'polegroup'
,
18
,
19
,
'object'
,
3
,
False
,
True
,
(
153
,
153
,
153
)),
CityscapesClass
(
'traffic light'
,
19
,
6
,
'object'
,
3
,
False
,
False
,
(
250
,
170
,
30
)),
CityscapesClass
(
'traffic sign'
,
20
,
7
,
'object'
,
3
,
False
,
False
,
(
220
,
220
,
0
)),
CityscapesClass
(
'vegetation'
,
21
,
8
,
'nature'
,
4
,
False
,
False
,
(
107
,
142
,
35
)),
CityscapesClass
(
'terrain'
,
22
,
9
,
'nature'
,
4
,
False
,
False
,
(
152
,
251
,
152
)),
CityscapesClass
(
'sky'
,
23
,
10
,
'sky'
,
5
,
False
,
False
,
(
70
,
130
,
180
)),
CityscapesClass
(
'person'
,
24
,
11
,
'human'
,
6
,
True
,
False
,
(
220
,
20
,
60
)),
CityscapesClass
(
'rider'
,
25
,
12
,
'human'
,
6
,
True
,
False
,
(
255
,
0
,
0
)),
CityscapesClass
(
'car'
,
26
,
13
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
142
)),
CityscapesClass
(
'truck'
,
27
,
14
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
70
)),
CityscapesClass
(
'bus'
,
28
,
15
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
60
,
100
)),
CityscapesClass
(
'caravan'
,
29
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
90
)),
CityscapesClass
(
'trailer'
,
30
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
110
)),
CityscapesClass
(
'train'
,
31
,
16
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
80
,
100
)),
CityscapesClass
(
'motorcycle'
,
32
,
17
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
230
)),
CityscapesClass
(
'bicycle'
,
33
,
18
,
'vehicle'
,
7
,
True
,
False
,
(
119
,
11
,
32
)),
CityscapesClass
(
'license plate'
,
-
1
,
-
1
,
'vehicle'
,
7
,
False
,
True
,
(
0
,
0
,
142
)),
]
# label2trainid
label2trainid
=
{
label
.
id
:
label
.
train_id
for
label
in
classes
}
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
mode
:
str
=
"fine"
,
resolution
:
int
=
1024
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
'gtFine'
if
mode
==
'fine'
else
'gtCoarse'
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
'leftImg8bit_trainvaltest/leftImg8bit'
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
'gtFine_trainvaltest/gtFine'
,
split
)
self
.
split
=
split
self
.
resolution
=
resolution
self
.
images
=
[]
self
.
targets
=
[]
for
city
in
sorted
(
os
.
listdir
(
self
.
images_dir
)):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
target_name
=
'{}_{}_labelIds.png'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
mode
)
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
np
.
array
(
target
)
target_copy
=
target
.
copy
()
for
k
,
v
in
Cityscapes
.
label2trainid
.
items
():
binary_target
=
(
target
==
k
)
target_copy
[
binary_target
]
=
v
target
=
target_copy
target
=
Image
.
fromarray
(
target
.
astype
(
np
.
uint8
))
if
self
.
transforms
is
not
None
:
image
,
target
=
self
.
transforms
(
image
,
target
)
return
image
,
target
def
__len__
(
self
)
->
int
:
# len(self.images)
return
len
(
self
.
images
)
tasks/vision/segmentation/data.py
0 → 100644
View file @
2eea6216
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron.data.autoaugment
import
ImageNetPolicy
from
tasks.vision.segmentation.cityscapes
import
Cityscapes
import
tasks.vision.segmentation.transforms
as
ET
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
class
VitSegmentationJointTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
if
self
.
train
:
self
.
transform0
=
ET
.
RandomSizeAndCrop
(
resolution
)
self
.
transform1
=
ET
.
RandomHorizontallyFlip
()
def
__call__
(
self
,
img
,
mask
):
if
self
.
train
:
img
,
mask
=
self
.
transform0
(
img
,
mask
)
img
,
mask
=
self
.
transform1
(
img
,
mask
)
return
img
,
mask
class
VitSegmentationImageTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
args
=
get_args
()
self
.
train
=
train
assert
args
.
fp16
or
args
.
bf16
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
self
.
mean_std
=
args
.
mean_std
if
self
.
train
:
assert
resolution
is
not
None
self
.
transform
=
T
.
Compose
([
ET
.
PhotoMetricDistortion
(),
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
self
.
transform
=
T
.
Compose
([
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
def
__call__
(
self
,
input
):
output
=
self
.
transform
(
input
)
return
output
class
VitSegmentationTargetTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
def
__call__
(
self
,
input
):
output
=
torch
.
from_numpy
(
np
.
array
(
input
,
dtype
=
np
.
int32
)).
long
()
return
output
class
RandomSeedSegmentationDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
joint_transform
,
image_transform
,
target_transform
):
args
=
get_args
()
self
.
base_seed
=
args
.
seed
self
.
curr_seed
=
self
.
base_seed
self
.
dataset
=
dataset
self
.
joint_transform
=
joint_transform
self
.
image_transform
=
image_transform
self
.
target_transform
=
target_transform
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
set_epoch
(
self
,
epoch
):
self
.
curr_seed
=
self
.
base_seed
+
100
*
epoch
def
__getitem__
(
self
,
idx
):
seed
=
idx
+
self
.
curr_seed
img
,
mask
=
self
.
dataset
[
idx
]
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
img
,
mask
=
self
.
joint_transform
(
img
,
mask
)
img
=
self
.
image_transform
(
img
)
mask
=
self
.
target_transform
(
mask
)
return
img
,
mask
def
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
):
args
=
get_args
()
args
.
num_classes
=
Cityscapes
.
num_classes
args
.
ignore_index
=
Cityscapes
.
ignore_index
args
.
color_table
=
Cityscapes
.
color_table
args
.
mean_std
=
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
])
train_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
True
,
resolution
=
image_size
)
val_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
False
,
resolution
=
image_size
)
train_image_transform
=
\
VitSegmentationImageTransform
(
train
=
True
,
resolution
=
image_size
)
val_image_transform
=
\
VitSegmentationImageTransform
(
train
=
False
,
resolution
=
image_size
)
train_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
True
,
resolution
=
image_size
)
val_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
False
,
resolution
=
image_size
)
# training dataset
train_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'train'
,
mode
=
'fine'
,
resolution
=
image_size
)
train_data
=
RandomSeedSegmentationDataset
(
train_data
,
joint_transform
=
train_joint_transform
,
image_transform
=
train_image_transform
,
target_transform
=
train_target_transform
)
# validation dataset
val_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'val'
,
mode
=
'fine'
,
resolution
=
image_size
)
val_data
=
RandomSeedSegmentationDataset
(
val_data
,
joint_transform
=
val_joint_transform
,
image_transform
=
val_image_transform
,
target_transform
=
val_target_transform
)
return
train_data
,
val_data
def
build_train_valid_datasets
(
data_path
,
image_size
):
return
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
)
tasks/vision/segmentation/finetune_segformer.py
0 → 100644
View file @
2eea6216
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SegformerSegmentationModel
from
megatron.model.vision.utils
import
resize
def
calculate_iou
(
hist_data
):
acc
=
np
.
diag
(
hist_data
).
sum
()
/
hist_data
.
sum
()
acc_cls
=
np
.
diag
(
hist_data
)
/
hist_data
.
sum
(
axis
=
1
)
acc_cls
=
np
.
nanmean
(
acc_cls
)
divisor
=
hist_data
.
sum
(
axis
=
1
)
+
hist_data
.
sum
(
axis
=
0
)
-
\
np
.
diag
(
hist_data
)
iu
=
np
.
diag
(
hist_data
)
/
divisor
return
iu
,
acc
,
acc_cls
def
fast_hist
(
pred
,
gtruth
,
num_classes
):
# mask indicates pixels we care about
mask
=
(
gtruth
>=
0
)
&
(
gtruth
<
num_classes
)
# stretch ground truth labels by num_classes
# class 0 -> 0
# class 1 -> 19
# class 18 -> 342
#
# TP at 0 + 0, 1 + 1, 2 + 2 ...
#
# TP exist where value == num_classes*class_id + class_id
# FP = row[class].sum() - TP
# FN = col[class].sum() - TP
hist
=
np
.
bincount
(
num_classes
*
gtruth
[
mask
].
astype
(
int
)
+
pred
[
mask
],
minlength
=
num_classes
**
2
)
hist
=
hist
.
reshape
(
num_classes
,
num_classes
)
return
hist
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
model
=
SegformerSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
print_rank_0
(
"model = {}"
.
format
(
model
))
return
model
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
logits
=
output_tensor
.
contiguous
().
float
()
logits
=
resize
(
logits
,
size
=
masks
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
# Cross-entropy loss.
# weight = calculate_weight(masks, num_classes)
loss
=
F
.
cross_entropy
(
logits
,
masks
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
logits
=
resize
(
logits
,
size
=
labels
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
cpu
().
numpy
()
performs
=
fast_hist
(
preds
.
flatten
(),
labels
.
cpu
().
numpy
().
flatten
(),
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
performs_tensor
=
torch
.
cuda
.
FloatTensor
(
performs
)
torch
.
distributed
.
all_reduce
(
performs_tensor
,
group
=
mpu
.
get_data_parallel_group
())
hist
=
performs_tensor
.
cpu
().
numpy
()
iu
,
acc
,
acc_cls
=
calculate_iou
(
hist
)
miou
=
np
.
nanmean
(
iu
)
return
iu
,
miou
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/finetune_setr.py
0 → 100644
View file @
2eea6216
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.metrics
import
CFMatrix
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SetrSegmentationModel
from
tasks.vision.segmentation.utils
import
slidingcrops
,
slidingjoins
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
return
SetrSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
weight
=
calculate_weight
(
masks
,
args
.
num_classes
)
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
masks
,
weight
=
weight
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
if
not
model
.
training
:
images
,
masks
,
_
,
_
=
slidingcrops
(
images
,
masks
)
#print_rank_0("images size = {}".format(images.size()))
if
not
model
.
training
:
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
else
:
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
slices_info
,
img_size
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
int
()
preds
,
labels
=
slidingjoins
(
preds
,
max_probs
,
labels
,
slices_info
,
img_size
)
_
,
performs
=
CFMatrix
()(
preds
,
labels
,
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
args
=
get_args
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
assert
not
model
.
training
images
,
labels
,
slices_info
,
img_size
=
slidingcrops
(
images
,
labels
)
# Forward model.
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
return
output_tensor
,
partial
(
loss_func
,
labels
,
slices_info
,
img_size
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
performs
,
group
=
mpu
.
get_data_parallel_group
())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive
=
performs
[:,
0
]
false_positive
=
performs
[:,
1
]
false_negative
=
performs
[:,
3
]
iou
=
true_positive
/
(
true_positive
+
false_positive
+
false_negative
)
miou
=
iou
[
~
torch
.
isnan
(
iou
)].
mean
()
return
iou
.
tolist
(),
miou
.
item
()
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/metrics.py
0 → 100644
View file @
2eea6216
#!/usr/bin/env python
# -*- coding: UTF-8 -*-
#copyright (c) go-hiroaki & Chokurei
#email: guangmingwu2010@gmail.com
# guozhilingty@gmail.com
#
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
eps
=
1e-6
def
_binarize
(
y_data
,
threshold
):
"""
args:
y_data : [float] 4-d tensor in [batch_size, channels, img_rows, img_cols]
threshold : [float] [0.0, 1.0]
return 4-d binarized y_data
"""
y_data
[
y_data
<
threshold
]
=
0.0
y_data
[
y_data
>=
threshold
]
=
1.0
return
y_data
def
_argmax
(
y_data
,
dim
):
"""
args:
y_data : 4-d tensor in [batch_size, chs, img_rows, img_cols]
dim : int
return 3-d [int] y_data
"""
return
torch
.
argmax
(
y_data
,
dim
).
int
()
def
_get_tp
(
y_pred
,
y_true
):
"""
args:
y_true : [int] 3-d in [batch_size, img_rows, img_cols]
y_pred : [int] 3-d in [batch_size, img_rows, img_cols]
return [float] true_positive
"""
return
torch
.
sum
(
y_true
*
y_pred
).
float
()
def
_get_fp
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_positive
"""
return
torch
.
sum
((
1
-
y_true
)
*
y_pred
).
float
()
def
_get_tn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] true_negative
"""
return
torch
.
sum
((
1
-
y_true
)
*
(
1
-
y_pred
)).
float
()
def
_get_fn
(
y_pred
,
y_true
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 3-d ndarray in [batch_size, img_rows, img_cols]
return [float] false_negative
"""
return
torch
.
sum
(
y_true
*
(
1
-
y_pred
)).
float
()
def
_get_weights
(
y_true
,
nb_ch
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
nb_ch : int
return [float] weights
"""
batch_size
,
img_rows
,
img_cols
=
y_true
.
shape
pixels
=
batch_size
*
img_rows
*
img_cols
weights
=
[
torch
.
sum
(
y_true
==
ch
).
item
()
/
pixels
for
ch
in
range
(
nb_ch
)]
return
weights
class
CFMatrix
(
object
):
def
__init__
(
self
,
des
=
None
):
self
.
des
=
des
def
__repr__
(
self
):
return
"ConfusionMatrix"
def
__call__
(
self
,
y_pred
,
y_true
,
ignore_index
,
threshold
=
0.5
):
"""
args:
y_true : 3-d ndarray in [batch_size, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return confusion matrix
"""
batch_size
,
img_rows
,
img_cols
=
y_pred
.
shape
chs
=
ignore_index
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
[
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
]
performs
=
None
else
:
performs
=
torch
.
zeros
(
chs
,
4
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_false_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_false_ch
[
torch
.
logical_and
((
y_true
!=
ch
),
(
y_true
!=
ignore_index
))]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
torch
.
sum
(
y_false_ch
*
y_pred_ch
).
float
()
nb_tn
=
torch
.
sum
(
y_false_ch
*
(
1
-
y_pred_ch
)).
float
()
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
),
:]
=
torch
.
FloatTensor
([
nb_tp
,
nb_fp
,
nb_tn
,
nb_fn
])
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
OAAcc
(
object
):
def
__init__
(
self
,
des
=
"Overall Accuracy"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"OAcc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (tp+tn)/total
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
nb_tp_tn
=
torch
.
sum
(
y_true
==
y_pred
).
float
()
mperforms
=
nb_tp_tn
/
(
batch_size
*
img_rows
*
img_cols
)
performs
=
None
return
mperforms
,
performs
class
Precision
(
object
):
def
__init__
(
self
,
des
=
"Precision"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Prec"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fp)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Recall
(
object
):
def
__init__
(
self
,
des
=
"Recall"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Reca"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return tp/(tp+fn)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
mperforms
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
performs
[
int
(
ch
)]
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
F1Score
(
object
):
def
__init__
(
self
,
des
=
"F1Score"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"F1Sc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return 2*precision*recall/(precision+recall)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
mperforms
=
2
*
_precision
*
_recall
/
(
_precision
+
_recall
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
_precision
=
nb_tp
/
(
nb_tp
+
nb_fp
+
esp
)
_recall
=
nb_tp
/
(
nb_tp
+
nb_fn
+
esp
)
performs
[
int
(
ch
)]
=
2
*
_precision
*
\
_recall
/
(
_precision
+
_recall
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Kappa
(
object
):
def
__init__
(
self
,
des
=
"Kappa"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Kapp"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return (Po-Pe)/(1-Pe)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
nb_tp
=
_get_tp
(
y_pred
,
y_true
)
nb_fp
=
_get_fp
(
y_pred
,
y_true
)
nb_tn
=
_get_tn
(
y_pred
,
y_true
)
nb_fn
=
_get_fn
(
y_pred
,
y_true
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
mperforms
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
nb_tp
=
_get_tp
(
y_pred_ch
,
y_true_ch
)
nb_fp
=
_get_fp
(
y_pred_ch
,
y_true_ch
)
nb_tn
=
_get_tn
(
y_pred_ch
,
y_true_ch
)
nb_fn
=
_get_fn
(
y_pred_ch
,
y_true_ch
)
nb_total
=
nb_tp
+
nb_fp
+
nb_tn
+
nb_fn
Po
=
(
nb_tp
+
nb_tn
)
/
nb_total
Pe
=
((
nb_tp
+
nb_fp
)
*
(
nb_tp
+
nb_fn
)
+
(
nb_fn
+
nb_tn
)
*
(
nb_fp
+
nb_tn
))
/
(
nb_total
**
2
)
performs
[
int
(
ch
)]
=
(
Po
-
Pe
)
/
(
1
-
Pe
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
Jaccard
(
object
):
def
__init__
(
self
,
des
=
"Jaccard"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"Jacc"
def
__call__
(
self
,
y_pred
,
y_true
,
threshold
=
0.5
):
"""
args:
y_true : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, chs, img_rows, img_cols]
threshold : [0.0, 1.0]
return intersection / (sum-intersection)
"""
batch_size
,
chs
,
img_rows
,
img_cols
=
y_true
.
shape
device
=
y_true
.
device
if
chs
==
1
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
y_true
=
_binarize
(
y_true
,
threshold
)
_intersec
=
torch
.
sum
(
y_true
*
y_pred
).
float
()
_sum
=
torch
.
sum
(
y_true
+
y_pred
).
float
()
mperforms
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
performs
=
None
else
:
y_pred
=
_argmax
(
y_pred
,
1
)
y_true
=
_argmax
(
y_true
,
1
)
performs
=
torch
.
zeros
(
chs
,
1
).
to
(
device
)
weights
=
_get_weights
(
y_true
,
chs
)
for
ch
in
range
(
chs
):
y_true_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_pred_ch
=
torch
.
zeros
(
batch_size
,
img_rows
,
img_cols
)
y_true_ch
[
y_true
==
ch
]
=
1
y_pred_ch
[
y_pred
==
ch
]
=
1
_intersec
=
torch
.
sum
(
y_true_ch
*
y_pred_ch
).
float
()
_sum
=
torch
.
sum
(
y_true_ch
+
y_pred_ch
).
float
()
performs
[
int
(
ch
)]
=
_intersec
/
(
_sum
-
_intersec
+
esp
)
mperforms
=
sum
([
i
*
j
for
(
i
,
j
)
in
zip
(
performs
,
weights
)])
return
mperforms
,
performs
class
MSE
(
object
):
def
__init__
(
self
,
des
=
"Mean Square Error"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"MSE"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return mean_squared_error, smaller the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
return
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
class
PSNR
(
object
):
def
__init__
(
self
,
des
=
"Peak Signal to Noise Ratio"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"PSNR"
def
__call__
(
self
,
y_pred
,
y_true
,
dim
=
1
,
threshold
=
None
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
threshold : [0.0, 1.0]
return PSNR, larger the better
"""
if
threshold
:
y_pred
=
_binarize
(
y_pred
,
threshold
)
mse
=
torch
.
mean
((
y_pred
-
y_true
)
**
2
)
return
10
*
torch
.
log10
(
1
/
mse
)
class
SSIM
(
object
):
'''
modified from https://github.com/jorge-pessoa/pytorch-msssim
'''
def
__init__
(
self
,
des
=
"structural similarity index"
):
self
.
des
=
des
def
__repr__
(
self
):
return
"SSIM"
def
gaussian
(
self
,
w_size
,
sigma
):
gauss
=
torch
.
Tensor
([
math
.
exp
(
-
(
x
-
w_size
//
2
)
**
2
/
float
(
2
*
sigma
**
2
))
for
x
in
range
(
w_size
)])
return
gauss
/
gauss
.
sum
()
def
create_window
(
self
,
w_size
,
channel
=
1
):
_1D_window
=
self
.
gaussian
(
w_size
,
1.5
).
unsqueeze
(
1
)
_2D_window
=
_1D_window
.
mm
(
_1D_window
.
t
()).
float
().
unsqueeze
(
0
).
unsqueeze
(
0
)
window
=
_2D_window
.
expand
(
channel
,
1
,
w_size
,
w_size
).
contiguous
()
return
window
def
__call__
(
self
,
y_pred
,
y_true
,
w_size
=
11
,
size_average
=
True
,
full
=
False
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
w_size : int, default 11
size_average : boolean, default True
full : boolean, default False
return ssim, larger the better
"""
# Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
if
torch
.
max
(
y_pred
)
>
128
:
max_val
=
255
else
:
max_val
=
1
if
torch
.
min
(
y_pred
)
<
-
0.5
:
min_val
=
-
1
else
:
min_val
=
0
L
=
max_val
-
min_val
padd
=
0
(
_
,
channel
,
height
,
width
)
=
y_pred
.
size
()
window
=
self
.
create_window
(
w_size
,
channel
=
channel
).
to
(
y_pred
.
device
)
mu1
=
F
.
conv2d
(
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
mu2
=
F
.
conv2d
(
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
mu1_sq
=
mu1
.
pow
(
2
)
mu2_sq
=
mu2
.
pow
(
2
)
mu1_mu2
=
mu1
*
mu2
sigma1_sq
=
F
.
conv2d
(
y_pred
*
y_pred
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_sq
sigma2_sq
=
F
.
conv2d
(
y_true
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu2_sq
sigma12
=
F
.
conv2d
(
y_pred
*
y_true
,
window
,
padding
=
padd
,
groups
=
channel
)
-
mu1_mu2
C1
=
(
0.01
*
L
)
**
2
C2
=
(
0.03
*
L
)
**
2
v1
=
2.0
*
sigma12
+
C2
v2
=
sigma1_sq
+
sigma2_sq
+
C2
cs
=
torch
.
mean
(
v1
/
v2
)
# contrast sensitivity
ssim_map
=
((
2
*
mu1_mu2
+
C1
)
*
v1
)
/
((
mu1_sq
+
mu2_sq
+
C1
)
*
v2
)
if
size_average
:
ret
=
ssim_map
.
mean
()
else
:
ret
=
ssim_map
.
mean
(
1
).
mean
(
1
).
mean
(
1
)
if
full
:
return
ret
,
cs
return
ret
class
AE
(
object
):
"""
Modified from matlab : colorangle.m, MATLAB V2019b
angle = acos(RGB1' * RGB2 / (norm(RGB1) * norm(RGB2)));
angle = 180 / pi * angle;
"""
def
__init__
(
self
,
des
=
'average Angular Error'
):
self
.
des
=
des
def
__repr__
(
self
):
return
"AE"
def
__call__
(
self
,
y_pred
,
y_true
):
"""
args:
y_true : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
y_pred : 4-d ndarray in [batch_size, channels, img_rows, img_cols]
return average AE, smaller the better
"""
dotP
=
torch
.
sum
(
y_pred
*
y_true
,
dim
=
1
)
Norm_pred
=
torch
.
sqrt
(
torch
.
sum
(
y_pred
*
y_pred
,
dim
=
1
))
Norm_true
=
torch
.
sqrt
(
torch
.
sum
(
y_true
*
y_true
,
dim
=
1
))
ae
=
180
/
math
.
pi
*
torch
.
acos
(
dotP
/
(
Norm_pred
*
Norm_true
+
eps
))
return
ae
.
mean
(
1
).
mean
(
1
)
if
__name__
==
"__main__"
:
for
ch
in
[
3
,
1
]:
batch_size
,
img_row
,
img_col
=
1
,
224
,
224
y_true
=
torch
.
rand
(
batch_size
,
ch
,
img_row
,
img_col
)
noise
=
torch
.
zeros
(
y_true
.
size
()).
data
.
normal_
(
0
,
std
=
0.1
)
y_pred
=
y_true
+
noise
for
cuda
in
[
False
,
True
]:
if
cuda
:
y_pred
=
y_pred
.
cuda
()
y_true
=
y_true
.
cuda
()
print
(
'#'
*
20
,
'Cuda : {} ; size : {}'
.
format
(
cuda
,
y_true
.
size
()))
########### similarity metrics
metric
=
MSE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
PSNR
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
SSIM
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
LPIPS
(
cuda
)
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
metric
=
AE
()
acc
=
metric
(
y_pred
,
y_true
).
item
()
print
(
"{} ==> {}"
.
format
(
repr
(
metric
),
acc
))
########### accuracy metrics
metric
=
OAAcc
()
maccu
,
accu
=
metric
(
y_pred
,
y_true
)
print
(
'mAccu:'
,
maccu
,
'Accu'
,
accu
)
metric
=
Precision
()
mprec
,
prec
=
metric
(
y_pred
,
y_true
)
print
(
'mPrec:'
,
mprec
,
'Prec'
,
prec
)
metric
=
Recall
()
mreca
,
reca
=
metric
(
y_pred
,
y_true
)
print
(
'mReca:'
,
mreca
,
'Reca'
,
reca
)
metric
=
F1Score
()
mf1sc
,
f1sc
=
metric
(
y_pred
,
y_true
)
print
(
'mF1sc:'
,
mf1sc
,
'F1sc'
,
f1sc
)
metric
=
Kappa
()
mkapp
,
kapp
=
metric
(
y_pred
,
y_true
)
print
(
'mKapp:'
,
mkapp
,
'Kapp'
,
kapp
)
metric
=
Jaccard
()
mjacc
,
jacc
=
metric
(
y_pred
,
y_true
)
print
(
'mJacc:'
,
mjacc
,
'Jacc'
,
jacc
)
tasks/vision/segmentation/seg_heads.py
0 → 100644
View file @
2eea6216
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.utils
import
resize
class
SetrSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
SetrSegmentationHead
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
hidden_size
self
.
num_classes
=
num_classes
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
patch_dim
=
args
.
patch_dim
self
.
layernorm
=
LayerNorm
(
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
self
.
conv_0
=
torch
.
nn
.
Conv2d
(
hidden_size
,
hidden_size
,
1
,
1
,
bias
=
False
)
self
.
norm_0
=
apex
.
parallel
.
SyncBatchNorm
(
hidden_size
)
self
.
conv_1
=
torch
.
nn
.
Conv2d
(
hidden_size
,
num_classes
,
1
,
1
)
def
to_2D
(
self
,
x
):
n
,
hw
,
c
=
x
.
shape
h
=
self
.
img_h
//
self
.
patch_dim
w
=
self
.
img_w
//
self
.
patch_dim
assert
(
hw
==
h
*
w
)
x
=
x
.
transpose
(
1
,
2
).
reshape
(
n
,
c
,
h
,
w
)
return
x
def
forward
(
self
,
hidden_states
):
# [b c h w]
hidden_states
=
self
.
layernorm
(
hidden_states
)
hidden_states
=
self
.
to_2D
(
hidden_states
)
hidden_states
=
self
.
conv_0
(
hidden_states
)
hidden_states
=
self
.
norm_0
(
hidden_states
)
hidden_states
=
torch
.
tanh
(
hidden_states
)
hidden_states
=
self
.
conv_1
(
hidden_states
)
# [b c h w]
result
=
F
.
interpolate
(
hidden_states
,
size
=
(
self
.
img_h
,
self
.
img_w
),
mode
=
'bilinear'
)
return
result
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
SegformerSegmentationHead
(
MegatronModule
):
def
__init__
(
self
,
feature_strides
,
in_channels
,
embedding_dim
,
dropout_ratio
):
super
(
SegformerSegmentationHead
,
self
).
__init__
()
assert
len
(
feature_strides
)
==
len
(
in_channels
)
assert
min
(
feature_strides
)
==
feature_strides
[
0
]
args
=
get_args
()
self
.
feature_strides
=
feature_strides
self
.
in_channels
=
in_channels
self
.
embedding_dim
=
embedding_dim
self
.
num_classes
=
args
.
num_classes
self
.
dropout_ratio
=
dropout_ratio
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
\
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
self
.
dropout_ratio
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
num_classes
,
kernel_size
=
1
)
def
forward
(
self
,
inputs
):
c1
,
c2
,
c3
,
c4
=
inputs
############## MLP decoder on C1-C4 ###########
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
self
.
conv_fuse
(
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
))
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
return
x
tasks/vision/segmentation/seg_models.py
0 → 100644
View file @
2eea6216
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3
,
mit_b5
from
tasks.vision.segmentation.seg_heads
import
SetrSegmentationHead
,
SegformerSegmentationHead
class
SetrSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SetrSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
assert
post_process
&
pre_process
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
class_token
=
False
,
post_layer_norm
=
False
,
drop_path_rate
=
0.1
)
self
.
head
=
SetrSegmentationHead
(
self
.
hidden_size
,
self
.
num_classes
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
result_final
=
self
.
head
(
hidden_states
)
return
result_final
class
SegformerSegmentationModel
(
MegatronModule
):
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
SegformerSegmentationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
mit_b5
()
self
.
head
=
SegformerSegmentationHead
(
feature_strides
=
[
4
,
8
,
16
,
32
],
in_channels
=
[
64
,
128
,
320
,
512
],
embedding_dim
=
768
,
dropout_ratio
=
0.1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
# [b hw c]
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
tasks/vision/segmentation/transforms.py
0 → 100644
View file @
2eea6216
# Copyright (c) 2020 The MMSegmenation Authors.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron
import
print_rank_0
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
,
ImageEnhance
import
torchvision.transforms
as
torch_tr
def
_is_pil_image
(
img
):
return
isinstance
(
img
,
Image
.
Image
)
class
PhotoMetricDistortion
(
object
):
"""Apply photometric distortion to image sequentially, every transformation
is applied with a probability of 0.5. The position of random contrast is in
second or second to last.
1. random brightness
2. random contrast (mode 0)
3. convert color from BGR to HSV
4. random saturation
5. random hue
6. convert color from HSV to BGR
7. random contrast (mode 1)
8. randomly swap channels
Args:
brightness_delta (int): delta of brightness.
contrast_range (tuple): range of contrast.
saturation_range (tuple): range of saturation.
hue_delta (int): delta of hue.
"""
def
__init__
(
self
,
brightness_delta
=
32
,
contrast_range
=
(
0.5
,
1.5
),
saturation_range
=
(
0.5
,
1.5
),
hue_delta
=
18
):
self
.
brightness_delta
=
brightness_delta
self
.
contrast_lower
,
self
.
contrast_upper
=
contrast_range
self
.
saturation_lower
,
self
.
saturation_upper
=
saturation_range
self
.
hue_delta
=
hue_delta
def
convert
(
self
,
img
,
alpha
=
1
,
beta
=
0
):
"""Multiple with alpha and add beat with clip."""
img
=
img
.
astype
(
np
.
float32
)
*
alpha
+
beta
img
=
np
.
clip
(
img
,
0
,
255
)
return
img
.
astype
(
np
.
uint8
)
def
brightness
(
self
,
img
):
"""Brightness distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
beta
=
random
.
uniform
(
-
self
.
brightness_delta
,
self
.
brightness_delta
))
return
img
def
contrast
(
self
,
img
):
"""Contrast distortion."""
if
random
.
randint
(
0
,
1
):
return
self
.
convert
(
img
,
alpha
=
random
.
uniform
(
self
.
contrast_lower
,
self
.
contrast_upper
))
return
img
def
saturation
(
self
,
img
):
"""Saturation distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
1
]
=
self
.
convert
(
img
[:,
:,
1
],
alpha
=
random
.
uniform
(
self
.
saturation_lower
,
self
.
saturation_upper
))
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
hue
(
self
,
img
):
"""Hue distortion."""
if
random
.
randint
(
0
,
1
):
img
=
mmcv
.
bgr2hsv
(
img
)
img
[:,
:,
0
]
=
(
img
[:,
:,
0
].
astype
(
int
)
+
random
.
randint
(
-
self
.
hue_delta
,
self
.
hue_delta
))
%
180
img
=
mmcv
.
hsv2bgr
(
img
)
return
img
def
__call__
(
self
,
img
):
"""Call function to perform photometric distortion on images.
Args:
results (dict): Result dict from loading pipeline.
Returns:
dict: Result dict with images distorted.
"""
img
=
np
.
array
(
img
)
# random brightness
img
=
self
.
brightness
(
img
)
# mode == 0 --> do random contrast first
# mode == 1 --> do random contrast last
mode
=
random
.
randint
(
0
,
1
)
if
mode
==
1
:
img
=
self
.
contrast
(
img
)
# random saturation
img
=
self
.
saturation
(
img
)
# random hue
img
=
self
.
hue
(
img
)
# random contrast
if
mode
==
0
:
img
=
self
.
contrast
(
img
)
img
=
Image
.
fromarray
(
img
.
astype
(
np
.
uint8
)).
convert
(
'RGB'
)
return
img
class
RandomCrop
(
object
):
"""
Take a random crop from the image.
First the image or crop size may need to be adjusted if the incoming image
is too small...
If the image is smaller than the crop, then:
the image is padded up to the size of the crop
unless 'nopad', in which case the crop size is shrunk to fit the image
A random crop is taken such that the crop fits within the image.
if cfg.DATASET.TRANSLATION_AUG_FIX is set, we insure that there's always
translation randomness of at least that value around the image.
if image < crop_size:
# slide crop within image, random offset
else:
# slide image within crop
"""
def
__init__
(
self
,
crop_size
):
args
=
get_args
()
self
.
size
=
crop_size
self
.
cat_max_ratio
=
0.75
self
.
ignore_index
=
args
.
ignore_index
self
.
pad_color
=
(
0
,
0
,
0
)
def
get_crop_bbox
(
self
,
img
):
"""Randomly get a crop bounding box."""
img_w
,
img_h
=
img
.
size
target_h
,
target_w
=
self
.
size
#[H W]
margin_h
=
max
(
img_h
-
target_h
,
0
)
margin_w
=
max
(
img_w
-
target_w
,
0
)
offset_h
=
random
.
randint
(
0
,
margin_h
)
offset_w
=
random
.
randint
(
0
,
margin_w
)
crop_y1
,
crop_y2
=
offset_h
,
offset_h
+
target_h
crop_x1
,
crop_x2
=
offset_w
,
offset_w
+
target_w
return
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
def
crop
(
self
,
img
,
crop_bbox
):
"""Crop from ``img``"""
crop_y1
,
crop_y2
,
crop_x1
,
crop_x2
=
crop_bbox
img
=
img
.
crop
((
crop_x1
,
crop_y1
,
crop_x2
,
crop_y2
))
return
img
@
staticmethod
def
crop_in_image
(
target_w
,
target_h
,
w
,
h
,
img
,
mask
):
if
w
==
target_w
:
x1
=
0
else
:
x1
=
random
.
randint
(
0
,
w
-
target_w
)
if
h
==
target_h
:
y1
=
0
else
:
y1
=
random
.
randint
(
0
,
h
-
target_h
)
return
[
img
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
)),
mask
.
crop
((
x1
,
y1
,
x1
+
target_w
,
y1
+
target_h
))]
def
__call__
(
self
,
img
,
mask
):
w
,
h
=
img
.
size
target_h
,
target_w
=
self
.
size
# ASSUME H, W
if
w
==
target_w
and
h
==
target_h
:
return
img
,
mask
# Pad image if image < crop
if
target_h
>
h
:
pad_h
=
(
target_h
-
h
)
//
2
+
1
else
:
pad_h
=
0
if
target_w
>
w
:
pad_w
=
(
target_w
-
w
)
//
2
+
1
else
:
pad_w
=
0
border
=
(
pad_w
,
pad_h
,
pad_w
,
pad_h
)
if
pad_h
or
pad_w
:
img
=
ImageOps
.
expand
(
img
,
border
=
border
,
fill
=
(
0
,
0
,
0
))
mask
=
ImageOps
.
expand
(
mask
,
border
=
border
,
fill
=
self
.
ignore_index
)
w
,
h
=
img
.
size
crop_bbox
=
self
.
get_crop_bbox
(
img
)
if
self
.
cat_max_ratio
<
1.
:
# Repeat 10 times
for
_
in
range
(
10
):
seg_temp
=
self
.
crop
(
mask
,
crop_bbox
)
labels
,
cnt
=
np
.
unique
(
seg_temp
,
return_counts
=
True
)
cnt
=
cnt
[
labels
!=
self
.
ignore_index
]
if
len
(
cnt
)
>
1
and
np
.
max
(
cnt
)
/
np
.
sum
(
cnt
)
<
self
.
cat_max_ratio
:
break
crop_bbox
=
self
.
get_crop_bbox
(
img
)
# crop the image
img
=
self
.
crop
(
img
,
crop_bbox
)
# crop semantic seg
mask
=
self
.
crop
(
mask
,
crop_bbox
)
assert
(
img
.
size
[
0
]
==
self
.
size
[
1
]
and
img
.
size
[
1
]
==
self
.
size
[
0
])
return
img
,
mask
class
RandomSizeAndCrop
(
object
):
def
__init__
(
self
,
crop_size
,
scale_min
=
0.5
,
scale_max
=
2.0
):
self
.
crop
=
RandomCrop
(
crop_size
)
self
.
scale_min
=
scale_min
self
.
scale_max
=
scale_max
def
__call__
(
self
,
img
,
mask
):
scale_amt
=
random
.
uniform
(
self
.
scale_min
,
self
.
scale_max
)
w
,
h
=
[
int
(
i
*
scale_amt
)
for
i
in
img
.
size
]
resized_img
=
img
.
resize
((
w
,
h
),
Image
.
BICUBIC
)
resized_mask
=
mask
.
resize
((
w
,
h
),
Image
.
NEAREST
)
img
,
mask
=
self
.
crop
(
resized_img
,
resized_mask
)
return
img
,
mask
class
RandomHorizontallyFlip
(
object
):
def
__call__
(
self
,
img
,
mask
):
if
random
.
random
()
<
0.5
:
return
img
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
),
mask
.
transpose
(
Image
.
FLIP_LEFT_RIGHT
)
return
img
,
mask
def
adjust_brightness
(
img
,
brightness_factor
):
"""Adjust brightness of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
brightness_factor (float): How much to adjust the brightness. Can be
any non negative number. 0 gives a black image, 1 gives the
original image while 2 increases the brightness by a factor of 2.
Returns:
PIL Image: Brightness adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Brightness
(
img
)
img
=
enhancer
.
enhance
(
brightness_factor
)
return
img
def
adjust_contrast
(
img
,
contrast_factor
):
"""Adjust contrast of an Image.
Args:
img (PIL Image): PIL Image to be adjusted.
contrast_factor (float): How much to adjust the contrast. Can be any
non negative number. 0 gives a solid gray image, 1 gives the
original image while 2 increases the contrast by a factor of 2.
Returns:
PIL Image: Contrast adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Contrast
(
img
)
img
=
enhancer
.
enhance
(
contrast_factor
)
return
img
def
adjust_saturation
(
img
,
saturation_factor
):
"""Adjust color saturation of an image.
Args:
img (PIL Image): PIL Image to be adjusted.
saturation_factor (float): How much to adjust the saturation. 0 will
give a black and white image, 1 will give the original image while
2 will enhance the saturation by a factor of 2.
Returns:
PIL Image: Saturation adjusted image.
"""
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
enhancer
=
ImageEnhance
.
Color
(
img
)
img
=
enhancer
.
enhance
(
saturation_factor
)
return
img
def
adjust_hue
(
img
,
hue_factor
):
"""Adjust hue of an image.
The image hue is adjusted by converting the image to HSV and
cyclically shifting the intensities in the hue channel (H).
The image is then converted back to original image mode.
`hue_factor` is the amount of shift in H channel and must be in the
interval `[-0.5, 0.5]`.
See https://en.wikipedia.org/wiki/Hue for more details on Hue.
Args:
img (PIL Image): PIL Image to be adjusted.
hue_factor (float): How much to shift the hue channel. Should be in
[-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
HSV space in positive and negative direction respectively.
0 means no shift. Therefore, both -0.5 and 0.5 will give an image
with complementary colors while 0 gives the original image.
Returns:
PIL Image: Hue adjusted image.
"""
if
not
(
-
0.5
<=
hue_factor
<=
0.5
):
raise
ValueError
(
'hue_factor is not in [-0.5, 0.5].'
.
format
(
hue_factor
))
if
not
_is_pil_image
(
img
):
raise
TypeError
(
'img should be PIL Image. Got {}'
.
format
(
type
(
img
)))
input_mode
=
img
.
mode
if
input_mode
in
{
'L'
,
'1'
,
'I'
,
'F'
}:
return
img
h
,
s
,
v
=
img
.
convert
(
'HSV'
).
split
()
np_h
=
np
.
array
(
h
,
dtype
=
np
.
uint8
)
# uint8 addition take cares of rotation across boundaries
with
np
.
errstate
(
over
=
'ignore'
):
np_h
+=
np
.
uint8
(
hue_factor
*
255
)
h
=
Image
.
fromarray
(
np_h
,
'L'
)
img
=
Image
.
merge
(
'HSV'
,
(
h
,
s
,
v
)).
convert
(
input_mode
)
return
img
class
ColorJitter
(
object
):
"""Randomly change the brightness, contrast and saturation of an image.
Args:
brightness (float): How much to jitter brightness. brightness_factor
is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
contrast (float): How much to jitter contrast. contrast_factor
is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
saturation (float): How much to jitter saturation. saturation_factor
is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
hue(float): How much to jitter hue. hue_factor is chosen uniformly from
[-hue, hue]. Should be >=0 and <= 0.5.
"""
def
__init__
(
self
,
brightness
=
0
,
contrast
=
0
,
saturation
=
0
,
hue
=
0
):
self
.
brightness
=
brightness
self
.
contrast
=
contrast
self
.
saturation
=
saturation
self
.
hue
=
hue
@
staticmethod
def
get_params
(
brightness
,
contrast
,
saturation
,
hue
):
"""Get a randomized transform to be applied on image.
Arguments are same as that of __init__.
Returns:
Transform which randomly adjusts brightness, contrast and
saturation in a random order.
"""
transforms
=
[]
if
brightness
>
0
:
brightness_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
brightness
),
1
+
brightness
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_brightness
(
img
,
brightness_factor
)))
if
contrast
>
0
:
contrast_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
contrast
),
1
+
contrast
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_contrast
(
img
,
contrast_factor
)))
if
saturation
>
0
:
saturation_factor
=
np
.
random
.
uniform
(
max
(
0
,
1
-
saturation
),
1
+
saturation
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_saturation
(
img
,
saturation_factor
)))
if
hue
>
0
:
hue_factor
=
np
.
random
.
uniform
(
-
hue
,
hue
)
transforms
.
append
(
torch_tr
.
Lambda
(
lambda
img
:
adjust_hue
(
img
,
hue_factor
)))
np
.
random
.
shuffle
(
transforms
)
transform
=
torch_tr
.
Compose
(
transforms
)
return
transform
def
__call__
(
self
,
img
):
"""
Args:
img (PIL Image): Input image.
Returns:
PIL Image: Color jittered image.
"""
transform
=
self
.
get_params
(
self
.
brightness
,
self
.
contrast
,
self
.
saturation
,
self
.
hue
)
return
transform
(
img
)
tasks/vision/segmentation/utils.py
0 → 100644
View file @
2eea6216
import
math
import
torch
import
numpy
as
np
from
megatron
import
get_args
def
slidingcrops
(
img
,
mask
):
# img: [b c h w]
# mask: [b h w]
args
=
get_args
()
assert
args
.
img_h
==
args
.
img_w
crop_size
=
args
.
img_h
stride
=
args
.
seg_stride
ignore_index
=
args
.
ignore_index
n
,
c
,
h
,
w
=
img
.
shape
assert
h
>=
crop_size
assert
w
>=
crop_size
long_size
=
max
(
h
,
w
)
img_slices
,
mask_slices
,
slices_info
=
[],
[],
[]
if
long_size
>
crop_size
:
assert
stride
<=
crop_size
h_step_num
=
int
(
math
.
ceil
((
h
-
crop_size
)
/
float
(
stride
)))
+
1
w_step_num
=
int
(
math
.
ceil
((
w
-
crop_size
)
/
float
(
stride
)))
+
1
for
yy
in
range
(
h_step_num
):
for
xx
in
range
(
w_step_num
):
sy
,
sx
=
yy
*
stride
,
xx
*
stride
ey
,
ex
=
sy
+
crop_size
,
sx
+
crop_size
img_sub
=
img
[:,
:,
sy
:
ey
,
sx
:
ex
]
mask_sub
=
mask
[:,
sy
:
ey
,
sx
:
ex
]
# padding
sub_h
,
sub_w
=
img_sub
.
shape
[
2
:]
pad_h
=
max
(
crop_size
-
sub_h
,
0
)
pad_w
=
max
(
crop_size
-
sub_w
,
0
)
img_sub
=
torch
.
nn
.
functional
.
pad
(
img_sub
,
pad
=
(
0
,
pad_w
,
0
,
pad_h
),
value
=
ignore_index
)
mask_sub
=
torch
.
nn
.
functional
.
pad
(
mask_sub
,
pad
=
(
0
,
pad_w
,
0
,
pad_h
))
img_slices
.
append
(
img_sub
)
mask_slices
.
append
(
mask_sub
)
slices_info
.
append
([
sy
,
ey
,
sx
,
ex
,
sub_h
,
sub_w
])
return
torch
.
cat
(
img_slices
),
torch
.
cat
(
mask_slices
),
slices_info
,
(
h
,
w
)
else
:
return
img
,
mask
,
[[
0
,
h
,
0
,
w
,
h
,
w
]],
(
h
,
w
)
def
slidingjoins
(
preds
,
probs
,
labels
,
slices_info
,
img_size
):
args
=
get_args
()
num_slices
=
len
(
slices_info
)
if
num_slices
==
1
:
return
preds
,
labels
h
,
w
=
img_size
split_size
=
args
.
micro_batch_size
preds_split
=
torch
.
split
(
preds
,
split_size
)
probs_split
=
torch
.
split
(
probs
,
split_size
)
labels_split
=
torch
.
split
(
labels
,
split_size
)
assert
(
len
(
preds_split
)
==
num_slices
)
total_max_probs
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
float
,
device
=
'cuda'
)
total_preds
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
total_labels
=
torch
.
zeros
((
split_size
,
h
,
w
),
dtype
=
torch
.
int
,
device
=
'cuda'
)
for
i
in
range
(
num_slices
):
sy
,
ey
,
sx
,
ex
,
sub_h
,
sub_w
=
slices_info
[
i
]
assert
sy
+
sub_h
<=
h
assert
sx
+
sub_w
<=
w
curr_max_probs
=
total_max_probs
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
curr_preds
=
total_preds
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
local_max_probs
=
probs_split
[
i
][:,
:
sub_h
,
:
sub_w
]
local_preds
=
preds_split
[
i
][:,
:
sub_h
,
:
sub_w
]
result_max_probs
=
torch
.
maximum
(
curr_max_probs
,
local_max_probs
)
result_preds
=
torch
.
where
(
curr_max_probs
>=
local_max_probs
,
curr_preds
,
local_preds
)
total_max_probs
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
result_max_probs
total_preds
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
result_preds
total_labels
[:,
sy
:
sy
+
sub_h
,
sx
:
sx
+
sub_w
]
=
labels_split
[
i
][
0
,
:
sub_h
,
:
sub_w
]
return
total_preds
,
total_labels
tools/checkpoint_loader_megatron.py
0 → 100644
View file @
2eea6216
import
json
import
os
import
sys
import
types
import
torch
def
add_arguments
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'Megatron loader'
)
group
.
add_argument
(
'--true-vocab-size'
,
type
=
int
,
default
=
None
,
help
=
'original size of vocab, if specified will trim padding from embedding table.'
)
group
.
add_argument
(
'--vocab-file'
,
type
=
str
,
default
=
None
,
help
=
'Path to the vocab file. If specified will use this to get vocab size and '
'trim padding from the embedding table.'
)
group
.
add_argument
(
'--megatron-path'
,
type
=
str
,
default
=
None
,
help
=
'Base directory of deepspeed repository'
)
def
_load_checkpoint
(
queue
,
args
):
# Search in directory above this
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
if
args
.
megatron_path
is
not
None
:
sys
.
path
.
insert
(
0
,
args
.
megatron_path
)
try
:
from
megatron.arguments
import
parse_args
,
validate_args
from
megatron.global_vars
import
set_args
,
set_global_variables
from
megatron.checkpointing
import
load_args_from_checkpoint
,
load_checkpoint
from
megatron.model
import
ModelType
,
module
from
megatron
import
mpu
,
fused_kernels
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
queue
.
put
(
"exit"
)
exit
(
1
)
# We want all arguments to come from us
sys
.
argv
=
[
'script.py'
,
'--no-masked-softmax-fusion'
,
'--no-bias-gelu-fusion'
,
'--no-bias-dropout-fusion'
,
'--use-cpu-initialization'
,
'--micro-batch-size'
,
'1'
,
'--no-load-optim'
,
'--no-load-rng'
,
'--no-save-optim'
,
'--no-save-rng'
,
'--no-initialization'
,
'--load'
,
args
.
load_dir
]
margs
=
parse_args
()
margs
=
load_args_from_checkpoint
(
margs
)
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
margs
.
world_size
=
margs
.
tensor_model_parallel_size
*
margs
.
pipeline_model_parallel_size
margs
=
validate_args
(
margs
)
def
check_for_arg
(
arg_name
):
if
getattr
(
margs
,
arg_name
,
None
)
is
None
:
print
(
f
"Checkpoint does not specify the argument
{
arg_name
}
. Exiting."
)
print
(
f
"Arguments:
{
margs
}
"
)
queue
.
put
(
"exit"
)
exit
(
1
)
check_for_arg
(
'tensor_model_parallel_size'
)
check_for_arg
(
'pipeline_model_parallel_size'
)
check_for_arg
(
'num_layers'
)
check_for_arg
(
'hidden_size'
)
check_for_arg
(
'seq_length'
)
check_for_arg
(
'num_attention_heads'
)
check_for_arg
(
'max_position_embeddings'
)
check_for_arg
(
'tokenizer_type'
)
check_for_arg
(
'iteration'
)
check_for_arg
(
'bert_binary_head'
)
check_for_arg
(
'params_dtype'
)
# Determine how to make our models
if
args
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
elif
args
.
model_type
==
'BERT'
:
from
pretrain_bert
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
else
:
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
# supress warning about torch.distributed not being initialized
module
.
MegatronModule
.
embedding_warning_printed
=
True
consumed_train_samples
=
None
consumed_valid_samples
=
None
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
nonlocal
consumed_train_samples
nonlocal
consumed_valid_samples
models
=
[]
for
rank
in
range
(
count
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
rank
)
model_
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)]
margs
.
consumed_train_samples
=
0
margs
.
consumed_valid_samples
=
0
load_checkpoint
(
model_
,
None
,
None
)
assert
(
len
(
model_
)
==
1
)
model_
=
model_
[
0
]
if
consumed_train_samples
is
not
None
:
assert
(
margs
.
consumed_train_samples
==
consumed_train_samples
)
else
:
consumed_train_samples
=
margs
.
consumed_train_samples
if
consumed_valid_samples
is
not
None
:
assert
(
margs
.
consumed_valid_samples
==
consumed_valid_samples
)
else
:
consumed_valid_samples
=
margs
.
consumed_valid_samples
models
.
append
(
model_
)
return
models
if
margs
.
num_layers_per_virtual_pipeline_stage
is
not
None
:
print
(
"Model with an interleaved pipeline schedule are not yet supported."
)
queue
.
put
(
"exit"
)
exit
(
1
)
set_global_variables
(
margs
)
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
margs
.
tensor_model_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
margs
.
pipeline_model_parallel_size
)
fused_kernels
.
load
(
margs
)
# Get true (non-padded) vocab size
if
args
.
true_vocab_size
is
not
None
:
true_vocab_size
=
args
.
true_vocab_size
elif
args
.
vocab_file
is
not
None
:
vocab
=
json
.
load
(
open
(
args
.
vocab_file
))
true_vocab_size
=
len
(
vocab
)
if
args
.
true_vocab_size
is
not
None
and
true_vocab_size
!=
args
.
true_vocab_size
:
print
(
"Both --true-vocab-size and --vocab-file specified and the vocab size does not match, aborting."
)
queue
.
put
(
"exit"
)
exit
(
1
)
else
:
true_vocab_size
=
None
# short aliases
tp_size
=
margs
.
tensor_model_parallel_size
pp_size
=
margs
.
pipeline_model_parallel_size
# metadata
md
=
types
.
SimpleNamespace
()
md
.
model_type
=
args
.
model_type
md
.
num_layers
=
margs
.
num_layers
md
.
hidden_size
=
margs
.
hidden_size
md
.
seq_length
=
margs
.
seq_length
md
.
num_attention_heads
=
margs
.
num_attention_heads
md
.
max_position_embeddings
=
margs
.
max_position_embeddings
md
.
tokenizer_type
=
margs
.
tokenizer_type
md
.
iteration
=
margs
.
iteration
md
.
params_dtype
=
margs
.
params_dtype
md
.
bert_binary_head
=
margs
.
bert_binary_head
md
.
previous_tensor_parallel_size
=
margs
.
tensor_model_parallel_size
md
.
previous_pipeline_parallel_size
=
margs
.
pipeline_model_parallel_size
md
.
true_vocab_size
=
true_vocab_size
md
.
make_vocab_size_divisible_by
=
margs
.
make_vocab_size_divisible_by
# Get first pipe stage
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
0
)
post_process
=
pp_size
==
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
True
,
post_process
)
md
.
consumed_train_samples
=
consumed_train_samples
md
.
consumed_valid_samples
=
consumed_valid_samples
queue
.
put
(
md
)
def
queue_put
(
name
,
msg
):
print
(
f
"sending
{
name
}
"
)
msg
[
"name"
]
=
name
queue
.
put
(
msg
)
# Send embeddings
message
=
{
"position embeddings"
:
models
[
0
].
language_model
.
embedding
.
position_embeddings
.
weight
.
data
,
"word embeddings"
:
torch
.
cat
(
[
models
[
tp_rank
].
language_model
.
embedding
.
word_embeddings
.
weight
.
data
for
tp_rank
in
range
(
tp_size
)],
dim
=
0
)
}
queue_put
(
"embeddings"
,
message
)
total_layer_num
=
0
for
pp_rank
in
range
(
pp_size
):
if
pp_rank
>
0
:
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
pp_rank
)
post_process
=
pp_rank
==
pp_size
-
1
models
=
get_models
(
tp_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer_num
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
message
=
{}
# Get non-parallel tensors from tp_rank 0
layer
=
models
[
0
].
language_model
.
encoder
.
layers
[
layer_num
]
message
[
"input layernorm weight"
]
=
layer
.
input_layernorm
.
weight
.
data
message
[
"input layernorm bias"
]
=
layer
.
input_layernorm
.
bias
.
data
message
[
"dense bias"
]
=
layer
.
self_attention
.
dense
.
bias
.
data
message
[
"post layernorm weight"
]
=
layer
.
post_attention_layernorm
.
weight
.
data
message
[
"post layernorm bias"
]
=
layer
.
post_attention_layernorm
.
bias
.
data
message
[
"mlp l1 bias"
]
=
layer
.
mlp
.
dense_4h_to_h
.
bias
.
data
# Grab all parallel tensors for this layer
qkv_weight
=
[]
qkv_bias
=
[]
dense_weight
=
[]
mlp_l0_weight
=
[]
mlp_l0_bias
=
[]
mlp_l1_weight
=
[]
for
tp_rank
,
model
in
enumerate
(
models
):
layer
=
model
.
language_model
.
encoder
.
layers
[
layer_num
]
qkv_weight
.
append
(
layer
.
self_attention
.
query_key_value
.
weight
.
data
)
qkv_bias
.
append
(
layer
.
self_attention
.
query_key_value
.
bias
.
data
)
dense_weight
.
append
(
layer
.
self_attention
.
dense
.
weight
.
data
)
mlp_l0_weight
.
append
(
layer
.
mlp
.
dense_h_to_4h
.
weight
.
data
)
mlp_l0_bias
.
append
(
layer
.
mlp
.
dense_h_to_4h
.
bias
.
data
)
mlp_l1_weight
.
append
(
layer
.
mlp
.
dense_4h_to_h
.
weight
.
data
)
# concat them
message
[
"qkv weight"
]
=
torch
.
cat
(
qkv_weight
,
dim
=
0
)
message
[
"qkv bias"
]
=
torch
.
cat
(
qkv_bias
,
dim
=
0
)
message
[
"dense weight"
]
=
torch
.
cat
(
dense_weight
,
dim
=
1
)
message
[
"mlp l0 weight"
]
=
torch
.
cat
(
mlp_l0_weight
,
dim
=
0
)
message
[
"mlp l0 bias"
]
=
torch
.
cat
(
mlp_l0_bias
,
dim
=
0
)
message
[
"mlp l1 weight"
]
=
torch
.
cat
(
mlp_l1_weight
,
dim
=
1
)
queue_put
(
f
"transformer layer
{
total_layer_num
}
"
,
message
)
total_layer_num
=
total_layer_num
+
1
# Send final layernorm from tp_rank 0
message
=
{
"weight"
:
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
,
"bias"
:
models
[
0
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
}
queue_put
(
"final layernorm"
,
message
)
# Send BERT lm head and binary head if it exists
if
md
.
model_type
==
'BERT'
:
print
(
"Sending LM Pooler"
)
message
=
{
"weight"
:
models
[
0
].
language_model
.
pooler
.
dense
.
weight
.
data
,
"bias"
:
models
[
0
].
language_model
.
pooler
.
dense
.
bias
.
data
}
queue_put
(
"pooler"
,
message
)
message
=
{
"dense weight"
:
models
[
0
].
lm_head
.
dense
.
weight
.
data
,
"dense bias"
:
models
[
0
].
lm_head
.
dense
.
bias
.
data
,
"layernorm weight"
:
models
[
0
].
lm_head
.
layernorm
.
weight
.
data
,
"layernorm bias"
:
models
[
0
].
lm_head
.
layernorm
.
bias
.
data
}
queue_put
(
"lm head"
,
message
)
if
md
.
bert_binary_head
:
print
(
"Sending BERT Binary head"
)
queue
.
put
(
"binary head"
)
message
=
{
"weight"
:
models
[
0
].
binary_head
.
weight
.
data
,
"bias"
:
models
[
0
].
binary_head
.
bias
.
data
}
queue_put
(
"binary head"
,
message
)
queue
.
put
(
"done"
)
def
load_checkpoint
(
queue
,
args
):
try
:
_load_checkpoint
(
queue
,
args
)
except
:
queue
.
put
(
"exit"
)
raise
tools/checkpoint_saver_megatron.py
0 → 100644
View file @
2eea6216
import
argparse
from
collections.abc
import
Mapping
import
concurrent.futures
import
os
import
sys
import
torch
def
add_arguments
(
parser
):
group
=
parser
.
add_argument_group
(
title
=
'Megatron saver'
)
group
.
add_argument
(
'--megatron-path'
,
type
=
str
,
default
=
None
,
help
=
'Base directory of Megatron repository'
)
group
.
add_argument
(
'--target-tensor-parallel-size'
,
type
=
int
,
help
=
'Target tensor model parallel size, defaults to the tensor parallel size '
'in the input checkpoint if provided by the loader, otherwise to 1'
)
group
.
add_argument
(
'--target-pipeline-parallel-size'
,
type
=
int
,
help
=
'Target tensor model parallel size, default to the pipeline parall size '
'in the input checkpoint if provided by the loader, otherwise to 1'
)
def
save_checkpoint
(
queue
,
args
):
# Search in directory above this
sys
.
path
.
append
(
os
.
path
.
abspath
(
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
os
.
path
.
pardir
)))
if
args
.
megatron_path
is
not
None
:
sys
.
path
.
insert
(
0
,
args
.
megatron_path
)
try
:
from
megatron.arguments
import
(
parse_args
,
validate_args
)
from
megatron.checkpointing
import
save_checkpoint
from
megatron.global_vars
import
set_global_variables
,
get_args
from
megatron.model
import
ModelType
from
megatron.tokenizer.tokenizer
import
_vocab_size_with_padding
from
megatron
import
mpu
,
fused_kernels
except
ModuleNotFoundError
:
print
(
"Unable to import Megatron, please specify the path to Megatron using --megatron-path. Exiting."
)
exit
(
1
)
def
queue_get
(
name
=
None
):
val
=
queue
.
get
()
if
val
==
"exit"
:
print
(
"Loader exited, exiting saver"
)
exit
(
1
)
if
name
is
not
None
and
args
.
checking
and
val
[
"name"
]
!=
name
:
val_name
=
val
[
"name"
]
print
(
f
'Unexpected message. Expecting "
{
name
}
" but got "
{
val_name
}
". Exiting saver.'
)
exit
(
1
)
if
name
is
not
None
:
print
(
f
"received
{
name
}
"
)
return
val
def
check_message
(
msg
):
if
not
args
.
checking
:
return
msg_name
=
msg
.
pop
(
"name"
)
if
len
(
msg
.
keys
())
>
0
:
print
(
f
"Unexpected values in
{
msg_name
}
:"
)
for
key
in
msg
.
keys
():
print
(
f
"
{
key
}
"
)
print
(
f
"Exiting. If you want to ignore this, use the argument --no-checking."
)
exit
(
1
)
md
=
queue_get
()
if
args
.
target_tensor_parallel_size
is
None
:
if
hasattr
(
md
,
'previous_tensor_parallel_size'
):
args
.
target_tensor_parallel_size
=
md
.
previous_tensor_parallel_size
else
:
print
(
"loader did not provide a tensor parallel size and --target-tensor-parallel-size not provided on command line. "
"Default to 1."
)
args
.
target_tensor_parallel_size
=
1
if
args
.
target_pipeline_parallel_size
is
None
:
if
hasattr
(
md
,
'previous_pipeline_parallel_size'
):
args
.
target_pipeline_parallel_size
=
md
.
previous_pipeline_parallel_size
else
:
print
(
"loader did not provide a pipeline parallel size and --target-pipeline-parallel-size not provided on command line. "
"Default to 1."
)
args
.
target_pipeline_parallel_size
=
1
# Arguments do sanity checks on the world size, but we don't care,
# so trick it into thinking we are plenty of processes
if
args
.
target_tensor_parallel_size
is
not
None
and
args
.
target_pipeline_parallel_size
is
not
None
:
os
.
environ
[
"WORLD_SIZE"
]
=
f
'
{
args
.
target_tensor_parallel_size
*
args
.
target_pipeline_parallel_size
}
'
# We want all arguments to come from us
sys
.
argv
=
[
'script.py'
,
'--num-layers'
,
str
(
md
.
num_layers
),
'--hidden-size'
,
str
(
md
.
hidden_size
),
'--seq-length'
,
str
(
md
.
seq_length
),
'--num-attention-heads'
,
str
(
md
.
num_attention_heads
),
'--max-position-embeddings'
,
str
(
md
.
max_position_embeddings
),
'--tokenizer-type'
,
str
(
md
.
tokenizer_type
),
'--tensor-model-parallel-size'
,
str
(
args
.
target_tensor_parallel_size
),
'--pipeline-model-parallel-size'
,
str
(
args
.
target_pipeline_parallel_size
),
'--no-masked-softmax-fusion'
,
'--no-bias-gelu-fusion'
,
'--no-bias-dropout-fusion'
,
'--use-cpu-initialization'
,
'--micro-batch-size'
,
'1'
,
'--no-load-optim'
,
'--no-load-rng'
,
'--no-save-optim'
,
'--no-save-rng'
,
'--no-initialization'
,
'--save-interval'
,
'1'
,
'--save'
,
args
.
save_dir
]
if
md
.
make_vocab_size_divisible_by
is
not
None
:
sys
.
argv
.
extend
([
'--make-vocab-size-divisible-by'
,
str
(
md
.
make_vocab_size_divisible_by
)])
if
md
.
params_dtype
==
torch
.
float16
:
sys
.
argv
.
append
(
'--fp16'
)
elif
md
.
params_dtype
==
torch
.
bfloat16
:
sys
.
argv
.
append
(
'--bf16'
)
if
md
.
model_type
==
'BERT'
and
not
md
.
bert_binary_head
:
sys
.
argv
.
append
(
'--bert-no-binary-head'
)
margs
=
parse_args
()
validate_args
(
margs
)
set_global_variables
(
margs
)
# margs = megatron args
margs
=
get_args
()
if
hasattr
(
md
,
'consumed_train_samples'
):
margs
.
consumed_train_samples
=
md
.
consumed_train_samples
margs
.
consumed_valid_samples
=
md
.
consumed_valid_samples
print
(
f
"Setting consumed_train_samples to
{
margs
.
consumed_train_samples
}
"
f
" and consumed_valid_samples to
{
margs
.
consumed_valid_samples
}
"
)
else
:
print
(
"consumed_train_samples not provided."
)
# Determine how to make our models
if
md
.
model_type
==
'GPT'
:
from
pretrain_gpt
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
elif
md
.
model_type
==
'BERT'
:
from
pretrain_bert
import
model_provider
margs
.
model_type
=
ModelType
.
encoder_or_decoder
else
:
raise
Exception
(
f
'unrecognized model type:
{
args
.
model_type
}
'
)
def
get_models
(
count
,
dtype
,
pre_process
,
post_process
):
models
=
[
model_provider
(
pre_process
,
post_process
).
to
(
dtype
)
for
_
in
range
(
count
)]
return
models
# fake initializing distributed
mpu
.
initialize
.
set_tensor_model_parallel_world_size
(
args
.
target_tensor_parallel_size
)
mpu
.
initialize
.
set_pipeline_model_parallel_world_size
(
args
.
target_pipeline_parallel_size
)
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
0
)
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
0
)
fused_kernels
.
load
(
margs
)
# Embeddings
#-----------
embeddings_msg
=
queue_get
(
"embeddings"
)
pos_embed
=
embeddings_msg
.
pop
(
"position embeddings"
)
orig_word_embed
=
embeddings_msg
.
pop
(
"word embeddings"
)
check_message
(
embeddings_msg
)
# Deal with padding
if
md
.
true_vocab_size
is
not
None
:
# figure out what our padded vocab size is
orig_vocab_size
=
orig_word_embed
.
shape
[
0
]
margs
.
padded_vocab_size
=
_vocab_size_with_padding
(
md
.
true_vocab_size
,
margs
)
# Cut out extra padding we don't need
if
orig_vocab_size
>
margs
.
padded_vocab_size
:
full_word_embed
=
orig_word_embed
[
0
:
margs
.
padded_vocab_size
,:]
# Expanding embedding to larger size by replicating final entry
elif
orig_vocab_size
<
margs
.
padded_vocab_size
:
padding_size
=
margs
.
padded_vocab_size
-
orig_vocab_size
full_word_embed
=
torch
.
cat
((
orig_word_embed
,
orig_word_embed
[
-
1
].
unsqueeze
(
0
).
expand
(
padding_size
,
-
1
)))
# Same size!
else
:
full_word_embed
=
orig_word_embed
else
:
print
(
"Original vocab size not specified, leaving embedding table as-is. "
"If you've changed the tensor parallel size this could cause problems."
)
margs
.
padded_vocab_size
=
orig_word_embed
.
shape
[
0
]
full_word_embed
=
orig_word_embed
# Split into new tensor model parallel sizes
out_word_embed
=
torch
.
chunk
(
full_word_embed
,
args
.
target_tensor_parallel_size
,
dim
=
0
)
# Make models for first pipeline stage and fill in embeddings
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
0
)
post_process
=
args
.
target_pipeline_parallel_size
==
1
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
True
,
post_process
)
for
tp_rank
,
model
in
enumerate
(
models
):
print
(
f
"word embeddings shape
{
model
.
language_model
.
embedding
.
word_embeddings
.
weight
.
shape
}
"
)
model
.
language_model
.
embedding
.
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
data
.
copy_
(
pos_embed
)
# Transformer layers
#-------------------
total_layer_num
=
0
for
pp_rank
in
range
(
args
.
target_pipeline_parallel_size
):
# For later pipeline parallel ranks, make the new models
if
pp_rank
>
0
:
mpu
.
initialize
.
set_pipeline_model_parallel_rank
(
pp_rank
)
post_process
=
pp_rank
==
args
.
target_pipeline_parallel_size
-
1
models
=
get_models
(
args
.
target_tensor_parallel_size
,
md
.
params_dtype
,
False
,
post_process
)
for
layer
in
range
(
len
(
models
[
0
].
language_model
.
encoder
.
layers
)):
msg
=
queue_get
(
f
"transformer layer
{
total_layer_num
}
"
)
# duplicated tensors
input_layernorm_weight
=
msg
.
pop
(
"input layernorm weight"
)
input_layernorm_bias
=
msg
.
pop
(
"input layernorm bias"
)
dense_bias
=
msg
.
pop
(
"dense bias"
)
post_layernorm_weight
=
msg
.
pop
(
"post layernorm weight"
)
post_layernorm_bias
=
msg
.
pop
(
"post layernorm bias"
)
mlp_l1_bias
=
msg
.
pop
(
"mlp l1 bias"
)
# Split up the parallel tensors
qkv_weight
=
torch
.
chunk
(
msg
.
pop
(
"qkv weight"
),
args
.
target_tensor_parallel_size
,
dim
=
0
)
qkv_bias
=
torch
.
chunk
(
msg
.
pop
(
"qkv bias"
),
args
.
target_tensor_parallel_size
,
dim
=
0
)
dense_weight
=
torch
.
chunk
(
msg
.
pop
(
"dense weight"
),
args
.
target_tensor_parallel_size
,
dim
=
1
)
mlp_l0_weight
=
torch
.
chunk
(
msg
.
pop
(
"mlp l0 weight"
),
args
.
target_tensor_parallel_size
,
dim
=
0
)
mlp_l0_bias
=
torch
.
chunk
(
msg
.
pop
(
"mlp l0 bias"
),
args
.
target_tensor_parallel_size
,
dim
=
0
)
mlp_l1_weight
=
torch
.
chunk
(
msg
.
pop
(
"mlp l1 weight"
),
args
.
target_tensor_parallel_size
,
dim
=
1
)
# Save them to the model
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
l
=
models
[
tp_rank
].
language_model
.
encoder
.
layers
[
layer
]
l
.
input_layernorm
.
weight
.
data
.
copy_
(
input_layernorm_weight
)
l
.
input_layernorm
.
bias
.
data
.
copy_
(
input_layernorm_bias
)
l
.
self_attention
.
query_key_value
.
weight
.
data
.
copy_
(
qkv_weight
[
tp_rank
])
l
.
self_attention
.
query_key_value
.
bias
.
data
.
copy_
(
qkv_bias
[
tp_rank
])
l
.
self_attention
.
dense
.
weight
.
data
.
copy_
(
dense_weight
[
tp_rank
])
l
.
self_attention
.
dense
.
bias
.
data
.
copy_
(
dense_bias
)
l
.
post_attention_layernorm
.
weight
.
data
.
copy_
(
post_layernorm_weight
)
l
.
post_attention_layernorm
.
bias
.
data
.
copy_
(
post_layernorm_bias
)
l
.
mlp
.
dense_h_to_4h
.
weight
.
data
.
copy_
(
mlp_l0_weight
[
tp_rank
])
l
.
mlp
.
dense_h_to_4h
.
bias
.
data
.
copy_
(
mlp_l0_bias
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
weight
.
data
.
copy_
(
mlp_l1_weight
[
tp_rank
])
l
.
mlp
.
dense_4h_to_h
.
bias
.
data
.
copy_
(
mlp_l1_bias
)
total_layer_num
=
total_layer_num
+
1
check_message
(
msg
)
if
post_process
:
msg
=
queue_get
(
"final layernorm"
)
final_layernorm_weight
=
msg
.
pop
(
"weight"
)
final_layernorm_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
weight
.
data
.
copy_
(
final_layernorm_weight
)
models
[
tp_rank
].
language_model
.
encoder
.
final_layernorm
.
bias
.
data
.
copy_
(
final_layernorm_bias
)
if
pp_rank
!=
0
:
# Copy word embeddings to final pipeline rank
models
[
tp_rank
].
word_embeddings
.
weight
.
data
.
copy_
(
out_word_embed
[
tp_rank
])
del
final_layernorm_weight
del
final_layernorm_bias
check_message
(
msg
)
msg
=
queue_get
()
if
msg
!=
"done"
and
msg
[
"name"
]
==
"pooler"
:
if
not
hasattr
(
models
[
0
].
language_model
,
'pooler'
):
print
(
"ERROR: got a pooler, but model does not have one"
)
exit
(
1
)
print
(
"received pooler"
)
pooler_weight
=
msg
.
pop
(
"weight"
)
pooler_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
weight
.
data
.
copy_
(
pooler_weight
)
models
[
tp_rank
].
language_model
.
pooler
.
dense
.
bias
.
data
.
copy_
(
pooler_bias
)
del
pooler_weight
del
pooler_bias
check_message
(
msg
)
msg
=
queue_get
()
if
msg
!=
"done"
and
msg
[
"name"
]
==
"lm head"
:
if
not
hasattr
(
models
[
0
],
'lm_head'
):
print
(
"ERROR: got an lm head, but model does not have one"
)
exit
(
1
)
print
(
"received lm head"
)
lm_head_dense_weight
=
msg
.
pop
(
"dense weight"
)
lm_head_dense_bias
=
msg
.
pop
(
"dense bias"
)
lm_head_layernorm_weight
=
msg
.
pop
(
"layernorm weight"
)
lm_head_layernorm_bias
=
msg
.
pop
(
"layernorm bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
lm_head
.
dense
.
weight
.
data
.
copy_
(
lm_head_dense_weight
)
models
[
tp_rank
].
lm_head
.
dense
.
bias
.
data
.
copy_
(
lm_head_dense_bias
)
models
[
tp_rank
].
lm_head
.
layernorm
.
weight
.
data
.
copy_
(
lm_head_layernorm_weight
)
models
[
tp_rank
].
lm_head
.
layernorm
.
bias
.
data
.
copy_
(
lm_head_layernorm_bias
)
check_message
(
msg
)
msg
=
queue_get
()
if
msg
!=
"done"
and
msg
[
"name"
]
==
"binary head"
:
if
not
hasattr
(
models
[
0
],
'binary_head'
):
print
(
"ERROR: got a binary head, but model does not have one"
)
exit
(
1
)
print
(
"received binary head"
)
binary_head_weight
=
msg
.
pop
(
"weight"
)
binary_head_bias
=
msg
.
pop
(
"bias"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
models
[
tp_rank
].
binary_head
.
weight
.
data
.
copy_
(
binary_head_weight
)
models
[
tp_rank
].
binary_head
.
bias
.
data
.
copy_
(
binary_head_bias
)
check_message
(
msg
)
msg
=
queue_get
()
if
msg
!=
"done"
:
print
(
"ERROR: got some more data but was expecting to be done"
)
for
tp_rank
in
range
(
args
.
target_tensor_parallel_size
):
mpu
.
initialize
.
set_tensor_model_parallel_rank
(
tp_rank
)
save_checkpoint
(
md
.
iteration
,
[
models
[
tp_rank
]],
None
,
None
)
print
(
"Done!"
)
tools/checkpoint_util.py
0 → 100644
View file @
2eea6216
import
argparse
import
importlib
import
torch.multiprocessing
as
mp
import
os
import
sys
# A loader is a python file with at least two functions
# - add_arguments - takes in a parser and adds any arguments needed
# - load_checkpoint - takes in the queue and parsed arguments
# A saver is similar but has save_checkpoint instead of
# load_checkpoint
# The loader and saver process are each given a queue, the loader
# should load the checkpoint and send the weights in messages in the
# following order, the saver should receive them in this order and
# save the checkpoints. A message consists of a python dictionary with
# a "name" for error checking and an entry for each tensor as
# indicated below. Note that the weight sent over the queue are the
# full model weights, nothing split.
# If the loader ever sends "exit" to the queue, that means something
# went wrong and it is exiting.
# - Metadata Namespace with the following attributes:
# model_type - GPT, BERT, T5, etc. (Part of protocol to allow this to be deduced later instead of given on command line)
# num_layers - Number of transformer layers
# hidden_size
# seq_length
# num_attention_heads
# max_position_embeddings
# tokenizer_type
# iteration
# params_dtype
# bert_binary_head - Used only if model_type is BERT
# previous_tensor_parallel_size - Optional
# previous_pipeline_parallel_size - Optional
# true_vocab_size
# make_vocab_size_divisble_by
# consumed_train_samples
# consumed_valid_samples
# messages
# {
# "name": "embeddings"
# "position embeddings"
# "word embeddings"
# }
# (for each transformer layer):
# {
# "name": "transformer layer N"
# "input layernorm weight"
# "input layernorm bias"
# "qkv weight"
# "qkv bias"
# "dense weight"
# "dense bias"
# "post layernorm weight"
# "post layernorm bias"
# "mlp l0 weight"
# "mlp l0 bias"
# "mlp l1 weight"
# "mlp l1 bias"
# }
# {
# "name": "final layer norm"
# "weight"
# "bias"
# }
# if present (i.e. for BERT):
# {
# "name": "pooler"
# "weight"
# "bias"
# }
# {
# "name": "lm head"
# "dense weight"
# "dense bias"
# "layernorm weight"
# "layernorm bias"
# }
# {
# "name": "binary head"
# "weight"
# "bias"
# }
# - "done"
def
load_plugin
(
plugin_type
,
name
):
module_name
=
f
"checkpoint_
{
plugin_type
}
_
{
name
}
"
try
:
plugin
=
importlib
.
import_module
(
module_name
)
except
ModuleNotFoundError
:
module_name
=
name
try
:
plugin
=
importlib
.
import_module
(
module_name
)
except
ModuleNotFoundError
:
sys
.
exit
(
f
"Unable to load
{
plugin_type
}
plugin
{
name
}
. Exiting."
)
if
not
hasattr
(
plugin
,
'add_arguments'
):
sys
.
exit
(
f
"
{
module_name
}
module is not a plugin. Exiting."
)
print
(
f
"Loaded
{
module_name
}
as the
{
plugin_type
}
."
)
return
plugin
def
main
():
import
argparse
parser
=
argparse
.
ArgumentParser
(
description
=
"Megatron Checkpoint Utility Arguments"
,
allow_abbrev
=
False
,
conflict_handler
=
'resolve'
)
parser
.
add_argument
(
'--model-type'
,
type
=
str
,
required
=
True
,
choices
=
[
'GPT'
,
'BERT'
],
help
=
'Type of the model'
)
parser
.
add_argument
(
'--loader'
,
type
=
str
,
default
=
'megatron'
,
help
=
'Module name to load checkpoint, should be on python path'
)
parser
.
add_argument
(
'--saver'
,
type
=
str
,
default
=
'megatron'
,
help
=
'Module name to save checkpoint, shdoul be on python path'
)
parser
.
add_argument
(
'--load-dir'
,
type
=
str
,
required
=
True
,
help
=
'Directory to load model checkpoint from'
)
parser
.
add_argument
(
'--save-dir'
,
type
=
str
,
required
=
True
,
help
=
'Directory to save model checkpoint to'
)
parser
.
add_argument
(
'--max-queue-size'
,
type
=
int
,
default
=
50
,
help
=
'Maximum number of tensors in the queue'
)
parser
.
add_argument
(
'--no-checking'
,
action
=
'store_false'
,
help
=
'Do not perform checking on the name and ordering of weights'
,
dest
=
'checking'
)
known_args
,
_
=
parser
.
parse_known_args
()
loader
=
load_plugin
(
'loader'
,
known_args
.
loader
)
saver
=
load_plugin
(
'saver'
,
known_args
.
saver
)
loader
.
add_arguments
(
parser
)
saver
.
add_arguments
(
parser
)
args
=
parser
.
parse_args
()
queue
=
mp
.
Queue
(
maxsize
=
args
.
max_queue_size
)
print
(
"Starting saver..."
)
saver_proc
=
mp
.
Process
(
target
=
saver
.
save_checkpoint
,
args
=
(
queue
,
args
))
saver_proc
.
start
()
print
(
"Starting loader..."
)
loader
.
load_checkpoint
(
queue
,
args
)
print
(
"Waiting for saver to complete..."
)
saver_proc
.
join
()
if
__name__
==
'__main__'
:
main
()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment