Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
06fc51ce
Commit
06fc51ce
authored
Feb 18, 2022
by
Jared Casper
Browse files
Merge branch 'main' into checkpoint_util
parents
ec561daa
0ed2f6ac
Changes
66
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3140 additions
and
182 deletions
+3140
-182
megatron/fused_kernels/scaled_softmax_cuda.cu
megatron/fused_kernels/scaled_softmax_cuda.cu
+104
-0
megatron/global_vars.py
megatron/global_vars.py
+15
-0
megatron/initialize.py
megatron/initialize.py
+7
-4
megatron/model/distributed.py
megatron/model/distributed.py
+7
-0
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+24
-3
megatron/model/fused_softmax.py
megatron/model/fused_softmax.py
+37
-4
megatron/model/language_model.py
megatron/model/language_model.py
+0
-4
megatron/model/module.py
megatron/model/module.py
+25
-30
megatron/model/transformer.py
megatron/model/transformer.py
+156
-57
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+97
-0
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+288
-0
megatron/model/vision/esvit_swin_backbone.py
megatron/model/vision/esvit_swin_backbone.py
+849
-0
megatron/model/vision/inpainting.py
megatron/model/vision/inpainting.py
+151
-0
megatron/model/vision/knn_monitor.py
megatron/model/vision/knn_monitor.py
+128
-0
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+420
-0
megatron/model/vision/swin_backbone.py
megatron/model/vision/swin_backbone.py
+625
-0
megatron/model/vision/utils.py
megatron/model/vision/utils.py
+27
-0
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+91
-67
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+6
-0
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+83
-13
No files found.
megatron/fused_kernels/scaled_softmax_cuda.cu
0 → 100644
View file @
06fc51ce
/* coding=utf-8
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <ATen/ATen.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_fp16.h>
#include <cuda_profiler_api.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include "scaled_masked_softmax.h"
#include "type_shim.h"
namespace
multihead_attn
{
namespace
fused_softmax
{
namespace
scaled_softmax
{
torch
::
Tensor
fwd_cuda
(
torch
::
Tensor
const
&
input
,
float
scale_factor
)
{
// input is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
input
.
size
(
0
);
const
int
attn_heads
=
input
.
size
(
1
);
const
int
query_seq_len
=
input
.
size
(
2
);
const
int
key_seq_len
=
input
.
size
(
3
);
TORCH_INTERNAL_ASSERT
(
key_seq_len
<=
4096
);
TORCH_INTERNAL_ASSERT
(
query_seq_len
>
1
);
// Output
auto
act_options
=
input
.
options
().
requires_grad
(
false
);
torch
::
Tensor
softmax_results
=
torch
::
empty
({
batches
,
attn_heads
,
query_seq_len
,
key_seq_len
},
act_options
);
// Softmax Intermediate Result Ptr
void
*
input_ptr
=
static_cast
<
void
*>
(
input
.
data_ptr
());
void
*
softmax_results_ptr
=
static_cast
<
void
*>
(
softmax_results
.
data_ptr
());
DISPATCH_HALF_AND_BFLOAT
(
input
.
scalar_type
(),
"dispatch_scaled_softmax_forward"
,
dispatch_scaled_softmax_forward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
softmax_results_ptr
),
reinterpret_cast
<
const
scalar_t
*>
(
input_ptr
),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
return
softmax_results
;
}
torch
::
Tensor
bwd_cuda
(
torch
::
Tensor
const
&
output_grads_
,
torch
::
Tensor
const
&
softmax_results_
,
float
scale_factor
)
{
auto
output_grads
=
output_grads_
.
contiguous
();
auto
softmax_results
=
softmax_results_
.
contiguous
();
//output grads is a 4d tensor with dimensions [batches, attn_heads, seq_len, seq_len]
const
int
batches
=
output_grads
.
size
(
0
);
const
int
attn_heads
=
output_grads
.
size
(
1
);
const
int
query_seq_len
=
output_grads
.
size
(
2
);
const
int
key_seq_len
=
output_grads
.
size
(
3
);
void
*
output_grads_ptr
=
static_cast
<
void
*>
(
output_grads
.
data_ptr
());
//Softmax Grad
DISPATCH_HALF_AND_BFLOAT
(
output_grads_
.
scalar_type
(),
"dispatch_scaled_masked_softmax_backward"
,
dispatch_scaled_masked_softmax_backward
<
scalar_t
,
scalar_t
,
float
>
(
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
*>
(
output_grads_ptr
),
reinterpret_cast
<
scalar_t
const
*>
(
softmax_results
.
data_ptr
()),
scale_factor
,
query_seq_len
,
key_seq_len
,
batches
,
attn_heads
);
);
//backward pass is completely in-place
return
output_grads
;
}
}
}
}
megatron/global_vars.py
View file @
06fc51ce
...
...
@@ -21,6 +21,7 @@ import time
import
torch
from
megatron
import
dist_signal_handler
from
megatron.tokenizer
import
build_tokenizer
from
.arguments
import
parse_args
from
.microbatches
import
build_num_microbatches_calculator
...
...
@@ -31,6 +32,7 @@ _GLOBAL_TOKENIZER = None
_GLOBAL_TENSORBOARD_WRITER
=
None
_GLOBAL_ADLR_AUTORESUME
=
None
_GLOBAL_TIMERS
=
None
_GLOBAL_SIGNAL_HANDLER
=
None
def
get_args
():
...
...
@@ -75,6 +77,14 @@ def get_timers():
_ensure_var_is_initialized
(
_GLOBAL_TIMERS
,
'timers'
)
return
_GLOBAL_TIMERS
def
get_signal_handler
():
_ensure_var_is_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
return
_GLOBAL_SIGNAL_HANDLER
def
_set_signal_handler
():
global
_GLOBAL_SIGNAL_HANDLER
_ensure_var_is_not_initialized
(
_GLOBAL_SIGNAL_HANDLER
,
'signal handler'
)
_GLOBAL_SIGNAL_HANDLER
=
dist_signal_handler
.
DistributedSignalHandler
().
__enter__
()
def
set_global_variables
(
extra_args_provider
=
None
,
args_defaults
=
{},
ignore_unknown_args
=
False
,
parse_args
=
True
):
...
...
@@ -93,10 +103,15 @@ def set_global_variables(extra_args_provider=None, args_defaults={},
_set_adlr_autoresume
(
args
)
_set_timers
()
if
args
.
exit_signal_handler
:
_set_signal_handler
()
def
set_args
(
args
):
global
_GLOBAL_ARGS
_GLOBAL_ARGS
=
args
def
_parse_args
(
extra_args_provider
=
None
,
defaults
=
{},
ignore_unknown_args
=
False
):
"""Parse entire arguments."""
...
...
megatron/initialize.py
View file @
06fc51ce
...
...
@@ -62,7 +62,7 @@ def initialize_megatron(extra_args_provider=None, args_defaults={},
# Random seeds for reproducibility.
if
args
.
rank
==
0
:
print
(
'> setting random seeds to {} ...'
.
format
(
args
.
seed
))
_set_random_seed
(
args
.
seed
)
_set_random_seed
(
args
.
seed
,
args
.
data_parallel_random_init
)
# Set pytorch JIT layer fusion options.
_set_jit_fusion_options
()
...
...
@@ -118,7 +118,7 @@ def _compile_dependencies():
args
.
micro_batch_size
# Constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
2048
and
\
custom_kernel_constraint
=
seq_len
>
16
and
seq_len
<=
4096
and
\
seq_len
%
4
==
0
and
attn_batch_size
%
4
==
0
# Print a warning.
if
not
((
args
.
fp16
or
args
.
bf16
)
and
...
...
@@ -180,7 +180,7 @@ def _initialize_distributed():
torch
.
distributed
.
init_process_group
(
backend
=
args
.
distributed_backend
,
world_size
=
args
.
world_size
,
rank
=
args
.
rank
,
timeout
=
timedelta
(
days
=
7
))
timeout
=
timedelta
(
minutes
=
10
))
# Set the tensor model-parallel, pipeline model-parallel, and
# data-parallel communicators.
...
...
@@ -203,11 +203,14 @@ def _init_autoresume():
torch
.
distributed
.
barrier
()
def
_set_random_seed
(
seed_
):
def
_set_random_seed
(
seed_
,
data_parallel_random_init
=
False
):
"""Set random seed for reproducability."""
if
seed_
is
not
None
and
seed_
>
0
:
# Ensure that different pipeline MP stages get different seeds.
seed
=
seed_
+
(
100
*
mpu
.
get_pipeline_model_parallel_rank
())
# Ensure different data parallel ranks get different seeds
if
data_parallel_random_init
:
seed
=
seed
+
(
10
*
mpu
.
get_data_parallel_rank
())
random
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
...
...
megatron/model/distributed.py
View file @
06fc51ce
...
...
@@ -185,6 +185,13 @@ class DistributedDataParallel(DistributedDataParallelBase):
buffer_
.
zero
()
def
broadcast_params
(
self
):
for
param
in
self
.
module
.
parameters
():
torch
.
distributed
.
broadcast
(
param
.
data
,
src
=
mpu
.
get_data_parallel_src_rank
(),
group
=
mpu
.
get_data_parallel_group
())
def
allreduce_gradients
(
self
):
"""Reduce gradients across data parallel ranks."""
# If we have buffers, simply reduce the data in the buffer.
...
...
megatron/model/fused_layer_norm.py
View file @
06fc51ce
...
...
@@ -23,6 +23,12 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
import
importlib
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
except
:
HAVE_PERSIST_LAYER_NORM
=
False
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
None
...
...
@@ -61,13 +67,23 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_mix_prec_layer_norm_cuda
fused_mix_prec_layer_norm_cuda
=
importlib
.
import_module
(
"fused_mix_prec_layer_norm_cuda"
)
# List of hiddens sizes supported in the persistent layer norm kernel
# If the hidden size is not supported, fall back to the non-persistent
# kernel.
persist_ln_hidden_sizes
=
[
1024
,
1536
,
2048
,
2304
,
3072
,
3840
,
4096
,
5120
,
6144
,
8192
,
10240
,
12288
,
12800
,
15360
,
16384
,
18432
,
20480
,
24576
,
25600
,
30720
,
32768
,
40960
,
49152
,
65536
]
if
normalized_shape
not
in
persist_ln_hidden_sizes
or
\
not
HAVE_PERSIST_LAYER_NORM
:
no_persist_layer_norm
=
True
if
isinstance
(
normalized_shape
,
numbers
.
Integral
):
normalized_shape
=
(
normalized_shape
,)
self
.
normalized_shape
=
torch
.
Size
(
normalized_shape
)
...
...
@@ -75,6 +91,7 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
weight
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
def
reset_parameters
(
self
):
...
...
@@ -85,6 +102,10 @@ class MixedFusedLayerNorm(torch.nn.Module):
def
forward
(
self
,
input
):
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
if
self
.
no_persist_layer_norm
:
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
return
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
megatron/model/fused_softmax.py
View file @
06fc51ce
...
...
@@ -81,6 +81,37 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
return
input_grads
,
None
,
None
class
ScaledSoftmax
(
torch
.
autograd
.
Function
):
"""
Fused operation which performs following two operations in sequence
1. Scale the tensor.
2. Perform softmax.
"""
@
staticmethod
def
forward
(
ctx
,
inputs
,
scale
):
import
scaled_softmax_cuda
scale_t
=
torch
.
tensor
([
scale
])
softmax_results
=
scaled_softmax_cuda
.
forward
(
inputs
,
scale_t
[
0
]
)
ctx
.
save_for_backward
(
softmax_results
,
scale_t
)
return
softmax_results
@
staticmethod
def
backward
(
ctx
,
output_grads
):
import
scaled_softmax_cuda
softmax_results
,
scale_t
=
ctx
.
saved_tensors
input_grads
=
scaled_softmax_cuda
.
backward
(
output_grads
,
softmax_results
,
scale_t
[
0
]
)
return
input_grads
,
None
,
None
class
FusedScaleMaskSoftmax
(
nn
.
Module
):
"""
fused operation: scaling + mask + softmax
...
...
@@ -137,12 +168,11 @@ class FusedScaleMaskSoftmax(nn.Module):
if
(
self
.
scaled_masked_softmax_fusion
# user want to fuse
and
self
.
input_in_float16
# input must be fp16
and
mask
is
not
None
# mask tensor must not be None
and
16
<
sk
<=
2048
# sk must be 16 ~ 2048
and
16
<
sk
<=
4096
# sk must be 16 ~ 2048
and
sq
%
4
==
0
# sq must be divisor of 4
and
attn_batches
%
4
==
0
# np * b must be divisor of 4
):
if
0
<=
sk
<=
2048
:
if
0
<=
sk
<=
4096
:
batch_per_block
=
self
.
get_batch_per_block
(
sq
,
sk
,
b
,
np
)
if
self
.
attn_mask_type
==
AttnMaskType
.
causal
:
...
...
@@ -166,7 +196,10 @@ class FusedScaleMaskSoftmax(nn.Module):
return
probs
.
view
(
b
,
np
,
sq
,
sk
)
else
:
# input is 4D tensor (b, np, sq, sk)
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
if
mask
is
not
None
:
return
ScaledMaskedSoftmax
.
apply
(
input
,
mask
,
scale
)
else
:
return
ScaledSoftmax
.
apply
(
input
,
scale
)
def
forward_torch_softmax
(
self
,
input
,
mask
):
if
self
.
input_in_float16
and
self
.
softmax_in_fp32
:
...
...
megatron/model/language_model.py
View file @
06fc51ce
...
...
@@ -336,10 +336,6 @@ class TransformerLanguageModel(MegatronModule):
# Decoder (usually set to False, True if part of an encoder-decoder
# architecture and in decoder-only stage).
if
self
.
add_decoder
:
# Temporary assertion until we verify correctness of pipeline parallelism
# implementation of T5.
assert
args
.
pipeline_model_parallel_size
==
1
,
\
'pipeline parallelism is not supported in the presence of decoder'
self
.
decoder
=
ParallelTransformer
(
self
.
init_method
,
output_layer_init_method
,
...
...
megatron/model/module.py
View file @
06fc51ce
...
...
@@ -51,8 +51,7 @@ class MegatronModule(torch.nn.Module):
def
word_embeddings_weight
(
self
):
if
not
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
or
\
mpu
.
get_pipeline_model_parallel_world_size
()
==
1
:
if
self
.
pre_process
:
return
self
.
language_model
.
embedding
.
word_embeddings
.
weight
else
:
if
not
self
.
share_word_embeddings
:
...
...
@@ -73,6 +72,16 @@ class MegatronModule(torch.nn.Module):
if
args
.
pipeline_model_parallel_size
==
1
:
return
if
not
torch
.
distributed
.
is_initialized
():
if
not
getattr
(
MegatronModule
,
"embedding_warning_printed"
,
False
):
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
MegatronModule
.
embedding_warning_printed
=
True
return
# Parameters are shared between the word embeddings layers, and the
# heads at the end of the model. In a pipelined setup with more than
# one stage, the initial embedding layer and the head are on different
...
...
@@ -85,7 +94,8 @@ class MegatronModule(torch.nn.Module):
# 3. In the training loop, before an all-reduce between the grads of
# the two word_embeddings layers to ensure that every applied weight
# update is the same on both stages.
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
()
and
\
not
self
.
pre_process
:
assert
not
mpu
.
is_pipeline_first_stage
()
self
.
_word_embeddings_for_head_key
=
'word_embeddings_for_head'
# set word_embeddings weights to 0 here, then copy first
...
...
@@ -96,21 +106,10 @@ class MegatronModule(torch.nn.Module):
self
.
word_embeddings
.
weight
.
data
.
fill_
(
0
)
self
.
word_embeddings
.
weight
.
shared
=
True
if
not
torch
.
distributed
.
is_initialized
():
if
not
getattr
(
MegatronModule
,
"embedding_warning_printed"
,
False
):
print
(
"WARNING! Distributed processes aren't initialized, so "
"word embeddings in the last layer are not initialized. "
"If you are just manipulating a model this is fine, but "
"this needs to be handled manually. If you are training "
"something is definitely wrong."
)
MegatronModule
.
embedding_warning_printed
=
True
return
# Zero out initial weights for decoder embedding.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
not
mpu
.
is_pipeline_first_stage
(
ignore_virtual
=
True
)
and
\
not
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
)
and
\
mpu
.
is_rank_in_embedding_group
():
self
.
pre_process
:
self
.
language_model
.
embedding
.
zero_parameters
()
# Ensure that first and last stages have the same initial parameter
...
...
@@ -118,21 +117,17 @@ class MegatronModule(torch.nn.Module):
if
mpu
.
is_rank_in_embedding_group
():
torch
.
distributed
.
all_reduce
(
self
.
word_embeddings_weight
().
data
,
group
=
mpu
.
get_embedding_group
())
# All-reduce other embeddings as well as necessary. The last stage
# does not have these other embeddings, so just create placeholder
# tensors of the right shape with all zeros.
# NOTE: We don't currently support T5 with the interleaved schedule.
if
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# TODO: Support tokentype embedding.
dimensions
=
(
args
.
max_position_embeddings
,
args
.
hidden_size
)
if
mpu
.
is_pipeline_last_stage
(
ignore_virtual
=
True
):
position_embeddings
=
torch
.
nn
.
Embedding
(
*
dimensions
).
cuda
()
position_embeddings
.
weight
.
data
.
fill_
(
0
)
else
:
self
.
language_model
.
embedding
.
cuda
()
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
mpu
.
get_embedding_group
())
# Ensure that encoder(first stage) and decoder(split stage) position
# embeddings have the same initial parameter values
# NOTE: We don't currently support T5 with the interleaved schedule.
if
mpu
.
is_rank_in_position_embedding_group
()
and
\
args
.
pipeline_model_parallel_split_rank
is
not
None
:
# TODO: Support tokentype embedding.
self
.
language_model
.
embedding
.
cuda
()
position_embeddings
=
self
.
language_model
.
embedding
.
position_embeddings
torch
.
distributed
.
all_reduce
(
position_embeddings
.
weight
.
data
,
group
=
mpu
.
get_position_embedding_group
())
def
conversion_helper
(
val
,
conversion
):
...
...
megatron/model/transformer.py
View file @
06fc51ce
...
...
@@ -27,7 +27,6 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
...
...
@@ -43,6 +42,29 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape
=
(
hidden_state
.
shape
[
0
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
1
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
class
ParallelMLP
(
MegatronModule
):
"""MLP.
...
...
@@ -407,7 +429,8 @@ class ParallelTransformerLayer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
):
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
...
...
@@ -423,7 +446,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the input data.
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
...
...
@@ -434,11 +458,13 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
args
.
hidden_dropout
self
.
bias_dropout_fusion
=
args
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Layernorm on the attention output
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
...
...
@@ -449,7 +475,8 @@ class ParallelTransformerLayer(MegatronModule):
# Layernorm on the attention output.
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
...
...
@@ -475,25 +502,31 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
hidden_states
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
layernorm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
layernorm_output
=
self
.
post_attention_layernorm
(
layernorm_input
)
...
...
@@ -529,24 +562,57 @@ class ParallelTransformerLayer(MegatronModule):
else
:
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
return
output
class
NoopTransformerLayer
(
MegatronModule
):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def
__init__
(
self
,
layer_number
):
super
().
__init__
()
self
.
layer_number
=
layer_number
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
return
hidden_states
.
clone
()
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
True
):
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
...
...
@@ -555,6 +621,7 @@ class ParallelTransformer(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
...
...
@@ -565,6 +632,8 @@ class ParallelTransformer(MegatronModule):
self
.
num_layers
=
mpu
.
get_num_layers
(
args
,
args
.
model_type
==
ModelType
.
encoder_and_decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
args
.
num_layers
)]
# Transformer layers.
def
build_layer
(
layer_number
):
return
ParallelTransformerLayer
(
...
...
@@ -572,11 +641,13 @@ class ParallelTransformer(MegatronModule):
output_layer_init_method
,
layer_number
,
layer_type
=
layer_type
,
self_attn_mask_type
=
self_attn_mask_type
)
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
if
args
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
args
.
num_layers
%
args
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
args
.
virtual_pipeline_model_parallel_size
...
...
@@ -593,16 +664,38 @@ class ParallelTransformer(MegatronModule):
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
)
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
...
...
@@ -622,23 +715,6 @@ class ParallelTransformer(MegatronModule):
return
x_
return
custom_forward
def
distribute_checkpointed_activations_helper
(
layer_number
):
"""Distribute checkpointed activations across the tensor model
Parallel ranks if the `distribute-checkpointed-activations
is on and either of the following conditions is met:
- it is not the first layer in the in the pipeline stage.
The first layer is used in the pipeline parallelism
and changing its shape throws error in the backward pass.
- we are at the first pipline stage so the input tensor is
not used in pipeline parallelism. Note that no pipeline
parallelism is a special case of this.
"""
not_first_layer_in_pipeline_stage
=
(
layer_number
>
0
)
is_first_pipeline_stage
=
(
mpu
.
get_pipeline_model_parallel_rank
()
==
0
)
return
self
.
distribute_checkpointed_activations
and
\
(
not_first_layer_in_pipeline_stage
or
is_first_pipeline_stage
)
if
self
.
activations_checkpoint_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
...
...
@@ -647,7 +723,7 @@ class ParallelTransformer(MegatronModule):
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint_num_layers
),
distribute_checkpointed_activations
_helper
(
l
)
,
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
...
...
@@ -658,7 +734,7 @@ class ParallelTransformer(MegatronModule):
if
l
<
self
.
activations_checkpoint_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
distribute_checkpointed_activations
_helper
(
l
)
,
self
.
distribute_checkpointed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
...
...
@@ -699,9 +775,32 @@ class ParallelTransformer(MegatronModule):
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
mpu
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# Transpose encoder output.
if
encoder_output
is
not
None
:
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
...
...
@@ -725,5 +824,5 @@ class ParallelTransformer(MegatronModule):
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
return
output
megatron/model/vision/classification.py
0 → 100644
View file @
06fc51ce
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Vision Transformer(VIT) model."""
import
torch
from
torch.nn.init
import
trunc_normal_
from
megatron
import
get_args
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3_avg
from
megatron.model.module
import
MegatronModule
class
VitClassificationModel
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
finetune
=
finetune
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
VitBackbone
(
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
single_token_output
=
True
)
if
self
.
post_process
:
if
not
self
.
finetune
:
self
.
head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
head
=
get_linear_layer
(
self
.
hidden_size
,
self
.
num_classes
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
self
.
post_process
:
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
class
MitClassificationModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
mit_b3_avg
()
self
.
head
=
torch
.
nn
.
Linear
(
512
,
num_classes
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
megatron/model/vision/dino.py
0 → 100644
View file @
06fc51ce
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import
math
import
apex
import
einops
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.model.vision.esvit_swin_backbone
import
get_swin
class
DINOLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_dim
,
ncrops
,
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
,
nepochs
,
student_temp
=
0.1
,
center_momentum
=
0.9
):
super
().
__init__
()
self
.
student_temp
=
student_temp
self
.
center_momentum
=
center_momentum
self
.
ncrops
=
ncrops
self
.
register_buffer
(
"center"
,
torch
.
zeros
(
1
,
out_dim
))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self
.
teacher_temp_schedule
=
np
.
concatenate
((
np
.
linspace
(
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
),
np
.
ones
(
nepochs
-
warmup_teacher_temp_epochs
)
*
teacher_temp
))
self
.
teacher_temp
=
teacher_temp
def
forward
(
self
,
student_output
,
teacher_output
,
iteration
):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args
=
get_args
()
student_out
=
student_output
/
self
.
student_temp
student_out
=
student_out
.
chunk
(
self
.
ncrops
)
epoch
=
iteration
//
args
.
iter_per_epoch
# teacher centering and sharpening
temp
=
self
.
teacher_temp_schedule
[
epoch
]
teacher_out
=
F
.
softmax
((
teacher_output
-
self
.
center
)
/
temp
,
dim
=-
1
)
teacher_out
=
teacher_out
.
detach
().
chunk
(
2
)
total_loss
=
0
n_loss_terms
=
0
for
iq
,
q
in
enumerate
(
teacher_out
):
for
v
in
range
(
len
(
student_out
)):
if
v
==
iq
:
# we skip cases where student and teacher operate on the same view
continue
loss
=
torch
.
sum
(
-
q
*
F
.
log_softmax
(
student_out
[
v
],
dim
=-
1
),
dim
=-
1
)
total_loss
+=
loss
.
mean
()
n_loss_terms
+=
1
total_loss
/=
n_loss_terms
self
.
update_center
(
teacher_output
)
return
total_loss
@
torch
.
no_grad
()
def
update_center
(
self
,
teacher_output
):
"""
Update center used for teacher output.
"""
batch_center
=
torch
.
sum
(
teacher_output
,
dim
=
0
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
batch_center
)
batch_center
=
batch_center
/
(
len
(
teacher_output
)
*
torch
.
distributed
.
get_world_size
())
self
.
center
=
self
.
center
*
self
.
center_momentum
+
batch_center
*
(
1
-
self
.
center_momentum
)
class
DINOHead
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
norm_last_layer
=
True
,
nlayers
=
3
):
super
().
__init__
()
args
=
get_args
()
hidden_dim
=
args
.
dino_head_hidden_size
bottleneck_dim
=
args
.
dino_bottleneck_size
nlayers
=
max
(
nlayers
,
1
)
if
nlayers
==
1
:
self
.
mlp
=
torch
.
nn
.
Linear
(
in_dim
,
bottleneck_dim
)
else
:
layers
=
[
torch
.
nn
.
Linear
(
in_dim
,
hidden_dim
)]
layers
.
append
(
torch
.
nn
.
GELU
())
for
_
in
range
(
nlayers
-
2
):
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
))
layers
.
append
(
torch
.
nn
.
GELU
())
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
bottleneck_dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
*
layers
)
self
.
apply
(
self
.
_init_weights
)
self
.
last_layer
=
torch
.
nn
.
utils
.
weight_norm
(
torch
.
nn
.
Linear
(
bottleneck_dim
,
out_dim
,
bias
=
False
))
self
.
last_layer
.
weight_g
.
data
.
fill_
(
1
)
if
norm_last_layer
:
self
.
last_layer
.
weight_g
.
requires_grad
=
False
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
x
=
self
.
mlp
(
x
)
x
=
torch
.
nn
.
functional
.
normalize
(
x
,
dim
=-
1
,
p
=
2
)
x
=
self
.
last_layer
(
x
)
return
x
class
MultiCropWrapper
(
MegatronModule
):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def
__init__
(
self
,
backbone
,
head
):
super
(
MultiCropWrapper
,
self
).
__init__
()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self
.
backbone
=
backbone
self
.
head
=
head
def
forward
(
self
,
x
):
# convert to list
if
not
isinstance
(
x
,
list
):
x
=
[
x
]
idx_crops
=
torch
.
cumsum
(
torch
.
unique_consecutive
(
torch
.
tensor
([
inp
.
shape
[
-
1
]
for
inp
in
x
]),
return_counts
=
True
,
)[
1
],
0
)
start_idx
=
0
for
end_idx
in
idx_crops
:
_out
=
self
.
backbone
(
torch
.
cat
(
x
[
start_idx
:
end_idx
]))
if
start_idx
==
0
:
output
=
_out
else
:
output
=
torch
.
cat
((
output
,
_out
))
start_idx
=
end_idx
# Run the head forward on the concatenated features.
if
self
.
training
:
return
self
.
head
(
output
)
else
:
return
output
def
cosine_scheduler
(
base_value
,
final_value
,
epochs
,
niter_per_ep
,
warmup_epochs
=
0
,
start_warmup_value
=
0
):
warmup_schedule
=
np
.
array
([])
warmup_iters
=
warmup_epochs
*
niter_per_ep
if
warmup_epochs
>
0
:
warmup_schedule
=
\
np
.
linspace
(
start_warmup_value
,
base_value
,
warmup_iters
)
iters
=
np
.
arange
(
epochs
*
niter_per_ep
-
warmup_iters
)
schedule
=
final_value
+
0.5
*
(
base_value
-
final_value
)
\
*
(
1
+
np
.
cos
(
np
.
pi
*
iters
/
len
(
iters
)))
schedule
=
np
.
concatenate
((
warmup_schedule
,
schedule
))
assert
len
(
schedule
)
==
epochs
*
niter_per_ep
return
schedule
def
get_student_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
student
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
drop_path_rate
=
0.1
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
student
=
mit_b5_avg
(
drop_path_rate
=
0.1
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
student
=
get_swin
()
num_features
=
student
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
student
,
num_features
def
get_teacher_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
teacher
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
teacher
=
mit_b5_avg
(
drop_path_rate
=
0.0
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
teacher
=
get_swin
(
is_teacher
=
True
)
num_features
=
teacher
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
teacher
,
num_features
class
DINOPretrainModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
DINOPretrainModel
,
self
).
__init__
()
args
=
get_args
()
self
.
out_dim
=
65536
self
.
dino_loss
=
DINOLoss
(
self
.
out_dim
,
args
.
dino_local_crops_number
+
2
,
args
.
dino_warmup_teacher_temp
,
args
.
dino_teacher_temp
,
args
.
dino_warmup_teacher_temp_epochs
,
300
,
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
momentum_teacher
=
0.996
student_backbone
,
num_features
=
\
get_student_backbone_and_num_features
(
pre_process
,
post_process
)
self
.
student
=
MultiCropWrapper
(
student_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
,
norm_last_layer
=
args
.
dino_norm_last_layer
)
)
self
.
momentum_schedule
=
cosine_scheduler
(
self
.
momentum_teacher
,
1
,
args
.
train_iters
//
args
.
iter_per_epoch
,
args
.
iter_per_epoch
)
teacher_backbone
,
num_features
=
\
get_teacher_backbone_and_num_features
(
pre_process
,
post_process
)
self
.
teacher
=
MultiCropWrapper
(
teacher_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
)
)
self
.
teacher
.
load_state_dict
(
self
.
student
.
state_dict
())
for
p
in
self
.
teacher
.
parameters
():
if
hasattr
(
p
,
"requires_grad"
)
and
p
.
requires_grad
is
not
None
:
p
.
requires_grad
=
False
def
set_input_tensor
(
self
,
tensor
):
pass
def
forward
(
self
,
input
):
student_output
=
None
if
self
.
training
:
student_output
=
self
.
student
(
input
)
teacher_output
=
self
.
teacher
(
input
[:
2
])
else
:
teacher_output
=
self
.
teacher
(
input
)
return
student_output
,
teacher_output
def
cancel_gradients_last_layer
(
self
,
iteration
):
args
=
get_args
()
epoch
=
iteration
//
args
.
iter_per_epoch
if
epoch
<
args
.
dino_freeze_last_layer
:
for
n
,
p
in
self
.
student
.
named_parameters
():
if
"last_layer"
in
n
:
p
.
grad
=
None
def
update_momentum
(
self
,
iteration
):
with
torch
.
no_grad
():
m
=
self
.
momentum_schedule
[
iteration
]
for
param_q
,
param_k
in
zip
(
self
.
student
.
parameters
(),
self
.
teacher
.
parameters
()):
param_k
.
data
.
mul_
(
m
).
add_
((
1
-
m
)
*
param_q
.
detach
().
data
)
megatron/model/vision/esvit_swin_backbone.py
0 → 100644
View file @
06fc51ce
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Modified by Chunyuan Li (chunyl@microsoft.com)
# Swin Transformer
# --------------------------------------------------------
import
os
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
torch.distributed
as
dist
from
torch.nn.init
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
import
numpy
as
np
from
math
import
sqrt
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
(
Mlp
,
self
).
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
(
WindowAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2 Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
).
type
(
attn
.
type
())
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn_out
=
attn
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
,
attn_out
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
@
staticmethod
def
compute_macs
(
module
,
input
,
output
):
B
,
N
,
C
=
input
[
0
].
shape
module
.
__flops__
+=
module
.
flops
(
N
)
*
B
class
SwinTransformerBlock
(
nn
.
Module
):
r
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
(
self
.
window_size
,
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of window size
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
_
,
Hp
,
Wp
,
_
=
x
.
shape
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
if
H
in
self
.
attn_mask_dict
.
keys
():
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
self
.
attn_mask_dict
[
H
]
=
self
.
create_attn_mask
(
self
.
H
,
self
.
W
).
to
(
x
.
device
)
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
shifted_x
=
x
attn_mask
=
None
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
,
attn
=
self
.
attn
(
x_windows
,
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
Hp
,
Wp
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
,
attn
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# padding
pad_input
=
(
H
%
2
==
1
)
or
(
W
%
2
==
1
)
if
pad_input
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
W
%
2
,
0
,
H
%
2
))
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
def
forward_with_features
(
self
,
x
):
fea
=
[]
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
fea
.
append
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
fea
def
forward_with_attention
(
self
,
x
):
attns
=
[]
for
blk
in
self
.
blocks
:
x
,
attn
=
blk
(
x
)
attns
.
append
(
attn
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
attns
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
)
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
# todo: to be implemented
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_region
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_region
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
def
forward_feature_maps
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_grid
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_grid
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
,
x_grid
def
forward_selfattention
(
self
,
x
,
n
=
1
):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
if
n
==
1
:
return
self
.
forward_last_selfattention
(
x
)
else
:
return
self
.
forward_all_selfattention
(
x
)
def
forward_last_selfattention
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
<
len
(
self
.
layers
)
-
1
:
x
=
layer
(
x
)
else
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
return
attns
[
-
1
]
def
forward_all_selfattention
(
self
,
x
):
attn_out
=
[]
for
layer
in
self
.
layers
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
attn_out
+=
attns
return
attn_out
def
forward_return_n_last_blocks
(
self
,
x
,
n
=
1
,
return_patch_avgpool
=
False
,
depth
=
[]):
num_blks
=
sum
(
depth
)
start_idx
=
num_blks
-
n
sum_cur
=
0
for
i
,
d
in
enumerate
(
depth
):
sum_cur_new
=
sum_cur
+
d
if
start_idx
>=
sum_cur
and
start_idx
<
sum_cur_new
:
start_stage
=
i
start_blk
=
start_idx
-
sum_cur
sum_cur
=
sum_cur_new
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output
=
[]
s
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
,
fea
=
layer
.
forward_with_features
(
x
)
if
i
>=
start_stage
:
for
x_
in
fea
[
start_blk
:]:
if
i
==
len
(
self
.
layers
)
-
1
:
# use the norm in the last stage
x_
=
self
.
norm
(
x_
)
x_avg
=
torch
.
flatten
(
self
.
avgpool
(
x_
.
transpose
(
1
,
2
)),
1
)
# B C
# print(f'Stage {i}, x_avg {x_avg.shape}')
output
.
append
(
x_avg
)
start_blk
=
0
return
torch
.
cat
(
output
,
dim
=-
1
)
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
if
dist
.
get_rank
()
==
0
:
print
(
f
"GFLOPs layer_
{
i
}
:
{
layer
.
flops
()
/
1e9
}
"
)
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
init_weights
(
self
,
pretrained
=
''
,
pretrained_layers
=
[],
verbose
=
True
):
if
os
.
path
.
isfile
(
pretrained
):
pretrained_dict
=
torch
.
load
(
pretrained
,
map_location
=
'cpu'
)
logging
.
info
(
f
'=> loading pretrained model
{
pretrained
}
'
)
model_dict
=
self
.
state_dict
()
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
k
in
model_dict
.
keys
()
}
need_init_state_dict
=
{}
for
k
,
v
in
pretrained_dict
.
items
():
need_init
=
(
k
.
split
(
'.'
)[
0
]
in
pretrained_layers
or
pretrained_layers
[
0
]
is
'*'
or
'relative_position_index'
not
in
k
or
'attn_mask'
not
in
k
)
if
need_init
:
if
verbose
:
logging
.
info
(
f
'=> init
{
k
}
from
{
pretrained
}
'
)
if
'relative_position_bias_table'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
relative_position_bias_table_pretrained
=
v
relative_position_bias_table_current
=
model_dict
[
k
]
L1
,
nH1
=
relative_position_bias_table_pretrained
.
size
()
L2
,
nH2
=
relative_position_bias_table_current
.
size
()
if
nH1
!=
nH2
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
L1
,
nH1
),
(
L2
,
nH2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
relative_position_bias_table_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
relative_position_bias_table_pretrained
.
permute
(
1
,
0
).
view
(
1
,
nH1
,
S1
,
S1
),
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
relative_position_bias_table_pretrained_resized
.
view
(
nH2
,
L2
).
permute
(
1
,
0
)
if
'absolute_pos_embed'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
absolute_pos_embed_pretrained
=
v
absolute_pos_embed_current
=
model_dict
[
k
]
_
,
L1
,
C1
=
absolute_pos_embed_pretrained
.
size
()
_
,
L2
,
C2
=
absolute_pos_embed_current
.
size
()
if
C1
!=
C1
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
1
,
L1
,
C1
),
(
1
,
L2
,
C2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
reshape
(
-
1
,
S1
,
S1
,
C1
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
permute
(
0
,
3
,
1
,
2
)
absolute_pos_embed_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
absolute_pos_embed_pretrained
,
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
absolute_pos_embed_pretrained_resized
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
1
,
2
)
need_init_state_dict
[
k
]
=
v
self
.
load_state_dict
(
need_init_state_dict
,
strict
=
False
)
def
freeze_pretrained_layers
(
self
,
frozen_layers
=
[]):
for
name
,
module
in
self
.
named_modules
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
'.'
.
join
(
name
.
split
(
'.'
)[
0
:
2
])
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
):
for
_name
,
param
in
module
.
named_parameters
():
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
for
name
,
param
in
self
.
named_parameters
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
and
param
.
requires_grad
is
True
):
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
return
self
def
get_swin
(
is_teacher
=
False
):
args
=
get_args
()
if
args
.
swin_backbone_type
==
"tiny"
:
embed_dim
=
96
depths
=
[
2
,
2
,
6
,
2
]
num_heads
=
[
3
,
6
,
12
,
24
]
drop_path_rate
=
0.1
elif
args
.
swin_backbone_type
==
'h3'
:
embed_dim
=
384
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
6
,
12
,
24
,
48
]
drop_path_rate
=
0.2
else
:
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
drop_path_rate
=
0.2
swin
=
SwinTransformer
(
img_size
=
224
,
in_chans
=
3
,
num_classes
=
1000
,
patch_size
=
4
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
7
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
drop_rate
=
0
,
attn_drop_rate
=
0
,
drop_path_rate
=
(
0.0
if
is_teacher
else
drop_path_rate
),
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
)
return
swin
megatron/model/vision/inpainting.py
0 → 100644
View file @
06fc51ce
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
i
import
math
import
apex
import
einops
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.mit_backbone
import
mit_b3
from
megatron.model.vision.utils
import
resize_
class
VitInpaintingModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitInpaintingModel
,
self
).
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
backbone
=
VitBackbone
(
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
class_token
=
False
,
)
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
seq_length
=
args
.
seq_length
# full mask
if
self
.
post_process
:
self
.
linear_decoder
=
get_linear_layer
(
self
.
hidden_size
,
self
.
backbone
.
flatten_dim
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
not
self
.
post_process
:
return
hidden_states
decoded_output
=
self
.
linear_decoder
(
hidden_states
)
output
=
einops
.
rearrange
(
decoded_output
,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
MitInpaintingModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitInpaintingModel
,
self
).
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
args
=
get_args
()
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
3
self
.
backbone
=
mit_b3
()
self
.
in_channels
=
[
64
,
128
,
320
,
512
]
self
.
embedding_dim
=
768
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
,
bias
=
False
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
0.1
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
flatten_dim
,
kernel_size
=
1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
c1
,
c2
,
c3
,
c4
=
self
.
backbone
(
input
)
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
)
_c
=
self
.
conv_fuse
(
_c
)
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
output
=
einops
.
rearrange
(
x
,
"b (c p1 p2) h w -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
megatron/model/vision/knn_monitor.py
0 → 100644
View file @
06fc51ce
import
torch.nn.functional
as
F
import
torch
from
megatron
import
print_rank_0
,
get_args
,
mpu
from
megatron.data.vit_dataset
import
ClassificationTransform
from
megatron.data.image_folder
import
ImageFolder
_FEATURE_BANK
=
None
def
build_data_loader
(
dataset
,
drop_last
=
True
,
shuffle
=
False
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args
=
get_args
()
micro_batch_size
=
16
num_workers
=
args
.
num_workers
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
micro_batch_size
,
sampler
=
sampler
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
not
drop_last
,
pin_memory
=
True
,
)
return
data_loader
def
compute_feature_bank
(
model
):
args
=
get_args
()
global
_FEATURE_BANK
feature_bank
=
[]
feature_label
=
[]
train_ds
=
ImageFolder
(
root
=
args
.
data_path
[
0
],
transform
=
ClassificationTransform
((
args
.
img_h
,
args
.
img_w
),
train
=
False
),
data_per_class_fraction
=
1.0
)
classes
=
len
(
train_ds
.
classes
)
dataloader
=
build_data_loader
(
train_ds
)
for
m
in
model
:
m
.
eval
()
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
dataloader
):
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
student_feature
,
teacher_feature
=
model
[
0
](
images
)
feature
=
F
.
normalize
(
teacher_feature
.
float
(),
dim
=
1
)
feature_bank
.
append
(
feature
)
feature_label
.
append
(
labels
)
for
m
in
model
:
m
.
train
()
# [N', D]
feature_bank
=
torch
.
cat
(
feature_bank
,
dim
=
0
).
contiguous
()
feature_label
=
torch
.
cat
(
feature_label
,
dim
=
0
).
contiguous
()
feature_banks
=
[
torch
.
zeros_like
(
feature_bank
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_banks
,
feature_bank
,
group
=
mpu
.
get_data_parallel_group
())
assert
torch
.
all
(
torch
.
eq
(
feature_banks
[
mpu
.
get_data_parallel_rank
()],
feature_bank
))
feature_labels
=
[
torch
.
zeros_like
(
feature_label
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_labels
,
feature_label
,
group
=
mpu
.
get_data_parallel_group
())
# [D, N]
feature_banks
=
torch
.
cat
(
feature_banks
,
dim
=
0
).
t
().
contiguous
()
# [N]
feature_labels
=
torch
.
cat
(
feature_labels
,
dim
=
0
).
contiguous
()
print_rank_0
(
"feature_banks size is {}"
.
format
(
feature_banks
.
size
()))
print_rank_0
(
"feature labels size is {}"
.
format
(
feature_labels
.
size
()))
_FEATURE_BANK
=
(
feature_banks
,
feature_labels
,
classes
)
def
get_feature_bank
():
global
_FEATURE_BANK
assert
_FEATURE_BANK
is
not
None
return
_FEATURE_BANK
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def
knn_predict
(
feature
,
feature_bank
,
feature_labels
,
classes
,
knn_k
,
knn_t
):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix
=
torch
.
mm
(
feature
,
feature_bank
)
# [B, K]
sim_weight
,
sim_indices
=
sim_matrix
.
topk
(
k
=
knn_k
,
dim
=-
1
)
# [B, K]
sim_labels
=
torch
.
gather
(
feature_labels
.
expand
(
feature
.
size
(
0
),
-
1
),
dim
=-
1
,
index
=
sim_indices
)
sim_weight
=
(
sim_weight
/
knn_t
).
exp
()
# counts for each class
one_hot_label
=
torch
.
zeros
(
feature
.
size
(
0
)
*
knn_k
,
classes
,
device
=
sim_labels
.
device
)
# [B*K, C]
one_hot_label
=
one_hot_label
.
scatter
(
dim
=-
1
,
index
=
sim_labels
.
view
(
-
1
,
1
),
value
=
1.0
)
# weighted score ---> [B, C]
pred_scores
=
torch
.
sum
(
one_hot_label
.
view
(
feature
.
size
(
0
),
-
1
,
classes
)
*
sim_weight
.
unsqueeze
(
dim
=-
1
),
dim
=
1
)
pred_labels
=
pred_scores
.
argsort
(
dim
=-
1
,
descending
=
True
)
return
pred_labels
megatron/model/vision/mit_backbone.py
0 → 100644
View file @
06fc51ce
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
torch.nn.init
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron.model
import
LayerNorm
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
dwconv
=
DWConv
(
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
,
H
,
W
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
sr_ratio
=
1
):
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
f
"dim
{
dim
}
should be divided by num_heads
{
num_heads
}
."
self
.
dim
=
dim
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
q
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
self
.
sr
=
nn
.
Conv2d
(
dim
,
dim
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
self
.
norm
=
LayerNorm
(
dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x_
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
H
,
W
)
x_
=
self
.
sr
(
x_
).
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
x_
=
self
.
norm
(
x_
)
kv
=
self
.
kv
(
x_
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
else
:
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
sr_ratio
=
1
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
sr_ratio
=
sr_ratio
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
H
,
W
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
),
H
,
W
))
return
x
class
OverlapPatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
(
patch_size
[
0
]
//
2
,
patch_size
[
1
]
//
2
))
self
.
norm
=
LayerNorm
(
embed_dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
return
x
,
H
,
W
class
MixVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
,
512
],
num_heads
=
[
1
,
2
,
4
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
LayerNorm
,
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
output_avg
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
depths
=
depths
self
.
output_avg
=
output_avg
# patch_embed
self
.
patch_embed1
=
OverlapPatchEmbed
(
img_size
=
img_size
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
in_chans
,
embed_dim
=
embed_dims
[
0
])
self
.
patch_embed2
=
OverlapPatchEmbed
(
img_size
=
img_size
//
4
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
0
],
embed_dim
=
embed_dims
[
1
])
self
.
patch_embed3
=
OverlapPatchEmbed
(
img_size
=
img_size
//
8
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
1
],
embed_dim
=
embed_dims
[
2
])
self
.
patch_embed4
=
OverlapPatchEmbed
(
img_size
=
img_size
//
16
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
2
],
embed_dim
=
embed_dims
[
3
])
# transformer encoder
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
cur
=
0
self
.
block1
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
0
],
num_heads
=
num_heads
[
0
],
mlp_ratio
=
mlp_ratios
[
0
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
0
])
for
i
in
range
(
depths
[
0
])])
self
.
norm1
=
norm_layer
(
embed_dims
[
0
])
cur
+=
depths
[
0
]
self
.
block2
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
1
],
num_heads
=
num_heads
[
1
],
mlp_ratio
=
mlp_ratios
[
1
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
1
])
for
i
in
range
(
depths
[
1
])])
self
.
norm2
=
norm_layer
(
embed_dims
[
1
])
cur
+=
depths
[
1
]
self
.
block3
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
2
],
num_heads
=
num_heads
[
2
],
mlp_ratio
=
mlp_ratios
[
2
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
2
])
for
i
in
range
(
depths
[
2
])])
self
.
norm3
=
norm_layer
(
embed_dims
[
2
])
cur
+=
depths
[
2
]
self
.
block4
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
3
],
num_heads
=
num_heads
[
3
],
mlp_ratio
=
mlp_ratios
[
3
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
3
])
for
i
in
range
(
depths
[
3
])])
self
.
norm4
=
norm_layer
(
embed_dims
[
3
])
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
reset_drop_path
(
self
,
drop_path_rate
):
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))]
cur
=
0
for
i
in
range
(
self
.
depths
[
0
]):
self
.
block1
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
0
]
for
i
in
range
(
self
.
depths
[
1
]):
self
.
block2
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
1
]
for
i
in
range
(
self
.
depths
[
2
]):
self
.
block3
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
2
]
for
i
in
range
(
self
.
depths
[
3
]):
self
.
block4
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
def
freeze_patch_emb
(
self
):
self
.
patch_embed1
.
requires_grad
=
False
def
forward_features
(
self
,
x
):
B
=
x
.
shape
[
0
]
outs
=
[]
# stage 1
x
,
H
,
W
=
self
.
patch_embed1
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block1
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm1
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 2
x
,
H
,
W
=
self
.
patch_embed2
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block2
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm2
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 3
x
,
H
,
W
=
self
.
patch_embed3
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block3
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm3
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 4
x
,
H
,
W
=
self
.
patch_embed4
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block4
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm4
(
x
)
if
not
self
.
output_avg
:
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
return
outs
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
output_avg
:
x
=
x
[
3
].
mean
(
dim
=
1
)
return
x
class
DWConv
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
768
):
super
(
DWConv
,
self
).
__init__
()
self
.
dwconv
=
nn
.
Conv2d
(
dim
,
dim
,
3
,
1
,
1
,
bias
=
True
,
groups
=
dim
)
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
x
=
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
)
x
=
self
.
dwconv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
mit_b0
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b0
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
32
,
64
,
160
,
256
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b1
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b1
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b2
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b2
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b3
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b3_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
class
mit_b4
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b4
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
8
,
27
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b5
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b5_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
megatron/model/vision/swin_backbone.py
0 → 100644
View file @
06fc51ce
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Swin Transformer
# --------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
checkpoint
from
timm.models.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
from
math
import
sqrt
from
megatron
import
get_args
from
functools
import
partial
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
class
SwinTransformerBlock
(
nn
.
Module
):
r
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
mask
=
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
, mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
f
"x size (
{
H
}
*
{
W
}
) are not even."
x
=
x
.
view
(
B
,
H
,
W
,
C
)
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
,
use_checkpoint
=
False
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
if
self
.
use_checkpoint
:
x
=
checkpoint
.
checkpoint
(
blk
,
x
)
else
:
x
=
blk
(
x
)
x_b4_ds
=
x
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x_b4_ds
,
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
r
""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.3
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
use_checkpoint
=
False
,
output_avg
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
img_size
=
to_2tuple
(
img_size
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
self
.
output_avg
=
output_avg
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
use_checkpoint
=
use_checkpoint
)
self
.
layers
.
append
(
layer
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
h
=
self
.
img_size
[
0
]
//
self
.
patch_size
[
0
]
w
=
self
.
img_size
[
1
]
//
self
.
patch_size
[
1
]
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
px
,
x
=
layer
(
x
)
b
,
n
,
c
=
px
.
shape
if
i
!=
len
(
self
.
layers
)
-
1
or
not
self
.
output_avg
:
px
=
px
.
permute
(
0
,
2
,
1
).
contiguous
()
px
=
px
.
reshape
(
b
,
c
,
h
,
w
)
# is this a fair assumption ?? i think it's baked into the architecture
h
,
w
=
h
//
2
,
w
//
2
outs
.
append
(
px
)
if
self
.
output_avg
:
return
outs
[
-
1
].
mean
(
dim
=
1
)
return
outs
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
get_swin
(
drop_path_rate
=
0.3
,
output_avg
=
False
):
args
=
get_args
()
window_size
=
7
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
swin
=
SwinTransformer
(
img_size
=
(
args
.
img_h
,
args
.
img_w
,),
in_chans
=
3
,
patch_size
=
args
.
patch_dim
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
window_size
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
output_avg
,
)
return
swin
megatron/model/vision/utils.py
0 → 100644
View file @
06fc51ce
import
warnings
import
torch
import
torch.nn.functional
as
F
def
resize
(
input
,
size
=
None
,
scale_factor
=
None
,
mode
=
'nearest'
,
align_corners
=
None
,
warning
=
True
):
if
warning
:
if
size
is
not
None
and
align_corners
:
input_h
,
input_w
=
tuple
(
int
(
x
)
for
x
in
input
.
shape
[
2
:])
output_h
,
output_w
=
tuple
(
int
(
x
)
for
x
in
size
)
if
output_h
>
input_h
or
output_w
>
output_h
:
if
((
output_h
>
1
and
output_w
>
1
and
input_h
>
1
and
input_w
>
1
)
and
(
output_h
-
1
)
%
(
input_h
-
1
)
and
(
output_w
-
1
)
%
(
input_w
-
1
)):
warnings
.
warn
(
f
'When align_corners=
{
align_corners
}
, '
'the output would more aligned if '
f
'input size
{
(
input_h
,
input_w
)
}
is `x+1` and '
f
'out size
{
(
output_h
,
output_w
)
}
is `nx+1`'
)
if
isinstance
(
size
,
torch
.
Size
):
size
=
tuple
(
int
(
x
)
for
x
in
size
)
return
F
.
interpolate
(
input
,
size
,
scale_factor
,
mode
,
align_corners
)
megatron/model/vi
t_model
.py
→
megatron/model/vi
sion/vit_backbone
.py
View file @
06fc51ce
...
...
@@ -18,16 +18,19 @@
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
get_linear_layer
,
init_method_normal
,
scaled_init_method_normal
,
)
from
.module
import
MegatronModule
from
megatron.model
.module
import
MegatronModule
CLASS_TOKEN_LENGTH
=
8
class
VitMlpHead
(
MegatronModule
):
"""Pooler layer.
...
...
@@ -44,19 +47,26 @@ class VitMlpHead(MegatronModule):
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
VitMlpHead
,
self
).
__init__
()
self
.
dense_in
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
dense_out
=
torch
.
nn
.
Linear
(
hidden_size
,
num_classes
)
torch
.
nn
.
init
.
constant_
(
self
.
dense_out
.
bias
,
-
10
)
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [b,
s
, h]
def
forward
(
self
,
hidden_states
):
# hidden_states: [b,
1
, h]
# sequence_index: index of the token to pool.
hidden_state
=
hidden_states
[:,
sequence_index
,
:]
dense_in_result
=
self
.
dense_in
(
hidden_state
)
dense_in_result
=
self
.
dense_in
(
hidden_states
)
tanh_result
=
torch
.
tanh
(
dense_in_result
)
dense_out_result
=
self
.
dense_out
(
tanh_result
)
return
dense_out_result
def
isPerfectSquare
(
x
):
if
(
x
>=
0
):
sr
=
math
.
sqrt
(
x
)
return
(
int
(
sr
)
*
int
(
sr
)
==
x
)
return
False
def
twod_interpolate_position_embeddings_hook
(
state_dict
,
prefix
,
...
...
@@ -68,66 +78,78 @@ def twod_interpolate_position_embeddings_hook(
):
args
=
get_args
()
num_patches_per_dim
=
args
.
img_
dim
//
args
.
patch_dim
num_patches
=
num_patches_per_dim
**
2
seq_length
=
num_patches
+
1
num_patches_per_dim
_h
=
args
.
img_
h
//
args
.
patch_dim
num_patches
_per_dim_w
=
args
.
img_w
//
args
.
patch_dim
num_patches
=
num_patches
_per_dim_h
*
num_patches_per_dim_w
hidden_size
=
args
.
hidden_size
key
=
prefix
+
"weight"
# import pdb
# pdb.set_trace()
assert
key
in
state_dict
if
key
in
state_dict
:
input_param
=
state_dict
[
key
]
input_seq_len
=
input_param
.
shape
[
0
]
assert
(
isPerfectSquare
(
input_seq_len
)
or
isPerfectSquare
(
input_seq_len
-
CLASS_TOKEN_LENGTH
))
input_has_class_token
=
not
isPerfectSquare
(
input_seq_len
)
num_tok_input
=
input_seq_len
-
CLASS_TOKEN_LENGTH
if
input_has_class_token
else
input_seq_len
num_tok_output
=
num_patches
output_has_class_token
=
args
.
class_token_present
# update input_param and load it to state_dict[key]
if
input_has_class_token
:
input_param_tok
=
input_param
[:
CLASS_TOKEN_LENGTH
,
:]
input_param_grid
=
input_param
[
CLASS_TOKEN_LENGTH
:,
:]
else
:
input_param_tok
=
torch
.
zeros
(
CLASS_TOKEN_LENGTH
,
hidden_size
)
input_param_grid
=
input_param
assert
input_param
.
shape
[
1
]
==
hidden_size
if
input_param
.
shape
[
0
]
!=
seq_length
:
# update input_param and load it to state_dict[key]
num_tok_input
=
input_param
.
shape
[
0
]
-
1
num_tok_new
=
seq_length
-
1
input_param_tok
,
input_param_grid
=
(
input_param
[:
1
,
:],
input_param
[
1
:,
:],
)
if
num_tok_input
!=
num_tok_output
:
gs_input
=
int
(
math
.
sqrt
(
num_tok_input
))
gs_new
=
int
(
math
.
sqrt
(
num_tok_ne
w
)
)
gs_new
=
(
num_patches_per_dim_h
,
num_patches_per_dim_
w
)
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
input_param_grid
=
input_param_grid
.
reshape
(
(
1
,
-
1
,
gs_input
,
gs_input
)
)
input_param_grid
=
input_param_grid
.
float
()
scale_factor
=
gs_new
/
gs_input
scale_factor
=
(
gs_new
[
0
]
/
gs_input
,
gs_new
[
1
]
/
gs_input
)
input_param_grid
=
F
.
interpolate
(
input_param_grid
,
scale_factor
=
scale_factor
,
mode
=
"bilinear"
)
input_param_grid
=
input_param_grid
.
half
()
input_param_grid
=
input_param_grid
.
reshape
((
-
1
,
gs_new
*
gs_new
))
input_param_grid
=
input_param_grid
.
reshape
((
-
1
,
num_tok_output
))
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
assert
input_param_grid
.
shape
[
1
]
==
hidden_size
input_param
=
torch
.
cat
((
input_param_tok
,
input_param_grid
),
dim
=
0
)
assert
(
input_param
.
shape
[
0
]
==
seq_length
and
input_param
.
shape
[
1
]
==
hidden_size
)
state_dict
[
key
]
=
input_param
input_param
=
input_param_grid
assert
(
input_param
.
shape
[
0
]
==
num_tok_output
and
input_param
.
shape
[
1
]
==
hidden_size
)
if
output_has_class_token
:
input_param
=
torch
.
cat
((
input_param_tok
,
input_param
),
dim
=
0
)
state_dict
[
key
]
=
input_param
class
Vit
Model
(
MegatronModule
):
class
Vit
Backbone
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
finetune
=
False
,
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitModel
,
self
).
__init__
(
share_word_embeddings
=
False
)
post_process
=
True
,
class_token
=
True
,
single_token_output
=
False
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
...
...
@@ -142,25 +164,34 @@ class VitModel(MegatronModule):
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
class_token
=
class_token
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
patch_dim
=
args
.
patch_dim
self
.
img_dim
=
args
.
img_dim
self
.
finetune
=
finetune
assert
self
.
img_dim
%
self
.
patch_dim
==
0
self
.
num_patches_per_dim
=
self
.
img_dim
//
self
.
patch_dim
self
.
num_patches
=
self
.
num_patches_per_dim
**
2
self
.
seq_length
=
self
.
num_patches
+
1
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
single_token_output
=
single_token_output
self
.
drop_path_rate
=
drop_path_rate
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
self
.
num_patches_per_dim_h
=
self
.
img_h
//
self
.
patch_dim
self
.
num_patches_per_dim_w
=
self
.
img_w
//
self
.
patch_dim
self
.
num_patches
=
self
.
num_patches_per_dim_h
*
self
.
num_patches_per_dim_w
self
.
seq_length
=
self
.
num_patches
+
(
CLASS_TOKEN_LENGTH
if
self
.
class_token
else
0
)
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
args
.
num_channels
self
.
input_tensor
=
None
self
.
position_ids
=
None
if
self
.
pre_process
:
# cls_token
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
1
,
self
.
hidden_size
)
)
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
if
self
.
class_token
:
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
CLASS_TOKEN_LENGTH
,
self
.
hidden_size
)
)
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
# Linear encoder
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
self
.
flatten_dim
,
self
.
hidden_size
...
...
@@ -173,8 +204,8 @@ class VitModel(MegatronModule):
init_method_normal
(
args
.
init_method_std
)(
self
.
position_embeddings
.
weight
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
args
.
class_token_present
=
self
.
class_token
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
twod_interpolate_position_embeddings_hook
)
...
...
@@ -183,21 +214,13 @@ class VitModel(MegatronModule):
# Transformer
self
.
transformer
=
ParallelTransformer
(
self
.
init_method
,
self
.
init_method
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
post_process
=
self
.
post_process
,
drop_path_rate
=
self
.
drop_path_rate
)
if
self
.
post_process
:
# MLP head
if
not
self
.
finetune
:
self
.
mlp_head
=
VitMlpHead
(
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
class_head
=
get_linear_layer
(
self
.
hidden_size
,
num_classes
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
transformer
.
set_input_tensor
(
input_tensor
)
...
...
@@ -214,21 +237,22 @@ class VitModel(MegatronModule):
assert
rearranged_input
.
dtype
==
torch
.
half
encoder_output
=
self
.
linear_encoder
(
rearranged_input
)
cls_tokens
=
self
.
cls_token
.
expand
(
encoder_output
.
shape
[
0
],
-
1
,
-
1
)
concatenated_tokens
=
torch
.
cat
((
cls_tokens
,
encoder_output
),
dim
=
1
)
concatenated_tokens
=
encoder_output
if
self
.
class_token
:
cls_tokens
=
self
.
cls_token
.
expand
(
encoder_output
.
shape
[
0
],
-
1
,
-
1
)
concatenated_tokens
=
torch
.
cat
((
cls_tokens
,
encoder_output
),
dim
=
1
)
token_embeddings
=
concatenated_tokens
+
\
self
.
position_embeddings
(
self
.
position_ids
)
self
.
position_embeddings
(
self
.
position_ids
[:,
:
concatenated_tokens
.
shape
[
1
]]
)
hidden_states
=
self
.
embedding_dropout
(
token_embeddings
)
else
:
hidden_states
=
input
hidden_states
=
self
.
transformer
(
hidden_states
,
None
)
if
self
.
post_process
:
if
not
self
.
finetune
:
hidden_states
=
self
.
mlp_head
(
hidden_states
)
else
:
hidden_states
=
self
.
class_head
(
hidden_states
[:,
0
,
:])
if
self
.
single_token_output
:
hidden_states
=
hidden_states
[:,
0
,:]
return
hidden_states
megatron/mpu/__init__.py
View file @
06fc51ce
...
...
@@ -25,6 +25,7 @@ from .initialize import get_data_parallel_group
from
.initialize
import
get_data_parallel_rank
from
.initialize
import
get_data_parallel_world_size
from
.initialize
import
get_embedding_group
from
.initialize
import
get_position_embedding_group
from
.initialize
import
get_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_pipeline_model_parallel_group
...
...
@@ -32,10 +33,12 @@ from .initialize import get_tensor_model_parallel_rank, set_tensor_model_paralle
from
.initialize
import
get_pipeline_model_parallel_rank
,
set_pipeline_model_parallel_rank
from
.initialize
import
is_pipeline_first_stage
,
is_pipeline_last_stage
from
.initialize
import
is_rank_in_embedding_group
from
.initialize
import
is_rank_in_position_embedding_group
from
.initialize
import
is_pipeline_stage_before_split
,
is_pipeline_stage_after_split
from
.initialize
import
is_pipeline_stage_at_split
from
.initialize
import
get_num_layers
from
.initialize
import
get_tensor_model_parallel_src_rank
from
.initialize
import
get_data_parallel_src_rank
from
.initialize
import
get_pipeline_model_parallel_first_rank
from
.initialize
import
get_pipeline_model_parallel_last_rank
from
.initialize
import
get_pipeline_model_parallel_next_rank
...
...
@@ -63,6 +66,9 @@ from .random import get_cuda_rng_tracker
from
.random
import
model_parallel_cuda_manual_seed
from
.random
import
gather_split_1d_tensor
from
.random
import
split_tensor_into_1d_equal_chunks
from
.random
import
make_viewless_tensor
from
.random
import
assert_viewless_tensor
from
.random
import
safely_set_viewless_tensor_data
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
megatron/mpu/initialize.py
View file @
06fc51ce
...
...
@@ -29,6 +29,8 @@ _PIPELINE_MODEL_PARALLEL_GROUP = None
_MODEL_PARALLEL_GROUP
=
None
# Embedding group.
_EMBEDDING_GROUP
=
None
# Position embedding group.
_POSITION_EMBEDDING_GROUP
=
None
# Data parallel group that the current rank belongs to.
_DATA_PARALLEL_GROUP
=
None
...
...
@@ -45,6 +47,9 @@ _MPU_PIPELINE_MODEL_PARALLEL_RANK = None
# A list of ranks that have a copy of the embedding.
_EMBEDDING_GLOBAL_RANKS
=
None
# A list of ranks that have a copy of the position embedding.
_POSITION_EMBEDDING_GLOBAL_RANKS
=
None
# A list of global ranks for each pipeline group to ease calculation of the source
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
...
...
@@ -165,6 +170,10 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
global
_EMBEDDING_GLOBAL_RANKS
assert
_EMBEDDING_GROUP
is
None
,
\
'embedding group is already initialized'
global
_POSITION_EMBEDDING_GROUP
global
_POSITION_EMBEDDING_GLOBAL_RANKS
assert
_POSITION_EMBEDDING_GROUP
is
None
,
\
'position embedding group is already initialized'
for
i
in
range
(
num_pipeline_model_parallel_groups
):
ranks
=
range
(
i
,
world_size
,
num_pipeline_model_parallel_groups
)
...
...
@@ -176,19 +185,31 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# first and last stages).
if
len
(
ranks
)
>
1
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
-
1
]]
if
pipeline_model_parallel_split_rank_
is
not
None
and
\
pipeline_model_parallel_split_rank_
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
]]
position_embedding_ranks
=
[
ranks
[
0
]]
if
pipeline_model_parallel_split_rank_
is
not
None
:
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
embedding_ranks
:
embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
],
ranks
[
-
1
]]
if
ranks
[
pipeline_model_parallel_split_rank_
]
not
in
position_embedding_ranks
:
position_embedding_ranks
=
[
ranks
[
0
],
ranks
[
pipeline_model_parallel_split_rank_
]]
else
:
embedding_ranks
=
ranks
position_embedding_ranks
=
ranks
group
=
torch
.
distributed
.
new_group
(
embedding_ranks
)
if
rank
in
embedding_ranks
:
_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_EMBEDDING_GLOBAL_RANKS
=
embedding_ranks
group
=
torch
.
distributed
.
new_group
(
position_embedding_ranks
)
if
rank
in
position_embedding_ranks
:
_POSITION_EMBEDDING_GROUP
=
group
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
...
...
@@ -234,6 +255,13 @@ def get_embedding_group():
return
_EMBEDDING_GROUP
def
get_position_embedding_group
():
"""Get the position embedding group the caller rank belongs to."""
assert
_POSITION_EMBEDDING_GROUP
is
not
None
,
\
'position embedding group is not initialized'
return
_POSITION_EMBEDDING_GROUP
def
set_tensor_model_parallel_world_size
(
world_size
):
"""Set the tensor model parallel size"""
global
_MPU_TENSOR_MODEL_PARALLEL_WORLD_SIZE
...
...
@@ -295,20 +323,44 @@ def get_num_layers(args, is_encoder_and_decoder_model):
if
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
num_ranks_in_encoder
=
args
.
pipeline_model_parallel_split_rank
num_ranks_in_decoder
=
get_pipeline_model_parallel_world_size
()
-
num_ranks_in_encoder
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder
=
(
args
.
pipeline_model_parallel_split_rank
-
1
if
args
.
standalone_embedding_stage
else
args
.
pipeline_model_parallel_split_rank
)
num_ranks_in_decoder
=
args
.
transformer_pipeline_model_parallel_size
-
num_ranks_in_encoder
assert
args
.
num_layers
%
num_ranks_in_encoder
==
0
,
\
'num_layers must be divisible by number of ranks given to encoder
'
'num_layers
(%d)
must be divisible by number of ranks given to encoder
(%d)'
%
(
args
.
num_layers
,
num_ranks_in_encoder
)
assert
args
.
num_layers
%
num_ranks_in_decoder
==
0
,
\
'num_layers must be divisible by number of ranks given to decoder
'
'num_layers
(%d)
must be divisible by number of ranks given to decoder
(%d)'
%
(
args
.
num_layers
,
num_ranks_in_decoder
)
if
is_pipeline_stage_before_split
():
num_layers
=
args
.
num_layers
//
num_ranks_in_encoder
num_layers
=
(
0
if
args
.
standalone_embedding_stage
and
get_pipeline_model_parallel_rank
()
==
0
else
args
.
num_layers
//
num_ranks_in_encoder
)
else
:
num_layers
=
args
.
num_layers
//
num_ranks_in_decoder
else
:
assert
args
.
num_layers
%
get_pipeline_model_parallel_world_size
()
==
0
,
\
'num_layers must be divisible by pipeline_model_parallel_size'
num_layers
=
args
.
num_layers
//
get_pipeline_model_parallel_world_size
()
assert
args
.
num_layers
%
args
.
transformer_pipeline_model_parallel_size
==
0
,
\
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers
=
(
0
if
args
.
standalone_embedding_stage
and
get_pipeline_model_parallel_rank
()
==
0
else
args
.
num_layers
//
args
.
transformer_pipeline_model_parallel_size
)
else
:
num_layers
=
args
.
num_layers
return
num_layers
...
...
@@ -352,6 +404,13 @@ def is_rank_in_embedding_group(ignore_virtual=False):
return
False
def
is_rank_in_position_embedding_group
():
"""Return true if current rank is in position embedding group, False otherwise."""
rank
=
torch
.
distributed
.
get_rank
()
global
_POSITION_EMBEDDING_GLOBAL_RANKS
return
rank
in
_POSITION_EMBEDDING_GLOBAL_RANKS
def
is_pipeline_stage_before_split
(
rank
=
None
):
"""Return True if pipeline stage executes encoder block for a model
with both encoder and decoder."""
...
...
@@ -417,6 +476,15 @@ def get_tensor_model_parallel_src_rank():
return
(
global_rank
//
local_world_size
)
*
local_world_size
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
data_parallel_size
=
get_data_parallel_world_size
()
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
"Pipeline parallel group is not initialized"
...
...
@@ -467,3 +535,5 @@ def destroy_model_parallel():
_DATA_PARALLEL_GROUP
=
None
global
_EMBEDDING_GROUP
_EMBEDDING_GROUP
=
None
global
_POSITION_EMBEDDING_GROUP
_POSITION_EMBEDDING_GROUP
=
None
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment