Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b8428a7f
Commit
b8428a7f
authored
Jul 15, 2022
by
peng xu
Browse files
Merge branch 'main' into beam_search
parents
e5034150
3f4e71df
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2476 additions
and
334 deletions
+2476
-334
megatron/mpu/random.py
megatron/mpu/random.py
+17
-13
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+22
-10
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+27
-28
megatron/optimizer/distrib_optimizer.py
megatron/optimizer/distrib_optimizer.py
+696
-0
megatron/optimizer/optimizer.py
megatron/optimizer/optimizer.py
+337
-121
megatron/p2p_communication.py
megatron/p2p_communication.py
+6
-3
megatron/schedules.py
megatron/schedules.py
+21
-10
megatron/training.py
megatron/training.py
+41
-53
pretrain_vision_classify.py
pretrain_vision_classify.py
+19
-5
pretrain_vision_dino.py
pretrain_vision_dino.py
+121
-0
pretrain_vision_inpaint.py
pretrain_vision_inpaint.py
+150
-0
tasks/finetune_utils.py
tasks/finetune_utils.py
+1
-1
tasks/vision/classification/classification.py
tasks/vision/classification/classification.py
+94
-0
tasks/vision/classification/eval_utils.py
tasks/vision/classification/eval_utils.py
+3
-5
tasks/vision/finetune_utils.py
tasks/vision/finetune_utils.py
+60
-63
tasks/vision/main.py
tasks/vision/main.py
+24
-22
tasks/vision/segmentation/cityscapes.py
tasks/vision/segmentation/cityscapes.py
+207
-0
tasks/vision/segmentation/data.py
tasks/vision/segmentation/data.py
+154
-0
tasks/vision/segmentation/finetune_segformer.py
tasks/vision/segmentation/finetune_segformer.py
+251
-0
tasks/vision/segmentation/finetune_setr.py
tasks/vision/segmentation/finetune_setr.py
+225
-0
No files found.
megatron/mpu/random.py
View file @
b8428a7f
...
...
@@ -87,17 +87,21 @@ def split_tensor_into_1d_equal_chunks(tensor, new_buffer=False):
def
gather_split_1d_tensor
(
tensor
):
"""Opposite of above function, gather values from model parallel ranks."""
world_size
=
get_tensor_model_parallel_world_size
()
numel
=
torch
.
numel
(
tensor
)
numel_gathered
=
world_size
*
numel
numel_gathered
=
torch
.
numel
(
tensor
)
*
\
get_tensor_model_parallel_world_size
()
gathered
=
torch
.
empty
(
numel_gathered
,
dtype
=
tensor
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
chunks
=
[
gathered
[
i
*
numel
:(
i
+
1
)
*
numel
]
for
i
in
range
(
world_size
)]
torch
.
distributed
.
all_gather
(
chunks
,
tensor
,
group
=
get_tensor_model_parallel_group
())
# TODO: This API is experimental in pytorch (as of Feb 2022) and
# this might break in future pytorch releases. We chose this API
# as opposed to torch.distributed.all_gather for efficiency reasons.
# This API calls directly NCCL all-gather versus the former does
# internal copies and can potentially cause slow down.
torch
.
distributed
.
_all_gather_base
(
gathered
,
tensor
,
group
=
get_tensor_model_parallel_group
())
return
gathered
def
_kernel_make_viewless_tensor
(
inp
,
requires_grad
):
'''Make a viewless tensor.
...
...
@@ -303,10 +307,10 @@ class CheckpointFunction(torch.autograd.Function):
tracked/set/reset.
"""
@
staticmethod
def
forward
(
ctx
,
run_function
,
distribute_
checkpoint
ed_activations
,
*
args
):
def
forward
(
ctx
,
run_function
,
distribute_
sav
ed_activations
,
*
args
):
ctx
.
run_function
=
run_function
ctx
.
distribute_
checkpoint
ed_activations
\
=
distribute_
checkpoint
ed_activations
ctx
.
distribute_
sav
ed_activations
\
=
distribute_
sav
ed_activations
# Copy the rng states.
ctx
.
fwd_cpu_rng_state
=
torch
.
get_rng_state
()
...
...
@@ -318,7 +322,7 @@ class CheckpointFunction(torch.autograd.Function):
# Divide hidden states across model parallel group and only keep
# the chunk corresponding to the current rank.
if
distribute_
checkpoint
ed_activations
:
if
distribute_
sav
ed_activations
:
ctx
.
input_0_shape
=
args
[
0
].
data
.
shape
safely_set_viewless_tensor_data
(
args
[
0
],
...
...
@@ -335,7 +339,7 @@ class CheckpointFunction(torch.autograd.Function):
raise
RuntimeError
(
"Checkpointing is not compatible with .grad(), "
"please use .backward() if possible"
)
inputs
=
ctx
.
saved_tensors
if
ctx
.
distribute_
checkpoint
ed_activations
:
if
ctx
.
distribute_
sav
ed_activations
:
safely_set_viewless_tensor_data
(
inputs
[
0
],
gather_split_1d_tensor
(
inputs
[
0
].
data
).
view
(
ctx
.
input_0_shape
))
...
...
@@ -368,8 +372,8 @@ class CheckpointFunction(torch.autograd.Function):
return
(
None
,
None
)
+
grads
def
checkpoint
(
function
,
distribute_
checkpoint
ed_activations
,
*
args
):
def
checkpoint
(
function
,
distribute_
sav
ed_activations
,
*
args
):
"""Checkpoint a model or part of the model.
This has been directly copied from torch.utils.checkpoint."""
return
CheckpointFunction
.
apply
(
function
,
distribute_
checkpoint
ed_activations
,
*
args
)
distribute_
sav
ed_activations
,
*
args
)
megatron/optimizer/__init__.py
View file @
b8428a7f
...
...
@@ -17,8 +17,8 @@ from apex.optimizers import FusedAdam as Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
.distrib_optimizer
import
DistributedOptimizer
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
...
...
@@ -105,7 +105,11 @@ def get_megatron_optimizer(model,
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
if
args
.
fp16
or
args
.
bf16
:
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if
args
.
fp16
or
args
.
bf16
or
args
.
use_distributed_optimizer
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
...
...
@@ -114,9 +118,11 @@ def get_megatron_optimizer(model,
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
if
args
.
fp16
:
...
...
@@ -129,16 +135,22 @@ def get_megatron_optimizer(model,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
return
Float16OptimizerWithFloat16Params
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
bf16
,
grad_scaler
)
opt_ty
=
DistributedOptimizer
\
if
args
.
use_distributed_optimizer
else
\
Float16OptimizerWithFloat16Params
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
fp16
,
args
.
bf16
,
grad_scaler
,
model
)
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
)
args
.
use_contiguous_buffers_in_local_ddp
,
model
)
megatron/optimizer/clip_grads.py
View file @
b8428a7f
...
...
@@ -21,12 +21,13 @@ from torch._six import inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron
import
mpu
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
def
clip_grad_norm_fp32
(
parameters
,
max_norm
,
norm_type
=
2
):
def
clip_grad_norm_fp32
(
parameters
,
grads_for_norm
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
...
...
@@ -37,9 +38,13 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
optimizer, this is passed as an argument.
Returns:
Total norm of the parameters (viewed as a single vector).
...
...
@@ -47,25 +52,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
if
isinstance
(
grads_for_norm
,
torch
.
Tensor
):
grads_for_norm
=
[
grads_for_norm
]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
# Grads.
grads
=
[]
grads_for_norm
=
[]
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
:
grad
=
param
.
grad
.
detach
()
if
grad_not_none
:
# Make sure the grads are in fp32
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grads
.
append
(
grad
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
grads
.
append
(
param
.
grad
.
detach
())
# Norm parameters.
max_norm
=
float
(
max_norm
)
...
...
@@ -79,7 +74,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
...
...
@@ -88,12 +83,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
if
grads_for_norm
:
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
else
:
grad_norm
=
torch
.
cuda
.
FloatTensor
([
0
])
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm
=
grad_norm
**
norm_type
...
...
@@ -106,7 +104,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_
model_parallel_group
()
)
group
=
model_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# Scale.
...
...
@@ -121,7 +119,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
return
total_norm
def
count_zeros_fp32
(
parameters
):
def
count_zeros_fp32
(
parameters
,
model_parallel_group
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
...
...
@@ -130,7 +128,7 @@ def count_zeros_fp32(parameters):
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
0.0
total_num_zeros
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
...
...
@@ -143,7 +141,8 @@ def count_zeros_fp32(parameters):
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
mpu
.
get_model_parallel_group
())
group
=
model_parallel_group
)
total_num_zeros
=
total_num_zeros
.
item
()
return
total_num_zeros
megatron/optimizer/distrib_optimizer.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Megatron distributed optimizer."""
import
math
import
torch
from
megatron
import
get_args
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
.optimizer
import
MixedPrecisionOptimizer
,
_zero_grad_group_helper
class
Range
:
"""
A range represents a start and end points for indexing a shard
from a full tensor.
"""
def
__init__
(
self
,
start
,
end
):
self
.
start
=
start
self
.
end
=
end
self
.
size
=
end
-
start
def
normalize
(
self
,
start
=
0
):
return
Range
(
start
,
start
+
self
.
size
)
def
__str__
(
self
):
return
"%d,%d [%d]"
%
(
self
.
start
,
self
.
end
,
self
.
size
)
class
DistributedOptimizer
(
MixedPrecisionOptimizer
):
"""Distributed optimizer, for all data types (fp16, bf16, and fp32).
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
@
classmethod
def
build_model_gbuf_param_range_map
(
cls
,
model
,
dtype
,
gbuf_world_range
):
"""
Build mapping from param reference to grad buffer shard ranges.
This method builds a mapping from parameter references to grad
buffer shard ranges, specific to each data-parallel (DP) rank's
set of 'owned' parameters. Each grad buffer (padded to be an even
multiple of DP-world-size) is conceptually divided into DP-world-size
contiguous regions, where each DP rank 'owns' a contiguous regions.
Ownership in this sense means DP rank is responsible for reducing
the relevant subset of grads, and updating the relevant subset of
params.
This conceptual partitioning of the grad buffer does NOT respect
parameter boundaries, and as such it is assumed that each created
range references a shard (or subset) of the full parameter. It is
easiest to think of each DP rank as operating (i.e., reducing,
gathering) purely on views into the grad buffer, for all model-to-
main & main-to-model operations.
This method creates three ranges:
- The param's range within the entire grad buffer (i.e., world index).
- The param's range within the DP rank's local view of the grad buffer.
- The param's range within itself (i.e., its shard).
"""
# Param range map.
param_world_index_map
=
model
.
_grad_buffer_param_index_map
[
dtype
]
param_range_map
=
{}
for
param
,
param_world_indexes
in
param_world_index_map
.
items
():
# Param range.
param_world_start
,
param_world_end
=
param_world_indexes
param_local_start
=
max
(
0
,
param_world_start
-
gbuf_world_range
.
start
)
param_local_end
=
min
(
gbuf_world_range
.
size
,
param_world_end
-
gbuf_world_range
.
start
)
# Add param, if within local gbuf range.
if
param_local_end
>
param_local_start
:
param_local_range
=
Range
(
param_local_start
,
param_local_end
)
param_world_range
=
param_local_range
.
normalize
(
param_local_start
+
gbuf_world_range
.
start
)
sub_param_start
=
max
(
0
,
gbuf_world_range
.
start
-
param_world_start
)
sub_param_range
=
param_local_range
.
normalize
(
sub_param_start
)
param_range_map
[
param
]
=
{
"gbuf_world"
:
param_world_range
,
"gbuf_local"
:
param_local_range
,
"param"
:
sub_param_range
,
}
return
param_range_map
@
classmethod
def
build_model_gbuf_range
(
cls
,
model
,
dtype
):
"""
Build mapping between params and their grad buffers.
This method does the initial setup for the method above. This setup
includes determining the shard ranges into the DDP's grad buffer for
each data-parallel (DP) rank. Each DP rank keeps range info for
all other DP ranks, for the purpose of creating args for
reduce-scatter and all-gather.
"""
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Grad buffer range.
grad_buffer
=
model
.
_grad_buffers
[
dtype
]
gbuf_size
=
grad_buffer
.
numel
max_gbuf_range_size
=
int
(
math
.
ceil
(
gbuf_size
/
data_parallel_world_size
))
# All world ranges. (i.e., across all data parallel ranks)
gbuf_world_all_ranges
=
[]
for
r
in
range
(
data_parallel_world_size
):
gbuf_world_start
=
r
*
max_gbuf_range_size
gbuf_world_end
=
min
(
gbuf_size
,
gbuf_world_start
+
max_gbuf_range_size
)
gbuf_world_range
=
Range
(
gbuf_world_start
,
gbuf_world_end
)
gbuf_world_all_ranges
.
append
(
gbuf_world_range
)
# Local DP's ranges.
gbuf_world_range
=
gbuf_world_all_ranges
[
data_parallel_rank
]
gbuf_local_range
=
gbuf_world_range
.
normalize
()
# Get each param's ranges.
param_range_map
=
cls
.
build_model_gbuf_param_range_map
(
model
,
dtype
,
gbuf_world_range
)
# Group into dict.
data
=
{
"local"
:
gbuf_local_range
,
"world"
:
gbuf_world_range
,
"world_all"
:
gbuf_world_all_ranges
,
"param_map"
:
param_range_map
,
"max_range_size"
:
max_gbuf_range_size
,
}
return
data
@
classmethod
def
build_model_gbuf_range_map
(
cls
,
model
):
"""
Create param-to-grad-buffer mappings, for grad buffer data types
within a specific virtual model.
"""
return
{
dtype
:
cls
.
build_model_gbuf_range
(
model
,
dtype
)
for
dtype
in
model
.
_grad_buffers
}
@
classmethod
def
build_model_param_gbuf_map
(
cls
,
model_gbuf_ranges
):
"""
Create a reverse of the model_gbuf_ranges, for referencing in
opposite direction.
"""
param_gbuf_map
=
{}
for
model_index
,
model_gbuf_range_map
in
enumerate
(
model_gbuf_ranges
):
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
for
param
,
param_range_map
in
gbuf_range_map
[
"param_map"
].
items
():
param_gbuf_map
[
param
]
=
(
model_index
,
dtype
)
return
param_gbuf_map
@
classmethod
def
build_optimizer_group_ranges
(
cls
,
param_groups
,
model_gbuf_ranges
):
"""
Create optimizer groups.
Given the set of parameter shard ranges that are owned by the current
data-parallel (DP) rank, gather the set of parameters that will be
used (in the method below) to create the current DP's optimizer
groups.
"""
num_groups
=
len
(
param_groups
)
# Param group map.
param_group_map
=
{}
for
group_index
,
group
in
enumerate
(
param_groups
):
for
param
in
group
[
"params"
]:
assert
param
.
requires_grad
param_group_map
[
param
]
=
group_index
# Optimizer group ranges.
group_ranges
=
[
{
"params"
:
[]}
for
_
in
param_groups
]
for
model_gbuf_range_map
in
model_gbuf_ranges
:
for
dtype
,
gbuf_range_map
in
model_gbuf_range_map
.
items
():
for
param
in
gbuf_range_map
[
"param_map"
]:
group_index
=
param_group_map
[
param
]
group_range
=
group_ranges
[
group_index
]
group_range
[
"params"
].
append
(
param
)
# Squeeze zero-size group ranges.
for
group_index
,
group_range
in
enumerate
(
group_ranges
):
group_range
[
"orig_group"
]
=
param_groups
[
group_index
]
group_ranges
=
[
g
for
g
in
group_ranges
if
len
(
g
[
"params"
])
>
0
]
return
group_ranges
@
classmethod
def
build_model_and_main_param_groups
(
cls
,
model_gbuf_ranges
,
param_gbuf_map
,
opt_group_ranges
):
"""
Create main parameter groups needed for the optimizer step.
These groups encompass both: 1) groups used by this class, for
reducing/gather, and 2) groups used by the inner optimizer for the
parameter update. Given that the conceptual grad buffer partitioning
(created in earlier method) doesn't respect parameter boundaries,
the optimizer operates on shards of the model parameters, rather than
the full parameters.
"""
# Parameter groups:
# model_float16_groups: original float16 parameters
# model_fp32_groups: original fp32 parameters
# shard_float16_groups: shards of original float16 parameters
# shard_fp32_groups: shards of original fp32 parameters
# shard_fp32_from_float16_groups: fp32 copy of float16 parameters
model_float16_groups
=
[]
model_fp32_groups
=
[]
shard_float16_groups
=
[]
shard_fp32_groups
=
[]
shard_fp32_from_float16_groups
=
[]
# Allocate (or slice) each group's param shard.
for
group_index
,
group_range
in
enumerate
(
opt_group_ranges
):
# Params of this group.
model_float16_params_this_group
=
[]
model_fp32_params_this_group
=
[]
shard_float16_params_this_group
=
[]
shard_fp32_params_this_group
=
[]
shard_fp32_from_float16_params_this_group
=
[]
model_float16_groups
.
append
(
model_float16_params_this_group
)
model_fp32_groups
.
append
(
model_fp32_params_this_group
)
shard_float16_groups
.
append
(
shard_float16_params_this_group
)
shard_fp32_groups
.
append
(
shard_fp32_params_this_group
)
shard_fp32_from_float16_groups
.
append
(
shard_fp32_from_float16_params_this_group
)
for
model_param
in
group_range
[
"params"
]:
assert
model_param
.
requires_grad
model_index
,
dtype
=
param_gbuf_map
[
model_param
]
gbuf_range
=
model_gbuf_ranges
[
model_index
][
dtype
]
param_range
=
gbuf_range
[
"param_map"
][
model_param
][
"param"
]
# fp16, bf16 params.
if
model_param
.
type
()
in
[
'torch.cuda.HalfTensor'
,
'torch.cuda.BFloat16Tensor'
]:
# Clone model -> main.
shard_model_param
=
model_param
.
detach
().
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
=
shard_model_param
.
clone
().
float
()
mpu
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
shard_main_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
shard_main_param
.
shared
=
model_param
.
shared
# Add to group.
model_float16_params_this_group
.
append
(
model_param
)
shard_float16_params_this_group
.
append
(
shard_model_param
)
shard_fp32_from_float16_params_this_group
.
append
(
shard_main_param
)
# fp32 params.
elif
model_param
.
type
()
==
'torch.cuda.FloatTensor'
:
shard_model_param
=
model_param
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
model_fp32_params_this_group
.
append
(
model_param
)
shard_fp32_params_this_group
.
append
(
shard_model_param
)
mpu
.
copy_tensor_model_parallel_attributes
(
shard_model_param
,
model_param
)
if
hasattr
(
model_param
,
'shared'
):
shard_model_param
.
shared
=
model_param
.
shared
else
:
raise
TypeError
(
'Wrapped parameters must be one of '
'torch.cuda.FloatTensor, '
'torch.cuda.HalfTensor, or '
'torch.cuda.BFloat16Tensor. '
'Received {}'
.
format
(
param
.
type
()))
# Update optimizer's params.
group_range
[
"orig_group"
][
"params"
]
=
[
*
shard_fp32_params_this_group
,
*
shard_fp32_from_float16_params_this_group
,
]
return
(
model_float16_groups
,
model_fp32_groups
,
shard_float16_groups
,
shard_fp32_groups
,
shard_fp32_from_float16_groups
,
)
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
):
"""
See top of class definition for argument descriptions.
The steps in this method create the core mapping between DDP grad
buffers, parameters, and parameter shard ranges, that is needed for
converting between model param indexes and main parameter shard
indexes. This method also updates the optimizer parameter groups
with the newly created shards.
"""
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
)
# Verify that contiguous buffers are being used.
# - Note: this should already be checked in arguments.py.
assert
use_contiguous_buffers_in_local_ddp
# Model grad buffer ranges.
self
.
model_gbuf_ranges
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
self
.
model_gbuf_ranges
.
append
(
self
.
build_model_gbuf_range_map
(
model
))
self
.
model_param_gbuf_map
=
\
self
.
build_model_param_gbuf_map
(
self
.
model_gbuf_ranges
)
# Optimizer ranges.
self
.
opt_group_ranges
=
self
.
build_optimizer_group_ranges
(
self
.
optimizer
.
param_groups
,
self
.
model_gbuf_ranges
)
# Allocate main param shards.
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
self
.
shard_float16_groups
,
self
.
shard_fp32_groups
,
self
.
shard_fp32_from_float16_groups
,
)
=
self
.
build_model_and_main_param_groups
(
self
.
model_gbuf_ranges
,
self
.
model_param_gbuf_map
,
self
.
opt_group_ranges
)
# Update optimizer groups.
# - Also, leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors.
self
.
optimizer
.
param_groups
=
\
[
g
[
"orig_group"
]
for
g
in
self
.
opt_group_ranges
]
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
get_model_param_range_map
(
self
,
param
):
"""
Given a model param, get the index sub-range of the param that this
data-parallel rank owns.
"""
model_index
,
dtype
=
self
.
model_param_gbuf_map
[
param
]
gbuf_range_map
=
self
.
model_gbuf_ranges
[
model_index
][
dtype
]
param_range_map
=
gbuf_range_map
[
"param_map"
][
param
]
return
param_range_map
def
get_model_parallel_group
(
self
):
"""
With the distributed optimizer, the model parallel group is the
entire world.
"""
return
None
def
state_dict
(
self
):
"""
The state dict must contain the fp32-from-float16 shards.
"""
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
if
self
.
grad_scaler
:
state_dict
[
'grad_scaler'
]
=
self
.
grad_scaler
.
state_dict
()
state_dict
[
'shard_fp32_from_float16_groups'
]
=
\
self
.
shard_fp32_from_float16_groups
return
state_dict
def
load_state_dict
(
self
,
state_dict
):
"""
Load the state dict.
"""
# Optimizer.
optimizer_key
=
'optimizer'
if
optimizer_key
not
in
state_dict
:
optimizer_key
=
'optimizer_state_dict'
print_rank_0
(
'***WARNING*** loading optimizer from '
'an old checkpoint ...'
)
self
.
optimizer
.
load_state_dict
(
state_dict
[
optimizer_key
])
# Grad scaler.
if
'grad_scaler'
not
in
state_dict
:
print_rank_0
(
'***WARNING*** found an old checkpoint, will not '
'load grad scaler ...'
)
else
:
if
self
.
grad_scaler
:
self
.
grad_scaler
.
load_state_dict
(
state_dict
[
'grad_scaler'
])
else
:
print_rank_0
(
'***WARNING*** fould the grad scaler in the '
'checkpoint but it is None in the class. '
'Skipping loading grad scaler ...'
)
# Copy data for the main params.
for
current_group
,
saved_group
in
zip
(
self
.
shard_fp32_from_float16_groups
,
state_dict
[
"shard_fp32_from_float16_groups"
]):
for
current_param
,
saved_param
in
zip
(
current_group
,
saved_group
):
current_param
.
data
.
copy_
(
saved_param
.
data
)
def
zero_grad
(
self
,
set_to_none
=
True
):
"""
Zero grads.
We only need to zero the model related parameters, i.e.,
model_float16_groups & model_fp32_groups. We additionally zero
the remaining groups as a memory optimization to reduce
fragmentation; in the case of set_to_none==True, the space
used by this field can be safely deallocated at this point.
"""
for
groups
in
(
self
.
model_float16_groups
,
self
.
model_fp32_groups
,
self
.
shard_float16_groups
,
# grad empty/unused here?
self
.
shard_fp32_groups
,
# throws grad-access warning
self
.
shard_fp32_from_float16_groups
):
for
group
in
groups
:
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_model_grad_buffer_dp_views
(
self
):
"""
Get shard views of each of the DDP's grad buffers.
In this nested list, the top level is grouped by the virtual model
index and the grad buffer's data type. The sub-level is a list of
shards of that grad buffer, where each shard in the list represents
a contiguous view of the grad buffer, that is owned by a data-parallel
rank. The shard boundary does not respect parameter boundaries, and
so the elements of some parameters are split across data parallel
ranks.
Additionally, return references to the entire grad buffers, for use
in _reduce_scatter_base and _all_gather_base.
"""
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Grad buffer views.
gbuf_view_items
=
[]
for
model_index
,
model
in
enumerate
(
self
.
models
):
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
assert
gbuf
.
numel_padded
%
data_parallel_world_size
==
0
shard_size
=
int
(
gbuf
.
numel_padded
/
data_parallel_world_size
)
gbuf_views
=
[
gbuf
.
data
[(
r
*
shard_size
):((
r
+
1
)
*
shard_size
)]
for
r
in
range
(
data_parallel_world_size
)]
gbuf_view_items
.
append
((
model_index
,
dtype
,
gbuf
.
data
,
gbuf_views
))
return
gbuf_view_items
def
reduce_model_grads
(
self
,
args
,
timers
):
"""
Reduce-scatter model grads.
The DDP's grad buffer is used for the reduce-scatter, and thus no
tensors are dynamically allocated.
Note: this is a different order of reduction, versus the non-
distributed optimizer, which reduces: 1) layernorm grads, 2) all
grads, 3) embedding grads.
"""
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'backward-layernorm-all-reduce'
).
start
()
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'backward-layernorm-all-reduce'
).
stop
()
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Reduce-scatter setup.
timers
(
'backward-params-all-reduce'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
# Scale grad buffers by '1 / data_parallel_world_size'.
for
model
in
self
.
models
:
for
dtype
,
gbuf
in
model
.
_grad_buffers
.
items
():
gbuf
.
data
/=
data_parallel_world_size
# Reduce-scatter all grads.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
\
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_reduce_scatter_base
(
gbuf_views
[
data_parallel_rank
],
gbuf
,
group
=
data_parallel_group
,
)
timers
(
'backward-params-all-reduce'
).
stop
()
def
gather_model_params
(
self
,
args
,
timers
):
"""
All-gather updated model params.
The DDP's grad buffer is used for the all-gather, and thus no
tensors are dynamically allocated. After the all-gather, the params
can be copied from param.main_grad to param.
"""
timers
(
'backward-params-all-gather'
).
start
()
data_parallel_rank
=
mpu
.
get_data_parallel_rank
()
data_parallel_group
=
mpu
.
get_data_parallel_group
()
# All-gather updated main params.
# - All grad buffer views are guaranteed to have the same num elements
# across all data parallel ranks, with grad buffer padding that is done
# in distributed.py. Thus, all sub-views will have consistent start/end
# indexes across data parallel ranks.
gbuf_view_items
=
self
.
get_model_grad_buffer_dp_views
()
for
index
,
(
model_index
,
dtype
,
gbuf
,
gbuf_views
)
\
in
enumerate
(
gbuf_view_items
):
torch
.
distributed
.
_all_gather_base
(
gbuf
,
gbuf_views
[
data_parallel_rank
],
group
=
data_parallel_group
,
)
# Each model param now contains its updated values in its
# '.main_grad' field.
for
model
in
self
.
models
:
for
dtype
,
param_map
in
model
.
_grad_buffer_param_index_map
.
items
():
for
param
in
param_map
:
param
.
detach
().
copy_
(
param
.
main_grad
)
timers
(
'backward-params-all-gather'
).
stop
()
def
_collect_main_grad_data_for_unscaling
(
self
):
"""
Note: this should be equivalent to the float-16 optimizer's method,
but writtent differently, so the two should be combined.
"""
return
[
param
.
grad
.
data
for
group
in
self
.
optimizer
.
param_groups
for
param
in
group
[
"params"
]
]
def
_get_model_and_main_params_data_float16
(
self
):
"""
Get aligned list of model and main params.
"""
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
shard_float16_groups
,
self
.
shard_fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
):
"""
Copy model grads to main grads.
Since this step follows a reduce-scatter through the DDP's grad
buffer, this method is responsible for copying the updated grads
from the grad buffer to the main shard's grad field.
"""
# Utility method for copying group grads.
def
copy_group_grads
(
model_groups
,
shard_main_groups
):
for
model_group
,
shard_main_group
in
zip
(
model_groups
,
shard_main_groups
):
for
model_param
,
shard_main_param
in
zip
(
model_group
,
shard_main_group
):
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
model_grad
=
model_param
.
main_grad
shard_model_grad
=
model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_main_param
.
grad
=
shard_model_grad
.
float
()
# Copy model groups to shard groups.
copy_group_grads
(
self
.
model_float16_groups
,
self
.
shard_fp32_from_float16_groups
)
copy_group_grads
(
self
.
model_fp32_groups
,
self
.
shard_fp32_groups
)
def
_copy_main_params_to_model_params
(
self
):
"""
Copy main params to model params.
Since this step is followed by an all-gather through the DDP's grad
buffer, this method is responsible for copying the updated params
from the main shards into the correct position in the grad buffer.
"""
# Utility method for copying group params.
def
copy_group_params
(
shard_main_groups
,
model_groups
):
for
shard_main_group
,
model_group
in
zip
(
shard_main_groups
,
model_groups
):
for
shard_main_param
,
model_param
in
zip
(
shard_main_group
,
model_group
):
param_range_map
=
self
.
get_model_param_range_map
(
model_param
)
param_range
=
param_range_map
[
"param"
]
assert
param_range
.
size
==
shard_main_param
.
nelement
()
model_grad
=
model_param
.
main_grad
shard_model_grad
=
model_grad
.
view
(
-
1
)
\
[
param_range
.
start
:
param_range
.
end
]
shard_model_grad
.
data
.
copy_
(
shard_main_param
)
# Copy shard groups to model groups.
copy_group_params
(
self
.
shard_fp32_from_float16_groups
,
self
.
model_float16_groups
)
copy_group_params
(
self
.
shard_fp32_groups
,
self
.
model_fp32_groups
)
megatron/optimizer/optimizer.py
View file @
b8428a7f
...
...
@@ -17,15 +17,20 @@
from
abc
import
ABC
from
abc
import
abstractmethod
import
torch
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model.module
import
param_is_not_shared
from
megatron.mpu.layers
import
param_is_not_tensor_parallel_duplicate
from
megatron.utils
import
unwrap_model
from
.clip_grads
import
clip_grad_norm_fp32
,
count_zeros_fp32
...
...
@@ -69,7 +74,8 @@ class MegatronOptimizer(ABC):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
):
use_contiguous_buffers_in_local_ddp
,
models
):
"""Input optimizer is the base optimizer for example Adam."""
self
.
optimizer
=
optimizer
...
...
@@ -80,10 +86,15 @@ class MegatronOptimizer(ABC):
self
.
params_have_main_grad
=
params_have_main_grad
self
.
use_contiguous_buffers_in_local_ddp
=
use_contiguous_buffers_in_local_ddp
# 'models' are retained for access to the contiguous grad buffers.
# (see distributed optimizer)
self
.
models
=
models
if
self
.
use_contiguous_buffers_in_local_ddp
:
assert
self
.
params_have_main_grad
,
\
"use of contiguous buffer requires that params have main grad"
def
get_parameters
(
self
):
params
=
[]
for
param_group
in
self
.
optimizer
.
param_groups
:
...
...
@@ -92,14 +103,42 @@ class MegatronOptimizer(ABC):
return
params
def
get_main_grads_for_grad_norm
(
self
):
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
params
=
self
.
get_parameters
()
grads_for_norm
=
[]
for
param
in
params
:
grad
=
param
.
grad
grad_not_none
=
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grads_for_norm
.
append
(
grad
)
return
grads_for_norm
def
get_model_parallel_group
(
self
):
"""Default returned here, but the distributed optimizer overrides this."""
return
mpu
.
get_model_parallel_group
()
def
clip_grad_norm
(
self
,
clip_grad
):
params
=
self
.
get_parameters
()
return
clip_grad_norm_fp32
(
params
,
clip_grad
)
grads_for_norm
=
self
.
get_main_grads_for_grad_norm
()
return
clip_grad_norm_fp32
(
params
,
grads_for_norm
,
clip_grad
,
model_parallel_group
=
self
.
get_model_parallel_group
())
def
count_zeros
(
self
):
params
=
self
.
get_parameters
()
return
count_zeros_fp32
(
params
)
return
count_zeros_fp32
(
params
,
model_parallel_group
=
self
.
get_model_parallel_group
())
@
abstractmethod
...
...
@@ -118,11 +157,6 @@ class MegatronOptimizer(ABC):
return
self
.
get_loss_scale
()
*
loss
@
abstractmethod
def
step
(
self
):
pass
@
abstractmethod
def
reload_model_params
(
self
):
"""Refreshes any internal state from the current model parameters.
...
...
@@ -166,9 +200,119 @@ class MegatronOptimizer(ABC):
param_groups
=
property
(
_get_param_groups
,
_set_param_groups
)
@
abstractmethod
def
step
(
self
,
args
,
timers
):
pass
class
Float16OptimizerWithFloat16Params
(
MegatronOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
def
gather_model_params
(
self
,
args
,
timers
):
"""
For the case of a non-distributed-optimizer, there is nothing to
do here.
"""
pass
def
allreduce_word_embedding_grads
(
self
,
args
):
"""
All-reduce word embedding grads.
Reduce grads across first and last stages to ensure that word_embeddings
parameters stay in sync. This should only run for models that support
pipelined model parallelism (BERT and GPT-2).
"""
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
self
.
models
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
self
.
models
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
self
.
models
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
def
allreduce_position_embedding_grads
(
self
,
args
):
"""
All-reduce position_embeddings grad across first (encoder) and
split (decoder) stages to ensure that position embeddings parameters
stay in sync. This should only run for T5 models with pipeline
parallelism.
"""
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
self
.
models
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
def
allreduce_embedding_grads
(
self
,
args
):
"""All-reduce both word and position embeddings."""
self
.
allreduce_word_embedding_grads
(
args
)
self
.
allreduce_position_embedding_grads
(
args
)
def
allreduce_layernorm_grads
(
self
,
args
):
"""All-reduce layernorm grads (for sequence parallelism)."""
# All-reduce layernorm parameters across model parallel nodes
# when sequence parallelism is used
if
mpu
.
get_tensor_model_parallel_world_size
()
>
1
and
\
args
.
sequence_parallel
:
grads
=
[]
for
model_module
in
self
.
models
:
unwrapped_model
=
unwrap_model
(
model_module
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
for
param
in
unwrapped_model
.
parameters
():
if
getattr
(
param
,
'sequence_parallel'
,
False
):
grad
=
param
.
main_grad
if
args
.
DDP_impl
==
'local'
else
param
.
grad
grads
.
append
(
grad
.
data
)
coalesced
=
_flatten_dense_tensors
(
grads
)
torch
.
distributed
.
all_reduce
(
coalesced
,
group
=
mpu
.
get_tensor_model_parallel_group
())
for
buf
,
synced
in
zip
(
grads
,
_unflatten_dense_tensors
(
coalesced
,
grads
)):
buf
.
copy_
(
synced
)
def
reduce_model_grads
(
self
,
args
,
timers
):
"""All-reduce all grads, and all-reduce embeddings."""
# All-reduce layer-norm grads (for sequence parallelism).
timers
(
'backward-layernorm-all-reduce'
).
start
()
self
.
allreduce_layernorm_grads
(
args
)
timers
(
'backward-layernorm-all-reduce'
).
stop
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
for
model
in
self
.
models
:
model
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce embedding grads.
timers
(
'backward-embedding-all-reduce'
).
start
()
self
.
allreduce_embedding_grads
(
args
)
timers
(
'backward-embedding-all-reduce'
).
stop
()
class
MixedPrecisionOptimizer
(
MegatronOptimizer
):
"""Base class for both the float-16 and the distributed optimizer.
Arguments:
optimizer: base optimizer such as Adam or SGD
...
...
@@ -184,27 +328,36 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
bf16
,
grad_scaler
):
fp16
,
bf16
,
grad_scaler
,
models
):
super
(
Float16OptimizerWithFloat16Params
,
self
).
__init__
(
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
self
.
fp16
=
fp16
self
.
bf16
=
bf16
self
.
grad_scaler
=
grad_scaler
# None grad scaler is only supported for bf16.
if
self
.
grad_scaler
is
None
:
assert
self
.
b
f16
,
'fp16 expects a grad scaler.'
assert
not
self
.
f
p
16
,
'fp16 expects a grad scaler.'
# Tensor used to determine if a nan/if has happend.
# Any non-zero value indicates inf/nan.
...
...
@@ -225,6 +378,131 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
self
.
grad_scaler
is
None
:
self
.
_scale_one
=
torch
.
cuda
.
FloatTensor
([
1.0
])
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
def
_unscale_main_grads_and_check_for_nan
(
self
):
# Collect main grads.
main_grads
=
self
.
_collect_main_grad_data_for_unscaling
()
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
get_model_parallel_group
())
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
return
found_inf_flag
@
torch
.
no_grad
()
def
step
(
self
,
args
,
timers
):
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# Count the zeros in the grads.
timers
(
'optimizer-count-zeros'
).
start
()
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
# Step the optimizer.
timers
(
'optimizer-inner-step'
).
start
()
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
class
Float16OptimizerWithFloat16Params
(
MixedPrecisionOptimizer
):
"""Float16 optimizer for fp16 and bf16 data types.
Arguments:
optimizer: base optimizer such as Adam or SGD
clip_grad: clip gradeints with this global L2 norm. Note
that clipping is ignored if clip_grad == 0
log_num_zeros_in_grad: return number of zeros in the gradients.
params_have_main_grad: flag indicating if parameters have
a `main_grad` field. If this is set, we are assuming
that the model parameters are store in the `main_grad`
field instead of the typical `grad` field. This happens
for the DDP cases where there is a continuous buffer
holding the gradients. For example for bfloat16, we want
to do gradient accumulation and all-reduces in float32
and as a result we store those gradients in the main_grad.
Note that main grad is not necessarily in float32.
use_contiguous_buffers_in_local_ddp: if true, the local DDP model
is using a contiguous buffer to hold the model grads.
fp16: if true, the model is running in fp16.
bf16: if true, the model is running in bfloat16.
grad_scaler: used for scaling gradients. Note that this can be
None. This case happens when `bf16 = True` and we don't
use any loss scale. Note that for `bf16 = True`, we can have
a constnat gradient scaler. Also for `bf16 = False`, we
always require a grad scaler.
models: list of models (i.e., the virtual pipelining models). This
is used by the distributed optimizer for mapping parameters.
"""
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
):
super
().
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
fp16
,
bf16
,
grad_scaler
,
models
)
# ======================
# main parameter stuff
# ======================
...
...
@@ -259,12 +537,12 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
main_param
.
shared
=
param
.
shared
# Replace the optimizer params with the new fp32 copy.
param_group
[
'params'
][
i
]
=
main_param
fp32_from_float16_params_this_group
.
append
(
main_param
)
# Reset existing state dict key to the new main param.
if
param
in
self
.
optimizer
.
state
:
self
.
optimizer
.
state
[
main_param
]
\
=
self
.
optimizer
.
state
.
pop
(
param
)
# fp32 params.
elif
param
.
type
()
==
'torch.cuda.FloatTensor'
:
fp32_params_this_group
.
append
(
param
)
...
...
@@ -282,10 +560,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
fp32_from_float16_params_this_group
)
self
.
fp32_from_fp32_groups
.
append
(
fp32_params_this_group
)
# Leverage state_dict() and load_state_dict() to
# recast preexisting per-param state tensors
self
.
optimizer
.
load_state_dict
(
self
.
optimizer
.
state_dict
())
def
zero_grad
(
self
,
set_to_none
=
True
):
"""We only need to zero the model related parameters, i.e.,
...
...
@@ -301,10 +575,34 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
_zero_grad_group_helper
(
group
,
set_to_none
)
def
get_loss_scale
(
self
):
if
self
.
grad_scaler
is
None
:
return
self
.
_scale_one
return
self
.
grad_scaler
.
scale
def
_collect_main_grad_data_for_unscaling
(
self
):
main_grads
=
[]
# fp32 params from float16 ones.
for
main_group
in
self
.
fp32_from_float16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Append fp32 parameters.
for
main_group
in
self
.
fp32_from_fp32_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
return
main_grads
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_model_grads_to_main_grads
(
self
):
...
...
@@ -338,43 +636,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
model_param
.
main_grad
=
None
def
_unscale_main_grads_and_check_for_nan
(
self
):
main_grads
=
[]
# fp32 params fromm float16 ones.
for
main_group
in
self
.
fp32_from_float16_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Append fp32 parameters.
for
main_group
in
self
.
fp32_from_fp32_groups
:
for
main_param
in
main_group
:
if
main_param
.
grad
is
not
None
:
main_grads
.
append
(
main_param
.
grad
.
data
)
# Reset found inf.
self
.
found_inf
.
fill_
(
0.0
)
# Unscale and set found inf/nan
torch
.
_amp_foreach_non_finite_check_and_unscale_
(
main_grads
,
self
.
found_inf
,
self
.
grad_scaler
.
inv_scale
)
# Update across all model parallel instances.
torch
.
distributed
.
all_reduce
(
self
.
found_inf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
mpu
.
get_model_parallel_group
())
# Check for nan.
found_inf_flag
=
(
self
.
found_inf
.
item
()
>
0
)
return
found_inf_flag
def
_get_model_and_main_params_data_float16
(
self
):
model_data
=
[]
main_data
=
[]
for
model_group
,
main_group
in
zip
(
self
.
float16_groups
,
self
.
fp32_from_float16_groups
):
for
model_param
,
main_param
in
zip
(
model_group
,
main_group
):
model_data
.
append
(
model_param
.
data
)
main_data
.
append
(
main_param
.
data
)
return
model_data
,
main_data
def
_copy_main_params_to_model_params
(
self
):
# Only needed for the float16 params.
...
...
@@ -390,60 +651,6 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
overflow_buf
=
self
.
_dummy_overflow_buf
)
def
reload_model_params
(
self
):
self
.
_copy_model_params_to_main_params
()
@
torch
.
no_grad
()
def
step
(
self
):
timers
=
get_timers
()
# Copy gradients from model params to main params.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
self
.
_copy_model_grads_to_main_grads
()
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Do unscale, check for inf, and update grad scaler only for
# the case that grad scaler is provided.
if
self
.
grad_scaler
:
# Unscale and check for inf/nan.
timers
(
'optimizer-unscale-and-check-inf'
).
start
()
found_inf_flag
=
self
.
_unscale_main_grads_and_check_for_nan
()
timers
(
'optimizer-unscale-and-check-inf'
).
stop
()
# We are done with scaling gradients
# so we can update the loss scale.
self
.
grad_scaler
.
update
(
found_inf_flag
)
# If we found inf/nan, skip the update.
if
found_inf_flag
:
return
False
,
None
,
None
# Clip the main gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
# Step the optimizer.
self
.
optimizer
.
step
()
# Update params from main params.
timers
(
'optimizer-copy-main-to-model-params'
).
start
()
self
.
_copy_main_params_to_model_params
()
timers
(
'optimizer-copy-main-to-model-params'
).
stop
()
# Successful update.
return
True
,
grad_norm
,
num_zeros_in_grad
def
state_dict
(
self
):
state_dict
=
{}
state_dict
[
'optimizer'
]
=
self
.
optimizer
.
state_dict
()
...
...
@@ -485,17 +692,18 @@ class Float16OptimizerWithFloat16Params(MegatronOptimizer):
current_param
.
data
.
copy_
(
saved_param
.
data
)
class
FP32Optimizer
(
MegatronOptimizer
):
def
__init__
(
self
,
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
):
use_contiguous_buffers_in_local_ddp
,
models
):
super
(
FP32Optimizer
,
self
).
__init__
(
optimizer
,
clip_grad
,
log_num_zeros_in_grad
,
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
)
params_have_main_grad
,
use_contiguous_buffers_in_local_ddp
,
models
)
self
.
_scale
=
torch
.
cuda
.
FloatTensor
([
1.0
])
...
...
@@ -512,11 +720,12 @@ class FP32Optimizer(MegatronOptimizer):
@
torch
.
no_grad
()
def
step
(
self
):
def
step
(
self
,
args
,
timers
):
"""Clip gradients (if needed) and step the base optimizer.
Always return successful since there is no overflow."""
# Copy main_grads to grads.
timers
(
'optimizer-copy-to-main-grad'
).
start
()
if
self
.
params_have_main_grad
:
for
param_group
in
self
.
optimizer
.
param_groups
:
for
param
in
param_group
[
'params'
]:
...
...
@@ -527,18 +736,25 @@ class FP32Optimizer(MegatronOptimizer):
# persist and therefore should not be deallocated.)
if
not
self
.
use_contiguous_buffers_in_local_ddp
:
param
.
main_grad
=
None
timers
(
'optimizer-copy-to-main-grad'
).
stop
()
# Clip gradients.
timers
(
'optimizer-clip-main-grad'
).
start
()
grad_norm
=
None
if
self
.
clip_grad
>
0.0
:
grad_norm
=
self
.
clip_grad_norm
(
self
.
clip_grad
)
timers
(
'optimizer-clip-main-grad'
).
stop
()
# count the zeros in the grads
timers
(
'optimizer-count-zeros'
).
start
()
num_zeros_in_grad
=
self
.
count_zeros
()
if
\
self
.
log_num_zeros_in_grad
else
None
timers
(
'optimizer-count-zeros'
).
stop
()
# Update parameters.
timers
(
'optimizer-inner-step'
).
start
()
self
.
optimizer
.
step
()
timers
(
'optimizer-inner-step'
).
stop
()
# No overflow for FP32 optimizer.
return
True
,
grad_norm
,
num_zeros_in_grad
...
...
megatron/p2p_communication.py
View file @
b8428a7f
...
...
@@ -61,7 +61,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
override_scatter_gather_tensors_in_pipeline
=
False
if
args
.
scatter_gather_tensors_in_pipeline
:
if
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
tensor_chunk_shape
=
reduce
(
operator
.
mul
,
tensor_shape
,
1
)
if
tensor_chunk_shape
%
mpu
.
get_tensor_model_parallel_world_size
()
==
0
:
tensor_chunk_shape
=
tensor_chunk_shape
//
\
...
...
@@ -93,7 +94,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# Split tensor into smaller chunks if using scatter-gather optimization.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
tensor_send_next
is
not
None
:
tensor_send_next
=
mpu
.
split_tensor_into_1d_equal_chunks
(
tensor_send_next
)
...
...
@@ -138,7 +140,8 @@ def _communicate(tensor_send_next, tensor_send_prev, recv_prev, recv_next,
# If using scatter-gather optimization, gather smaller chunks.
if
not
override_scatter_gather_tensors_in_pipeline
and
\
args
.
scatter_gather_tensors_in_pipeline
:
args
.
scatter_gather_tensors_in_pipeline
and
\
not
args
.
sequence_parallel
:
if
recv_prev
:
tensor_recv_prev
=
mpu
.
gather_split_1d_tensor
(
tensor_recv_prev
).
view
(
tensor_shape
).
requires_grad_
()
...
...
megatron/schedules.py
View file @
b8428a7f
...
...
@@ -279,8 +279,12 @@ def forward_backward_pipelining_with_interleaving(forward_step_func,
pipeline_parallel_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
args
=
get_args
()
tensor_shape
=
(
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
tensor_shape
=
(
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
)
# Compute number of warmup and remaining microbatches.
num_model_chunks
=
len
(
model
)
num_microbatches
=
get_num_microbatches
()
*
num_model_chunks
...
...
@@ -514,18 +518,25 @@ def get_tensor_shapes(rank, model_type):
# Otherwise, send one tensor (pre-transpose).
args
=
get_args
()
tensor_shapes
=
[]
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
if
model_type
==
ModelType
.
encoder_and_decoder
:
if
args
.
sequence_parallel
:
decoder_seq_length
=
args
.
decoder_seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
decoder_seq_length
=
args
.
decoder_seq_length
if
mpu
.
is_pipeline_stage_before_split
(
rank
):
# If next rank is after split, then need transpose for encoder_hidden_state.
if
mpu
.
is_pipeline_stage_before_split
(
rank
+
1
):
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
args
.
micro_batch_size
,
args
.
seq_length
,
args
.
hidden_size
))
tensor_shapes
.
append
((
decoder_seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
else
:
tensor_shapes
.
append
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
tensor_shapes
.
append
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
))
return
tensor_shapes
...
...
megatron/training.py
View file @
b8428a7f
...
...
@@ -21,7 +21,6 @@ import sys
import
time
# The earliest we can measure the start time.
_TRAIN_START_TIME
=
time
.
time
()
import
torch
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
...
...
@@ -43,6 +42,7 @@ from megatron.model import ModelType
from
megatron.optimizer
import
get_megatron_optimizer
from
megatron.initialize
import
initialize_megatron
from
megatron.initialize
import
write_args_to_tensorboard
from
megatron.initialize
import
set_jit_fusion_options
from
megatron.optimizer_param_scheduler
import
OptimizerParamScheduler
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.utils
import
check_adlr_autoresume_termination
...
...
@@ -51,7 +51,7 @@ from megatron.data.data_samplers import build_pretraining_data_loader
from
megatron.utils
import
calc_params_l2_norm
from
megatron.schedules
import
get_forward_backward_func
from
megatron.utils
import
report_memory
from
megatron.model.vision.knn_monitor
import
compute_feature_bank
def
print_datetime
(
string
):
...
...
@@ -100,6 +100,8 @@ def pretrain(train_valid_test_dataset_provider,
# Initalize and get arguments, timers, and Tensorboard writer.
initialize_megatron
(
extra_args_provider
=
extra_args_provider
,
args_defaults
=
args_defaults
)
# Set pytorch JIT layer fusion options and warmup JIT functions.
set_jit_fusion_options
()
# Adjust the startup time so it reflects the largest value.
# This will be closer to what scheduler will see (outside of
...
...
@@ -362,12 +364,11 @@ def setup_model_and_optimizer(model_provider_func,
args
=
get_args
()
model
=
get_model
(
model_provider_func
,
model_type
)
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
optimizer
=
get_megatron_optimizer
(
unwrapped_model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
optimizer
=
get_megatron_optimizer
(
model
,
no_wd_decay_cond
,
scale_lr_cond
,
lr_mult
)
opt_param_scheduler
=
get_optimizer_param_scheduler
(
optimizer
)
if
args
.
load
is
not
None
:
...
...
@@ -410,66 +411,44 @@ def train_step(forward_step_func, data_iterator,
partition
.
zero_grad_buffer
()
optimizer
.
zero_grad
()
# Forward pass.
forward_backward_func
=
get_forward_backward_func
()
losses_reduced
=
forward_backward_func
(
forward_step_func
,
data_iterator
,
model
,
optimizer
,
timers
,
forward_only
=
False
)
# Empty unused memory
# Empty unused memory
.
if
args
.
empty_unused_memory_level
>=
1
:
torch
.
cuda
.
empty_cache
()
# All-reduce if needed.
if
args
.
DDP_impl
==
'local'
:
timers
(
'backward-params-all-reduce'
).
start
()
for
model_module
in
model
:
model_module
.
allreduce_gradients
()
timers
(
'backward-params-all-reduce'
).
stop
()
# All-reduce word_embeddings' grad across first and last stages to ensure
# that word_embeddings parameters stay in sync.
# This should only run for models that support pipelined model parallelism
# (BERT and GPT-2).
timers
(
'backward-embedding-all-reduce'
).
start
()
if
mpu
.
is_rank_in_embedding_group
(
ignore_virtual
=
True
)
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
0
]
elif
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
unwrapped_model
=
model
[
-
1
]
else
:
# We do not support the interleaved schedule for T5 yet.
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
if
unwrapped_model
.
share_word_embeddings
:
word_embeddings_weight
=
unwrapped_model
.
word_embeddings_weight
()
if
args
.
DDP_impl
==
'local'
:
grad
=
word_embeddings_weight
.
main_grad
else
:
grad
=
word_embeddings_weight
.
grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_embedding_group
())
# All-reduce position_embeddings grad across first (encoder) and split (decoder)
# stages to ensure that position embeddings parameters stay in sync.
# This should only run for T5 models with pipeline parallelism
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
unwrapped_model
=
model
[
0
]
unwrapped_model
=
unwrap_model
(
unwrapped_model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
assert
args
.
DDP_impl
==
'local'
,
\
'T5 model is only supported with local DDP mode'
grad
=
unwrapped_model
.
language_model
.
embedding
.
position_embeddings
.
weight
.
main_grad
torch
.
distributed
.
all_reduce
(
grad
,
group
=
mpu
.
get_position_embedding_group
())
timers
(
'backward-embedding-all-reduce'
).
stop
()
# Reduce gradients.
timers
(
'backward-reduce-model-grads'
).
start
()
optimizer
.
reduce_model_grads
(
args
,
timers
)
timers
(
'backward-reduce-model-grads'
).
stop
()
# Vision gradients.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
cancel_gradients_last_layer
(
args
.
curr_iteration
)
# Update parameters.
timers
(
'optimizer'
).
start
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
()
update_successful
,
grad_norm
,
num_zeros_in_grad
=
optimizer
.
step
(
args
,
timers
)
timers
(
'optimizer'
).
stop
()
# Gather params.
if
update_successful
:
timers
(
'backward-gather-model-params'
).
start
()
optimizer
.
gather_model_params
(
args
,
timers
)
timers
(
'backward-gather-model-params'
).
stop
()
# Vision momentum.
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
unwrapped_model
=
unwrap_model
(
model
[
0
],
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
update_momentum
(
args
.
curr_iteration
)
# Update learning rate.
if
update_successful
:
increment
=
get_num_microbatches
()
*
\
...
...
@@ -480,7 +459,7 @@ def train_step(forward_step_func, data_iterator,
else
:
skipped_iter
=
1
# Empty unused memory
# Empty unused memory
.
if
args
.
empty_unused_memory_level
>=
2
:
torch
.
cuda
.
empty_cache
()
...
...
@@ -547,10 +526,15 @@ def training_log(loss_dict, total_loss_dict, learning_rate, iteration,
add_to_logging
(
'backward-send-forward-recv'
)
add_to_logging
(
'backward-send-backward-recv'
)
add_to_logging
(
'backward-params-all-reduce'
)
add_to_logging
(
'backward-layernorm-all-reduce'
)
add_to_logging
(
'backward-embedding-all-reduce'
)
add_to_logging
(
'backward-reduce-model-grads'
)
add_to_logging
(
'backward-gather-model-params'
)
add_to_logging
(
'optimizer-copy-to-main-grad'
)
add_to_logging
(
'optimizer-unscale-and-check-inf'
)
add_to_logging
(
'optimizer-clip-main-grad'
)
add_to_logging
(
'optimizer-count-zeros'
)
add_to_logging
(
'optimizer-inner-step'
)
add_to_logging
(
'optimizer-copy-main-to-model-params'
)
add_to_logging
(
'optimizer'
)
add_to_logging
(
'batch-generator'
)
...
...
@@ -702,6 +686,7 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
report_memory_flag
=
True
while
iteration
<
args
.
train_iters
:
update_num_microbatches
(
args
.
consumed_train_samples
)
args
.
curr_iteration
=
iteration
loss_dict
,
skipped_iter
,
grad_norm
,
num_zeros_in_grad
=
\
train_step
(
forward_step_func
,
train_data_iterator
,
...
...
@@ -791,6 +776,9 @@ def evaluate(forward_step_func,
"""Evaluation."""
args
=
get_args
()
if
args
.
vision_pretraining
and
args
.
vision_pretraining_type
==
"dino"
:
compute_feature_bank
(
model
)
# Turn on evaluation mode which disables dropout.
for
model_module
in
model
:
model_module
.
eval
()
...
...
pretrain_vi
t
.py
→
pretrain_vi
sion_classify
.py
View file @
b8428a7f
...
...
@@ -22,20 +22,32 @@ from megatron import get_args, get_timers, mpu, print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model
import
ModelType
from
megatron.model.vision.classification
import
VitClassificationModel
from
megatron.model.vision.classification
import
MitClassificationModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
print_rank_0
(
"building VIT model ..."
)
args
=
get_args
()
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
if
args
.
vision_backbone_type
==
'vit'
:
print_rank_0
(
"building VIT model ..."
)
model
=
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
print_rank_0
(
"building MIT model ..."
)
model
=
MitClassificationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
model
def
get_batch
(
data_iterator
):
"""Build the batch."""
data
=
next
(
data_iterator
)
...
...
@@ -46,6 +58,7 @@ def get_batch(data_iterator):
return
images
,
labels
def
loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
labels
)
...
...
@@ -58,6 +71,7 @@ def loss_func(labels, output_tensor):
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"accuracy"
:
averaged_loss
[
1
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
...
...
@@ -98,5 +112,5 @@ if __name__ == "__main__":
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
}
args_defaults
=
{
'dataloader_type'
:
'cyclic'
,
'vision_pretraining'
:
True
}
)
pretrain_vision_dino.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
import
torch.nn.functional
as
F
import
torch.nn
as
nn
import
numpy
as
np
import
torch.distributed
as
dist
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.dino
import
DINOPretrainModel
from
megatron.model.vision.knn_monitor
import
knn_predict
,
get_feature_bank
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
,
unwrap_model
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
return
DINOPretrainModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
def
get_batch
(
data_iterator
):
"""Build the batch."""
data
=
next
(
data_iterator
)
# only data parallelism; no need for broadcast
if
isinstance
(
data
[
0
],
list
):
images
=
[
aug
.
cuda
()
for
aug
in
data
[
0
]]
else
:
images
=
data
[
0
].
cuda
()
labels
=
data
[
1
].
cuda
()
return
images
,
labels
def
loss_func
(
model
,
labels
,
output_tensor
,
collect_data
=
False
):
args
=
get_args
()
model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
)
)
if
model
.
training
:
student_output
,
teacher_output
=
output_tensor
loss
=
model
.
dino_loss
(
student_output
,
teacher_output
,
args
.
curr_iteration
)
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
"loss"
:
averaged_loss
[
0
]}
else
:
_
,
teacher_feature
=
output_tensor
feature_bank
,
feature_labels
,
classes
=
get_feature_bank
()
feature
=
F
.
normalize
(
teacher_feature
.
float
(),
dim
=
1
)
knn_accs
=
[]
for
k
in
[
10
,
20
,
100
,
200
]:
pred_labels
=
knn_predict
(
feature
,
feature_bank
,
feature_labels
,
classes
,
k
,
0.07
)
knn_acc
=
(
pred_labels
[:,
0
]
==
labels
).
float
().
mean
()
knn_accs
.
append
(
knn_acc
)
averaged_loss
=
average_losses_across_data_parallel_group
(
knn_accs
)
return
0
,
{
"knn_acc_10"
:
averaged_loss
[
0
],
"knn_acc_20"
:
averaged_loss
[
1
],
"knn_acc_100"
:
averaged_loss
[
2
],
"knn_acc_200"
:
averaged_loss
[
3
]}
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
(
images
,
labels
,
)
=
get_batch
(
data_iterator
)
timers
(
"batch-generator"
).
stop
()
return
model
(
images
),
partial
(
loss_func
,
model
,
labels
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid, and test datasets."""
args
=
get_args
()
print_rank_0
(
"> building train, validation, and test datasets "
"for VIT ..."
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
print_rank_0
(
"> finished creating VIT datasets ..."
)
return
train_ds
,
valid_ds
,
None
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
,
'vision_pretraining'
:
True
}
)
pretrain_vision_inpaint.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pretrain VIT"""
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
,
mpu
,
print_rank_0
,
print_rank_last
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
megatron.model.vision.inpainting
import
VitInpaintingModel
from
megatron.model.vision.inpainting
import
MitInpaintingModel
from
megatron.training
import
pretrain
from
megatron.utils
import
average_losses_across_data_parallel_group
from
tasks.vision.metrics
import
SSIM
,
PSNR
from
megatron.model
import
ModelType
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
model
=
VitInpaintingModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
elif
args
.
vision_backbone_type
==
'mit'
:
model
=
MitInpaintingModel
(
pre_process
=
pre_process
,
post_process
=
post_process
)
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
model
def
get_batch
(
data_iterator
):
"""Build the batch."""
data
=
next
(
data_iterator
)
# only data parallelism; no need for broadcast
images
=
data
[
0
][
0
].
cuda
()
masks
=
data
[
0
][
1
].
cuda
()
return
images
,
masks
def
loss_func
(
images
,
masks
,
masked_images
,
outputs
,
collect_data
=
False
):
outputs
=
outputs
.
contiguous
().
float
()
masks_flip
=
1
-
masks
flip_masked_outputs
=
outputs
.
masked_fill
(
masks_flip
.
bool
(),
0
)
flip_masked_images
=
images
.
masked_fill
(
masks_flip
.
bool
(),
0
)
ssim_fun
=
SSIM
()
psnr_fun
=
PSNR
()
if
not
collect_data
:
mask_count
=
torch
.
count_nonzero
(
masks
)
loss
=
F
.
mse_loss
(
flip_masked_outputs
,
flip_masked_images
.
float
(),
reduction
=
"sum"
)
loss
=
loss
/
mask_count
ssim
=
ssim_fun
(
flip_masked_outputs
,
flip_masked_images
.
float
())
psnr
=
psnr_fun
(
flip_masked_outputs
,
flip_masked_images
.
float
())
averaged_loss
=
average_losses_across_data_parallel_group
(
[
loss
,
psnr
,
ssim
]
)
return
loss
,
{
"loss"
:
averaged_loss
[
0
],
"psnr"
:
averaged_loss
[
1
],
'ssim'
:
averaged_loss
[
2
]}
else
:
synth_images
=
masked_images
.
float
()
+
flip_masked_outputs
ssim
=
ssim_fun
(
synth_images
,
images
.
float
())
psnr
=
psnr_fun
(
synth_images
,
images
.
float
())
return
torch
.
cat
((
images
,
masked_images
,
synth_images
),
dim
=
2
),
ssim
,
psnr
def
forward_step
(
data_iterator
,
model
):
"""Forward step."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch-generator"
).
start
()
(
images
,
masks
,
)
=
get_batch
(
data_iterator
)
timers
(
"batch-generator"
).
stop
()
masked_images
=
images
.
masked_fill
(
masks
.
bool
(),
0
)
outputs
=
model
(
masked_images
)
# Forward mode
return
outputs
,
partial
(
loss_func
,
images
,
masks
,
masked_images
)
def
process_non_loss_data
(
data
,
iteration
,
writer
):
psnr_sum
=
0
ssim_sum
=
0
for
(
output_tb
,
ssim
,
psnr
)
in
data
:
output_tb
[
output_tb
<
0
]
=
0
output_tb
[
output_tb
>
1
]
=
1
writer
.
add_images
(
"gt-input-output-vald"
,
output_tb
,
global_step
=
iteration
,
walltime
=
None
,
dataformats
=
'NCHW'
)
psnr_sum
=
psnr_sum
+
psnr
.
item
()
ssim_sum
=
ssim_sum
+
ssim
.
item
()
psnr
=
psnr_sum
/
len
(
data
)
ssim
=
ssim_sum
/
len
(
data
)
writer
.
add_scalar
(
'PSNR generate value-validation'
,
psnr
,
iteration
)
writer
.
add_scalar
(
'SSIM generate value-validation'
,
ssim
,
iteration
)
def
train_valid_test_datasets_provider
(
train_val_test_num_samples
):
"""Build train, valid, and test datasets."""
args
=
get_args
()
print_rank_0
(
"> building train, validation, and test datasets "
"for VIT ..."
)
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
print_rank_0
(
"> finished creating VIT datasets ..."
)
return
train_ds
,
valid_ds
,
None
if
__name__
==
"__main__"
:
pretrain
(
train_valid_test_datasets_provider
,
model_provider
,
ModelType
.
encoder_or_decoder
,
forward_step
,
process_non_loss_data
,
args_defaults
=
{
'dataloader_type'
:
'cyclic'
,
'vision_pretraining'
:
True
}
)
tasks/finetune_utils.py
View file @
b8428a7f
...
...
@@ -229,7 +229,7 @@ def _train(model, optimizer, opt_param_scheduler, forward_step,
prefix
=
'iteration {}'
.
format
(
iteration
)
evaluate_and_print_results
(
prefix
,
forward_step
,
valid_dataloader
,
model
,
iteration
,
False
)
iteration
,
None
,
False
)
# Exiting based on iterations
if
args
.
exit_interval
and
iteration
%
args
.
exit_interval
==
0
:
...
...
tasks/vision/classification.py
→
tasks/vision/classification
/classification
.py
View file @
b8428a7f
...
...
@@ -15,12 +15,15 @@
"""Vision-classification finetuning/evaluation."""
from
megatron
import
get_args
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
print_rank_0
from
megatron.model.vi
t_model
import
Vit
Model
from
megatron.model.vi
sion.classification
import
VitClassification
Model
from
megatron.data.vit_dataset
import
build_train_valid_datasets
from
tasks.vision.eval_utils
import
accuracy_func_provider
from
tasks.vision.
classification.
eval_utils
import
accuracy_func_provider
from
tasks.vision.finetune_utils
import
finetune
from
megatron.utils
import
average_losses_across_data_parallel_group
def
classification
():
...
...
@@ -30,7 +33,7 @@ def classification():
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
crop
_size
=
args
.
img_
dim
,
image
_size
=
(
args
.
img_
h
,
args
.
img_w
)
,
)
return
train_ds
,
valid_ds
...
...
@@ -40,16 +43,52 @@ def classification():
print_rank_0
(
"building classification model for ImageNet ..."
)
return
VitModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
return
VitClassificationModel
(
num_classes
=
args
.
num_classes
,
finetune
=
True
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
classification
()
tasks/vision/eval_utils.py
→
tasks/vision/
classification/
eval_utils.py
View file @
b8428a7f
...
...
@@ -33,11 +33,10 @@ def accuracy_func_provider():
"""Provide function that calculates accuracies."""
args
=
get_args
()
data_path
=
args
.
data_path
crop_size
=
args
.
img_
dim
crop_size
=
(
args
.
img_
h
,
args
.
img_w
)
# mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
# Build dataloaders.
val_data_path
=
os
.
path
.
join
(
data_path
[
0
],
"val"
)
val_data_path
=
data_path
[
1
]
normalize
=
transforms
.
Normalize
(
mean
=
[
0.5
,
0.5
,
0.5
],
std
=
[
0.5
,
0.5
,
0.5
])
transform_val
=
transforms
.
Compose
(
[
...
...
@@ -54,6 +53,7 @@ def accuracy_func_provider():
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
...
...
@@ -71,7 +71,6 @@ def accuracy_func_provider():
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
args
=
get_args
()
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
...
...
@@ -98,7 +97,6 @@ def calculate_correct_answers(model, dataloader, epoch):
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
args
=
get_args
()
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
...
...
tasks/vision/finetune_utils.py
View file @
b8428a7f
...
...
@@ -17,11 +17,10 @@
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
from
megatron
import
print_rank_0
from
megatron
import
get_timers
from
megatron
import
mpu
from
megatron
import
mpu
,
utils
from
megatron.checkpointing
import
load_checkpoint
from
megatron.checkpointing
import
save_checkpoint
from
megatron.training
import
evaluate_and_print_results
...
...
@@ -29,7 +28,10 @@ from megatron.training import setup_model_and_optimizer
from
megatron.training
import
train_step
from
megatron.training
import
training_log
from
megatron.utils
import
check_adlr_autoresume_termination
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.utils
import
average_losses_across_data_parallel_group
,
print_params_min_max_norm
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
,
ModelType
def
process_batch
(
batch
):
...
...
@@ -39,45 +41,16 @@ def process_batch(batch):
return
images
,
labels
def
cross_entropy_loss_func
(
labels
,
output_tensor
):
logits
=
output_tensor
# Cross-entropy loss.
loss
=
F
.
cross_entropy
(
logits
.
contiguous
().
float
(),
labels
)
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
labels
)
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
):
def
build_data_loader
(
dataset
,
micro_batch_size
,
num_workers
,
drop_last
,
shuffle
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
...
...
@@ -112,14 +85,14 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
print_rank_0
(
'building train and validation dataloaders ...'
)
# Training dataset.
train_dataloader
=
build_data_loader
(
train_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
False
,
True
)
# Set the training iterations.
args
.
train_iters_per_epoch
=
len
(
train_dataloader
)
args
.
train_iters
=
args
.
epochs
*
args
.
train_iters_per_epoch
# Validation dataset. For this dataset, we do not need to set up
# shuffling so we can just use a simple infinite loop.
valid_dataloader_
=
build_data_loader
(
valid_dataset
,
args
.
micro_batch_size
,
args
.
num_workers
,
not
args
.
keep_last
)
args
.
num_workers
,
True
,
False
)
valid_dataloader
=
_build_infinite_size_dataloader
(
valid_dataloader_
)
# Now that we've built the data loaders, set batch_size arguments
...
...
@@ -132,6 +105,7 @@ def _build_train_valid_dataloaders(train_dataset, valid_dataset):
return
train_dataloader
,
valid_dataloader
def
_train
(
model
,
optimizer
,
...
...
@@ -140,6 +114,7 @@ def _train(
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
process_non_loss_data_func
=
None
):
"""Train the model."""
args
=
get_args
()
...
...
@@ -167,6 +142,7 @@ def _train(
# Set the data loader epoch to shuffle the index iterator.
train_dataloader
.
sampler
.
set_epoch
(
args
.
seed
+
epoch
)
train_dataloader
.
dataset
.
set_epoch
(
epoch
)
# For all the batches in the dataset.
for
iteration_
,
batch
in
enumerate
(
train_dataloader
):
...
...
@@ -185,8 +161,6 @@ def _train(
# Logging.
params_norm
=
None
if
args
.
log_params_norm
:
params_norm
=
calc_params_l2_norm
(
model
)
report_memory_flag
=
training_log
(
losses_dict
,
...
...
@@ -202,20 +176,16 @@ def _train(
)
# Autoresume
if
args
.
adlr_autoresume
and
(
iteration
%
args
.
adlr_autoresume_interval
==
0
):
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
if
args
.
adlr_autoresume
and
\
iteration
%
args
.
adlr_autoresume_interval
==
0
:
check_adlr_autoresume_termination
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Checkpointing
if
(
args
.
save
and
args
.
save_interval
and
iteration
%
args
.
save_interval
==
0
):
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
if
args
.
save
and
args
.
save_interval
and
\
iteration
%
args
.
save_interval
==
0
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Evaluation
if
args
.
eval_interval
and
iteration
%
args
.
eval_interval
==
0
:
...
...
@@ -226,13 +196,10 @@ def _train(
valid_dataloader
,
model
,
iteration
,
process_non_loss_data_func
,
False
,
)
# Checkpointing at the end of each epoch.
if
args
.
save
:
save_checkpoint
(
iteration
,
model
,
optimizer
,
opt_param_scheduler
)
# Callback at the end of each epoch.
if
end_of_epoch_callback
is
not
None
:
end_of_epoch_callback
(
model
,
epoch
)
...
...
@@ -241,7 +208,9 @@ def _train(
def
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
forward_step
,
model_type
=
ModelType
.
encoder_or_decoder
,
process_non_loss_data_func
=
None
,
end_of_epoch_callback_provider
=
None
,
):
"""Main finetune function used across all tasks."""
...
...
@@ -266,7 +235,12 @@ def finetune(
# Build model, optimizer and learning rate scheduler.
timers
(
"model and optimizer"
).
start
()
model
,
optimizer
,
opt_param_scheduler
=
setup_model_and_optimizer
(
model_provider
)
model
,
optimizer
,
opt_param_scheduler
=
\
setup_model_and_optimizer
(
model_provider
,
model_type
,
scale_lr_cond
=
lambda
name
,
param
:
".head."
in
name
,
lr_mult
=
args
.
head_lr_mult
)
timers
(
"model and optimizer"
).
stop
()
# If pretrained checkpoint is provided and we have not trained for
...
...
@@ -274,13 +248,34 @@ def finetune(
# checkpoint.
timers
(
"pretrained checkpoint"
).
start
()
if
args
.
iteration
==
0
and
args
.
pretrained_checkpoint
is
not
None
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
original_load
if
args
.
pretrained_checkpoint_type
==
'default'
:
original_load
=
args
.
load
args
.
load
=
args
.
pretrained_checkpoint
_
=
load_checkpoint
(
model
,
None
,
None
,
strict
=
False
)
args
.
load
=
original_load
elif
args
.
pretrained_checkpoint_type
==
'external'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
elif
args
.
pretrained_checkpoint_type
==
'constrastive'
:
unwrap_model
=
utils
.
unwrap_model
(
model
)
state_dict
=
torch
.
load
(
args
.
pretrained_checkpoint
,
map_location
=
"cpu"
)
state_dict
=
state_dict
[
"model"
]
state_dict
=
{
k
.
replace
(
"teacher.backbone."
,
""
):
v
for
k
,
v
in
state_dict
.
items
()
if
k
.
startswith
(
"teacher.backbone."
)}
unwrap_model
[
0
].
module
.
backbone
.
load_state_dict
(
state_dict
,
strict
=
False
)
else
:
raise
Exception
(
"pretrained checkpoint type {} not supported"
.
format
(
args
.
pretrained_checkpoint_type
))
# This is critical when only model is loaded. We should make sure
# master parameters are also updated.
optimizer
.
reload_model_params
()
timers
(
"pretrained checkpoint"
).
stop
()
# Print setup timing.
...
...
@@ -305,11 +300,13 @@ def finetune(
train_dataloader
,
valid_dataloader
,
end_of_epoch_callback
,
process_non_loss_data_func
,
)
# Or just evaluate.
else
:
if
end_of_epoch_callback
is
not
None
:
print_rank_0
(
"evaluation only mode, setting epoch to -1"
)
end_of_epoch_callback
(
model
,
epoch
=-
1
,
output_predictions
=
True
)
end_of_epoch_callback
(
model
,
epoch
=-
1
)
print_rank_0
(
"done :-)"
)
tasks/vision/main.py
View file @
b8428a7f
...
...
@@ -28,32 +28,24 @@ sys.path.append(
)
from
megatron
import
get_args
from
megatron.initialize
import
initialize_megatron
from
classification
import
main
def
get_tasks_args
(
parser
):
"""Provide extra arguments required for tasks."""
group
=
parser
.
add_argument_group
(
title
=
"tasks"
)
group
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
None
,
help
=
"Number of finetunning epochs. Zero results in "
"evaluation only."
,
)
group
.
add_argument
(
"--pretrained-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Pretrained checkpoint used for finetunning."
,
)
group
.
add_argument
(
"--keep-last"
,
action
=
"store_true"
,
help
=
"Keep the last batch (maybe incomplete) in"
"the data loader"
,
)
group
.
add_argument
(
'--task'
,
type
=
str
,
default
=
'segment'
,
choices
=
[
'classify'
,
'segment_setr'
,
'segment_segformer'
],
help
=
'task name.'
)
group
.
add_argument
(
"--epochs"
,
type
=
int
,
default
=
None
,
help
=
"Number of finetunning epochs. Zero results in "
"evaluation only."
)
group
.
add_argument
(
'--pretrained-checkpoint-type'
,
type
=
str
,
default
=
'default'
,
choices
=
[
'default'
,
'external'
,
'constrastive'
],
help
=
'Type of pretrained checkpoint'
)
group
.
add_argument
(
"--pretrained-checkpoint"
,
type
=
str
,
default
=
None
,
help
=
"Pretrained checkpoint used for finetunning."
)
group
.
add_argument
(
'--seg-stride'
,
type
=
int
,
default
=
None
,
help
=
'sliding window stride during evaluation'
)
return
parser
...
...
@@ -61,4 +53,14 @@ if __name__ == "__main__":
initialize_megatron
(
extra_args_provider
=
get_tasks_args
)
args
=
get_args
()
main
()
if
args
.
task
==
'classify'
:
from
tasks.vision.classification.classification
import
main
main
()
elif
args
.
task
==
'segment_setr'
:
from
tasks.vision.segmentation.finetune_setr
import
main
main
()
elif
args
.
task
==
'segment_segformer'
:
from
tasks.vision.segmentation.finetune_segformer
import
main
main
()
tasks/vision/segmentation/cityscapes.py
0 → 100644
View file @
b8428a7f
# BSD 3-Clause License
#
# Copyright (c) Soumith Chintala 2016,
# All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# * Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# * Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# * Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
# code taken from
# https://github.com/pytorch/vision/blob/main/torchvision/datasets/cityscapes.py
# modified it to change max label index from 255 to 19 (num_classes)
import
torch
import
json
import
os
from
collections
import
namedtuple
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Optional
,
Union
,
Tuple
import
numpy
as
np
from
torchvision.datasets.utils
import
extract_archive
,
verify_str_arg
,
iterable_to_str
from
torchvision.datasets
import
VisionDataset
from
PIL
import
Image
from
megatron
import
print_rank_0
class
Cityscapes
(
VisionDataset
):
"""`Cityscapes <http://www.cityscapes-dataset.com/>`_ Dataset.
Args:
root (string): Root directory of dataset where directory ``leftImg8bit``
and ``gtFine`` or ``gtCoarse`` are located.
split (string, optional): The image split to use, ``train``, ``test`` or ``val`` if mode="fine"
otherwise ``train``, ``train_extra`` or ``val``
mode (string, optional): The quality mode to use, ``fine`` or ``coarse``
target_type (string or list, optional): Type of target to use, ``instance``, ``semantic``, ``polygon``
or ``color``. Can also be a list to output a tuple with all specified target types.
transform (callable, optional): A function/transform that takes in a PIL image
and returns a transformed version. E.g, ``transforms.RandomCrop``
target_transform (callable, optional): A function/transform that takes in the
target and transforms it.
transforms (callable, optional): A function/transform that takes input sample and its target as entry
and returns a transformed version.
Examples:
Get semantic segmentation target
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type='semantic')
img, smnt = dataset[0]
Get multiple targets
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='train', mode='fine',
target_type=['instance', 'color', 'polygon'])
img, (inst, col, poly) = dataset[0]
Validate on the "coarse" set
.. code-block:: python
dataset = Cityscapes('./data/cityscapes', split='val', mode='coarse',
target_type='semantic')
img, smnt = dataset[0]
"""
num_classes
=
19
ignore_index
=
19
color_table
=
torch
.
tensor
(
[[
128
,
64
,
128
],
[
244
,
35
,
232
],
[
70
,
70
,
70
],
[
102
,
102
,
156
],
[
190
,
153
,
153
],
[
153
,
153
,
153
],
[
250
,
170
,
30
],
[
220
,
220
,
0
],
[
107
,
142
,
35
],
[
152
,
251
,
152
],
[
70
,
130
,
180
],
[
220
,
20
,
60
],
[
255
,
0
,
0
],
[
0
,
0
,
142
],
[
0
,
0
,
70
],
[
0
,
60
,
100
],
[
0
,
80
,
100
],
[
0
,
0
,
230
],
[
119
,
11
,
32
],
[
0
,
0
,
0
]],
dtype
=
torch
.
float
,
device
=
'cuda'
)
# Based on https://github.com/mcordts/cityscapesScripts
CityscapesClass
=
namedtuple
(
'CityscapesClass'
,
[
'name'
,
'id'
,
'train_id'
,
'category'
,
'category_id'
,
'has_instances'
,
'ignore_in_eval'
,
'color'
])
classes
=
[
CityscapesClass
(
'unlabeled'
,
0
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'ego vehicle'
,
1
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'rectification border'
,
2
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'out of roi'
,
3
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'static'
,
4
,
19
,
'void'
,
0
,
False
,
True
,
(
0
,
0
,
0
)),
CityscapesClass
(
'dynamic'
,
5
,
19
,
'void'
,
0
,
False
,
True
,
(
111
,
74
,
0
)),
CityscapesClass
(
'ground'
,
6
,
19
,
'void'
,
0
,
False
,
True
,
(
81
,
0
,
81
)),
CityscapesClass
(
'road'
,
7
,
0
,
'flat'
,
1
,
False
,
False
,
(
128
,
64
,
128
)),
CityscapesClass
(
'sidewalk'
,
8
,
1
,
'flat'
,
1
,
False
,
False
,
(
244
,
35
,
232
)),
CityscapesClass
(
'parking'
,
9
,
19
,
'flat'
,
1
,
False
,
True
,
(
250
,
170
,
160
)),
CityscapesClass
(
'rail track'
,
10
,
19
,
'flat'
,
1
,
False
,
True
,
(
230
,
150
,
140
)),
CityscapesClass
(
'building'
,
11
,
2
,
'construction'
,
2
,
False
,
False
,
(
70
,
70
,
70
)),
CityscapesClass
(
'wall'
,
12
,
3
,
'construction'
,
2
,
False
,
False
,
(
102
,
102
,
156
)),
CityscapesClass
(
'fence'
,
13
,
4
,
'construction'
,
2
,
False
,
False
,
(
190
,
153
,
153
)),
CityscapesClass
(
'guard rail'
,
14
,
19
,
'construction'
,
2
,
False
,
True
,
(
180
,
165
,
180
)),
CityscapesClass
(
'bridge'
,
15
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
100
,
100
)),
CityscapesClass
(
'tunnel'
,
16
,
19
,
'construction'
,
2
,
False
,
True
,
(
150
,
120
,
90
)),
CityscapesClass
(
'pole'
,
17
,
5
,
'object'
,
3
,
False
,
False
,
(
153
,
153
,
153
)),
CityscapesClass
(
'polegroup'
,
18
,
19
,
'object'
,
3
,
False
,
True
,
(
153
,
153
,
153
)),
CityscapesClass
(
'traffic light'
,
19
,
6
,
'object'
,
3
,
False
,
False
,
(
250
,
170
,
30
)),
CityscapesClass
(
'traffic sign'
,
20
,
7
,
'object'
,
3
,
False
,
False
,
(
220
,
220
,
0
)),
CityscapesClass
(
'vegetation'
,
21
,
8
,
'nature'
,
4
,
False
,
False
,
(
107
,
142
,
35
)),
CityscapesClass
(
'terrain'
,
22
,
9
,
'nature'
,
4
,
False
,
False
,
(
152
,
251
,
152
)),
CityscapesClass
(
'sky'
,
23
,
10
,
'sky'
,
5
,
False
,
False
,
(
70
,
130
,
180
)),
CityscapesClass
(
'person'
,
24
,
11
,
'human'
,
6
,
True
,
False
,
(
220
,
20
,
60
)),
CityscapesClass
(
'rider'
,
25
,
12
,
'human'
,
6
,
True
,
False
,
(
255
,
0
,
0
)),
CityscapesClass
(
'car'
,
26
,
13
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
142
)),
CityscapesClass
(
'truck'
,
27
,
14
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
70
)),
CityscapesClass
(
'bus'
,
28
,
15
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
60
,
100
)),
CityscapesClass
(
'caravan'
,
29
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
90
)),
CityscapesClass
(
'trailer'
,
30
,
19
,
'vehicle'
,
7
,
True
,
True
,
(
0
,
0
,
110
)),
CityscapesClass
(
'train'
,
31
,
16
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
80
,
100
)),
CityscapesClass
(
'motorcycle'
,
32
,
17
,
'vehicle'
,
7
,
True
,
False
,
(
0
,
0
,
230
)),
CityscapesClass
(
'bicycle'
,
33
,
18
,
'vehicle'
,
7
,
True
,
False
,
(
119
,
11
,
32
)),
CityscapesClass
(
'license plate'
,
-
1
,
-
1
,
'vehicle'
,
7
,
False
,
True
,
(
0
,
0
,
142
)),
]
# label2trainid
label2trainid
=
{
label
.
id
:
label
.
train_id
for
label
in
classes
}
def
__init__
(
self
,
root
:
str
,
split
:
str
=
"train"
,
mode
:
str
=
"fine"
,
resolution
:
int
=
1024
,
transform
:
Optional
[
Callable
]
=
None
,
target_transform
:
Optional
[
Callable
]
=
None
,
transforms
:
Optional
[
Callable
]
=
None
,
)
->
None
:
super
(
Cityscapes
,
self
).
__init__
(
root
,
transforms
,
transform
,
target_transform
)
self
.
mode
=
'gtFine'
if
mode
==
'fine'
else
'gtCoarse'
self
.
images_dir
=
os
.
path
.
join
(
self
.
root
,
'leftImg8bit_trainvaltest/leftImg8bit'
,
split
)
self
.
targets_dir
=
os
.
path
.
join
(
self
.
root
,
'gtFine_trainvaltest/gtFine'
,
split
)
self
.
split
=
split
self
.
resolution
=
resolution
self
.
images
=
[]
self
.
targets
=
[]
for
city
in
sorted
(
os
.
listdir
(
self
.
images_dir
)):
img_dir
=
os
.
path
.
join
(
self
.
images_dir
,
city
)
target_dir
=
os
.
path
.
join
(
self
.
targets_dir
,
city
)
for
file_name
in
os
.
listdir
(
img_dir
):
target_name
=
'{}_{}_labelIds.png'
.
format
(
file_name
.
split
(
'_leftImg8bit'
)[
0
],
self
.
mode
)
self
.
images
.
append
(
os
.
path
.
join
(
img_dir
,
file_name
))
self
.
targets
.
append
(
os
.
path
.
join
(
target_dir
,
target_name
))
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
Any
,
Any
]:
"""
Args:
index (int): Index
Returns:
tuple: (image, target) where target is a tuple of all target types if target_type is a list with more
than one item. Otherwise target is a json object if target_type="polygon", else the image segmentation.
"""
image
=
Image
.
open
(
self
.
images
[
index
]).
convert
(
'RGB'
)
target
=
Image
.
open
(
self
.
targets
[
index
])
target
=
np
.
array
(
target
)
target_copy
=
target
.
copy
()
for
k
,
v
in
Cityscapes
.
label2trainid
.
items
():
binary_target
=
(
target
==
k
)
target_copy
[
binary_target
]
=
v
target
=
target_copy
target
=
Image
.
fromarray
(
target
.
astype
(
np
.
uint8
))
if
self
.
transforms
is
not
None
:
image
,
target
=
self
.
transforms
(
image
,
target
)
return
image
,
target
def
__len__
(
self
)
->
int
:
# len(self.images)
return
len
(
self
.
images
)
tasks/vision/segmentation/data.py
0 → 100644
View file @
b8428a7f
import
random
import
os
import
math
import
mmcv
import
torch
import
numpy
as
np
import
torchvision.transforms
as
T
from
torchvision
import
datasets
from
torch.utils.data
import
Dataset
from
megatron.data.autoaugment
import
ImageNetPolicy
from
tasks.vision.segmentation.cityscapes
import
Cityscapes
import
tasks.vision.segmentation.transforms
as
ET
from
megatron.data.autoaugment
import
ImageNetPolicy
from
megatron
import
get_args
from
PIL
import
Image
,
ImageOps
class
VitSegmentationJointTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
if
self
.
train
:
self
.
transform0
=
ET
.
RandomSizeAndCrop
(
resolution
)
self
.
transform1
=
ET
.
RandomHorizontallyFlip
()
def
__call__
(
self
,
img
,
mask
):
if
self
.
train
:
img
,
mask
=
self
.
transform0
(
img
,
mask
)
img
,
mask
=
self
.
transform1
(
img
,
mask
)
return
img
,
mask
class
VitSegmentationImageTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
args
=
get_args
()
self
.
train
=
train
assert
args
.
fp16
or
args
.
bf16
self
.
data_type
=
torch
.
half
if
args
.
fp16
else
torch
.
bfloat16
self
.
mean_std
=
args
.
mean_std
if
self
.
train
:
assert
resolution
is
not
None
self
.
transform
=
T
.
Compose
([
ET
.
PhotoMetricDistortion
(),
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
else
:
self
.
transform
=
T
.
Compose
([
T
.
ToTensor
(),
T
.
Normalize
(
*
self
.
mean_std
),
T
.
ConvertImageDtype
(
self
.
data_type
)
])
def
__call__
(
self
,
input
):
output
=
self
.
transform
(
input
)
return
output
class
VitSegmentationTargetTransform
():
def
__init__
(
self
,
train
=
True
,
resolution
=
None
):
self
.
train
=
train
def
__call__
(
self
,
input
):
output
=
torch
.
from_numpy
(
np
.
array
(
input
,
dtype
=
np
.
int32
)).
long
()
return
output
class
RandomSeedSegmentationDataset
(
Dataset
):
def
__init__
(
self
,
dataset
,
joint_transform
,
image_transform
,
target_transform
):
args
=
get_args
()
self
.
base_seed
=
args
.
seed
self
.
curr_seed
=
self
.
base_seed
self
.
dataset
=
dataset
self
.
joint_transform
=
joint_transform
self
.
image_transform
=
image_transform
self
.
target_transform
=
target_transform
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
set_epoch
(
self
,
epoch
):
self
.
curr_seed
=
self
.
base_seed
+
100
*
epoch
def
__getitem__
(
self
,
idx
):
seed
=
idx
+
self
.
curr_seed
img
,
mask
=
self
.
dataset
[
idx
]
torch
.
manual_seed
(
seed
)
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
img
,
mask
=
self
.
joint_transform
(
img
,
mask
)
img
=
self
.
image_transform
(
img
)
mask
=
self
.
target_transform
(
mask
)
return
img
,
mask
def
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
):
args
=
get_args
()
args
.
num_classes
=
Cityscapes
.
num_classes
args
.
ignore_index
=
Cityscapes
.
ignore_index
args
.
color_table
=
Cityscapes
.
color_table
args
.
mean_std
=
([
0.485
,
0.456
,
0.406
],
[
0.229
,
0.224
,
0.225
])
train_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
True
,
resolution
=
image_size
)
val_joint_transform
=
\
VitSegmentationJointTransform
(
train
=
False
,
resolution
=
image_size
)
train_image_transform
=
\
VitSegmentationImageTransform
(
train
=
True
,
resolution
=
image_size
)
val_image_transform
=
\
VitSegmentationImageTransform
(
train
=
False
,
resolution
=
image_size
)
train_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
True
,
resolution
=
image_size
)
val_target_transform
=
\
VitSegmentationTargetTransform
(
train
=
False
,
resolution
=
image_size
)
# training dataset
train_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'train'
,
mode
=
'fine'
,
resolution
=
image_size
)
train_data
=
RandomSeedSegmentationDataset
(
train_data
,
joint_transform
=
train_joint_transform
,
image_transform
=
train_image_transform
,
target_transform
=
train_target_transform
)
# validation dataset
val_data
=
Cityscapes
(
root
=
data_path
[
0
],
split
=
'val'
,
mode
=
'fine'
,
resolution
=
image_size
)
val_data
=
RandomSeedSegmentationDataset
(
val_data
,
joint_transform
=
val_joint_transform
,
image_transform
=
val_image_transform
,
target_transform
=
val_target_transform
)
return
train_data
,
val_data
def
build_train_valid_datasets
(
data_path
,
image_size
):
return
build_cityscapes_train_valid_datasets
(
data_path
,
image_size
)
tasks/vision/segmentation/finetune_segformer.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SegformerSegmentationModel
from
megatron.model.vision.utils
import
resize
def
calculate_iou
(
hist_data
):
acc
=
np
.
diag
(
hist_data
).
sum
()
/
hist_data
.
sum
()
acc_cls
=
np
.
diag
(
hist_data
)
/
hist_data
.
sum
(
axis
=
1
)
acc_cls
=
np
.
nanmean
(
acc_cls
)
divisor
=
hist_data
.
sum
(
axis
=
1
)
+
hist_data
.
sum
(
axis
=
0
)
-
\
np
.
diag
(
hist_data
)
iu
=
np
.
diag
(
hist_data
)
/
divisor
return
iu
,
acc
,
acc_cls
def
fast_hist
(
pred
,
gtruth
,
num_classes
):
# mask indicates pixels we care about
mask
=
(
gtruth
>=
0
)
&
(
gtruth
<
num_classes
)
# stretch ground truth labels by num_classes
# class 0 -> 0
# class 1 -> 19
# class 18 -> 342
#
# TP at 0 + 0, 1 + 1, 2 + 2 ...
#
# TP exist where value == num_classes*class_id + class_id
# FP = row[class].sum() - TP
# FN = col[class].sum() - TP
hist
=
np
.
bincount
(
num_classes
*
gtruth
[
mask
].
astype
(
int
)
+
pred
[
mask
],
minlength
=
num_classes
**
2
)
hist
=
hist
.
reshape
(
num_classes
,
num_classes
)
return
hist
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
model
=
SegformerSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
print_rank_0
(
"model = {}"
.
format
(
model
))
return
model
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
logits
=
output_tensor
.
contiguous
().
float
()
logits
=
resize
(
logits
,
size
=
masks
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
# Cross-entropy loss.
# weight = calculate_weight(masks, num_classes)
loss
=
F
.
cross_entropy
(
logits
,
masks
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
logits
=
resize
(
logits
,
size
=
labels
.
shape
[
1
:],
mode
=
'bilinear'
,
align_corners
=
False
)
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
cpu
().
numpy
()
performs
=
fast_hist
(
preds
.
flatten
(),
labels
.
cpu
().
numpy
().
flatten
(),
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
# Forward model.
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
loss_func
,
labels
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
performs_tensor
=
torch
.
cuda
.
FloatTensor
(
performs
)
torch
.
distributed
.
all_reduce
(
performs_tensor
,
group
=
mpu
.
get_data_parallel_group
())
hist
=
performs_tensor
.
cpu
().
numpy
()
iu
,
acc
,
acc_cls
=
calculate_iou
(
hist
)
miou
=
np
.
nanmean
(
iu
)
return
iu
,
miou
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
tasks/vision/segmentation/finetune_setr.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision-classification finetuning/evaluation."""
import
torch
import
torch.nn.functional
as
F
from
functools
import
partial
from
megatron
import
get_args
,
get_timers
from
megatron
import
mpu
,
print_rank_0
,
print_rank_last
from
tasks.vision.finetune_utils
import
finetune
from
tasks.vision.finetune_utils
import
build_data_loader
from
megatron.utils
import
average_losses_across_data_parallel_group
from
megatron.schedules
import
get_forward_backward_func
from
tasks.vision.segmentation.metrics
import
CFMatrix
from
tasks.vision.segmentation.data
import
build_train_valid_datasets
from
tasks.vision.segmentation.seg_models
import
SetrSegmentationModel
from
tasks.vision.segmentation.utils
import
slidingcrops
,
slidingjoins
def
segmentation
():
def
train_valid_datasets_provider
():
"""Build train and validation dataset."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
return
train_ds
,
valid_ds
def
model_provider
(
pre_process
=
True
,
post_process
=
True
):
"""Build the model."""
args
=
get_args
()
return
SetrSegmentationModel
(
num_classes
=
args
.
num_classes
,
pre_process
=
pre_process
,
post_process
=
post_process
)
def
process_batch
(
batch
):
"""Process batch and produce inputs for the model."""
images
=
batch
[
0
].
cuda
().
contiguous
()
masks
=
batch
[
1
].
cuda
().
contiguous
()
return
images
,
masks
def
calculate_weight
(
masks
,
num_classes
):
bins
=
torch
.
histc
(
masks
,
bins
=
num_classes
,
min
=
0.0
,
max
=
num_classes
)
hist_norm
=
bins
.
float
()
/
bins
.
sum
()
hist
=
((
bins
!=
0
).
float
()
*
(
1.
-
hist_norm
))
+
1.0
return
hist
def
cross_entropy_loss_func
(
images
,
masks
,
output_tensor
,
non_loss_data
=
False
):
args
=
get_args
()
ignore_index
=
args
.
ignore_index
color_table
=
args
.
color_table
weight
=
calculate_weight
(
masks
,
args
.
num_classes
)
logits
=
output_tensor
.
contiguous
().
float
()
loss
=
F
.
cross_entropy
(
logits
,
masks
,
weight
=
weight
,
ignore_index
=
ignore_index
)
if
not
non_loss_data
:
# Reduce loss for logging.
averaged_loss
=
average_losses_across_data_parallel_group
([
loss
])
return
loss
,
{
'lm loss'
:
averaged_loss
[
0
]}
else
:
seg_mask
=
logits
.
argmax
(
dim
=
1
)
output_mask
=
F
.
embedding
(
seg_mask
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
gt_mask
=
F
.
embedding
(
masks
,
color_table
).
permute
(
0
,
3
,
1
,
2
)
return
torch
.
cat
((
images
,
output_mask
,
gt_mask
),
dim
=
2
),
loss
def
_cross_entropy_forward_step
(
batch
,
model
):
"""Simple forward step with cross-entropy loss."""
args
=
get_args
()
timers
=
get_timers
()
# Get the batch.
timers
(
"batch generator"
).
start
()
import
types
if
isinstance
(
batch
,
types
.
GeneratorType
):
batch_
=
next
(
batch
)
else
:
batch_
=
batch
images
,
masks
=
process_batch
(
batch_
)
timers
(
"batch generator"
).
stop
()
# Forward model.
if
not
model
.
training
:
images
,
masks
,
_
,
_
=
slidingcrops
(
images
,
masks
)
#print_rank_0("images size = {}".format(images.size()))
if
not
model
.
training
:
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
else
:
output_tensor
=
model
(
images
)
return
output_tensor
,
partial
(
cross_entropy_loss_func
,
images
,
masks
)
def
calculate_correct_answers
(
model
,
dataloader
,
epoch
):
"""Calculate correct over total answers"""
forward_backward_func
=
get_forward_backward_func
()
for
m
in
model
:
m
.
eval
()
def
loss_func
(
labels
,
slices_info
,
img_size
,
output_tensor
):
args
=
get_args
()
logits
=
output_tensor
loss_dict
=
{}
# Compute the correct answers.
probs
=
logits
.
contiguous
().
float
().
softmax
(
dim
=
1
)
max_probs
,
preds
=
torch
.
max
(
probs
,
1
)
preds
=
preds
.
int
()
preds
,
labels
=
slidingjoins
(
preds
,
max_probs
,
labels
,
slices_info
,
img_size
)
_
,
performs
=
CFMatrix
()(
preds
,
labels
,
args
.
ignore_index
)
loss_dict
[
'performs'
]
=
performs
return
0
,
loss_dict
# defined inside to capture output_predictions
def
correct_answers_forward_step
(
batch
,
model
):
args
=
get_args
()
try
:
batch_
=
next
(
batch
)
except
BaseException
:
batch_
=
batch
images
,
labels
=
process_batch
(
batch_
)
assert
not
model
.
training
images
,
labels
,
slices_info
,
img_size
=
slidingcrops
(
images
,
labels
)
# Forward model.
output_tensor
=
torch
.
cat
([
model
(
image
)
for
image
in
torch
.
split
(
images
,
args
.
micro_batch_size
)])
return
output_tensor
,
partial
(
loss_func
,
labels
,
slices_info
,
img_size
)
with
torch
.
no_grad
():
# For all the batches in the dataset.
performs
=
None
for
_
,
batch
in
enumerate
(
dataloader
):
loss_dicts
=
forward_backward_func
(
correct_answers_forward_step
,
batch
,
model
,
optimizer
=
None
,
timers
=
None
,
forward_only
=
True
)
for
loss_dict
in
loss_dicts
:
if
performs
is
None
:
performs
=
loss_dict
[
'performs'
]
else
:
performs
+=
loss_dict
[
'performs'
]
for
m
in
model
:
m
.
train
()
# Reduce.
if
mpu
.
is_pipeline_last_stage
():
torch
.
distributed
.
all_reduce
(
performs
,
group
=
mpu
.
get_data_parallel_group
())
# Print on screen.
# performs[int(ch), :] = [nb_tp, nb_fp, nb_tn, nb_fn]
true_positive
=
performs
[:,
0
]
false_positive
=
performs
[:,
1
]
false_negative
=
performs
[:,
3
]
iou
=
true_positive
/
(
true_positive
+
false_positive
+
false_negative
)
miou
=
iou
[
~
torch
.
isnan
(
iou
)].
mean
()
return
iou
.
tolist
(),
miou
.
item
()
def
accuracy_func_provider
():
"""Provide function that calculates accuracies."""
args
=
get_args
()
train_ds
,
valid_ds
=
build_train_valid_datasets
(
data_path
=
args
.
data_path
,
image_size
=
(
args
.
img_h
,
args
.
img_w
)
)
dataloader
=
build_data_loader
(
valid_ds
,
args
.
micro_batch_size
,
num_workers
=
args
.
num_workers
,
drop_last
=
(
mpu
.
get_data_parallel_world_size
()
>
1
),
shuffle
=
False
)
def
metrics_func
(
model
,
epoch
):
print_rank_0
(
"calculating metrics ..."
)
iou
,
miou
=
calculate_correct_answers
(
model
,
dataloader
,
epoch
)
print_rank_last
(
" >> |epoch: {}| overall: iou = {},"
"miou = {:.4f} %"
.
format
(
epoch
,
iou
,
miou
*
100.0
)
)
return
metrics_func
def
dump_output_data
(
data
,
iteration
,
writer
):
for
(
output_tb
,
loss
)
in
data
:
# output_tb[output_tb < 0] = 0
# output_tb[output_tb > 1] = 1
writer
.
add_images
(
"image-outputseg-realseg"
,
output_tb
,
global_step
=
None
,
walltime
=
None
,
dataformats
=
'NCHW'
)
"""Finetune/evaluate."""
finetune
(
train_valid_datasets_provider
,
model_provider
,
forward_step
=
_cross_entropy_forward_step
,
process_non_loss_data_func
=
dump_output_data
,
end_of_epoch_callback_provider
=
accuracy_func_provider
,
)
def
main
():
segmentation
()
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment