Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wxj
NeMo
Commits
bc5c7fa7
Commit
bc5c7fa7
authored
Jan 07, 2025
by
wxj
Browse files
第一次测试提交
parent
70fddd0f
Changes
290
Show whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4031 additions
and
0 deletions
+4031
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/rms_norm.py
...ron-LM-core_r0.7.0.beta/megatron/legacy/model/rms_norm.py
+31
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/t5_model.py
...ron-LM-core_r0.7.0.beta/megatron/legacy/model/t5_model.py
+186
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/transformer.py
...-LM-core_r0.7.0.beta/megatron/legacy/model/transformer.py
+1813
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/utils.py
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/utils.py
+79
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/classification.py
...0.7.0.beta/megatron/legacy/model/vision/classification.py
+86
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/dino.py
...-LM-core_r0.7.0.beta/megatron/legacy/model/vision/dino.py
+291
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/esvit_swin_backbone.py
....beta/megatron/legacy/model/vision/esvit_swin_backbone.py
+849
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/inpainting.py
...re_r0.7.0.beta/megatron/legacy/model/vision/inpainting.py
+152
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/knn_monitor.py
...e_r0.7.0.beta/megatron/legacy/model/vision/knn_monitor.py
+129
-0
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/mit_backbone.py
..._r0.7.0.beta/megatron/legacy/model/vision/mit_backbone.py
+415
-0
No files found.
Too many changes to show.
To preserve performance only
290 of 290+
files are displayed.
Plain diff
Email patch
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/rms_norm.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
import
torch
from
torch
import
nn
class
RMSNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
:
int
,
eps
:
float
=
1e-6
,
sequence_parallel
:
bool
=
False
):
"""RMS Normaliation module
Args:
dim (int): The width of input, i.e. hidden size
eps (float): epsilon to use for the norm, default to 1e-6
sequence_parallel (bool): Set to true if sequence parallelism is being used,
this marks the weights as needing to be allreduced.
"""
super
().
__init__
()
self
.
eps
=
eps
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
dim
))
setattr
(
self
.
weight
,
'sequence_parallel'
,
sequence_parallel
)
def
_norm
(
self
,
x
):
return
x
*
torch
.
rsqrt
(
x
.
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
+
self
.
eps
)
def
forward
(
self
,
x
):
output
=
self
.
_norm
(
x
.
float
()).
type_as
(
x
)
return
output
*
self
.
weight
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/t5_model.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""T5 model."""
import
torch
from
megatron.training
import
get_args
from
megatron.core
import
tensor_parallel
from
megatron.legacy.model.enums
import
AttnMaskType
from
megatron.legacy.model.language_model
import
parallel_lm_logits
,
get_language_model
from
megatron.legacy.model
import
LayerNorm
from
megatron.legacy.model.utils
import
(
openai_gelu
,
get_linear_layer
)
from
.module
import
MegatronModule
def
t5_extended_attention_mask
(
attention_mask_list
):
def
attn_mask_postprocess
(
attn_mask
):
# [b, 1, s, s]
extended_attention_mask
=
attn_mask
.
unsqueeze
(
1
)
return
extended_attention_mask
return
[
attn_mask_postprocess
(
attn_mask
)
for
attn_mask
in
attention_mask_list
]
def
t5_position_ids
(
token_ids
):
# Create position ids
seq_length
=
token_ids
.
size
(
1
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
token_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
token_ids
)
return
position_ids
class
T5LMHead
(
MegatronModule
):
"""Masked LM head for T5
Args:
mpu_vocab_size: model parallel size of vocabulary.
parallel_output: wether output logits being distributed or not.
"""
def
__init__
(
self
,
mpu_vocab_size
,
parallel_output
):
super
(
T5LMHead
,
self
).
__init__
()
self
.
bias
=
torch
.
nn
.
Parameter
(
torch
.
zeros
(
mpu_vocab_size
))
self
.
bias
.
model_parallel
=
True
self
.
bias
.
partition_dim
=
0
self
.
bias
.
stride
=
1
self
.
parallel_output
=
parallel_output
def
forward
(
self
,
hidden_states
,
word_embeddings_weight
):
output
=
parallel_lm_logits
(
hidden_states
,
word_embeddings_weight
,
self
.
parallel_output
,
bias
=
self
.
bias
)
return
output
class
T5Model
(
MegatronModule
):
"""T5 Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
,
add_encoder
=
True
,
add_decoder
=
True
):
super
().
__init__
(
config
=
config
)
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
add_encoder
=
add_encoder
self
.
add_decoder
=
add_decoder
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
add_encoder
=
add_encoder
,
add_decoder
=
add_decoder
,
encoder_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
self
.
initialize_word_embeddings
()
if
self
.
post_process
and
self
.
add_decoder
:
self
.
lm_head
=
T5LMHead
(
self
.
shared_embedding_or_output_weight
().
size
(
0
),
parallel_output
)
self
.
_lm_head_key
=
'lm_head'
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
encoder_input_ids
,
decoder_input_ids
,
encoder_attn_mask
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
tokentype_ids
=
None
,
lm_labels
=
None
,
enc_hidden_states
=
None
):
# Converting the attention masks to proper parameter settings
encoder_attn_mask
,
decoder_attn_mask
,
encoder_decoder_attn_mask
=
t5_extended_attention_mask
(
[
encoder_attn_mask
,
decoder_attn_mask
,
encoder_decoder_attn_mask
])
encoder_position_ids
=
t5_position_ids
(
encoder_input_ids
)
decoder_position_ids
=
t5_position_ids
(
decoder_input_ids
)
lm_output
=
self
.
language_model
(
encoder_input_ids
,
encoder_position_ids
,
encoder_attn_mask
,
decoder_input_ids
,
decoder_position_ids
,
decoder_attn_mask
,
encoder_decoder_attn_mask
,
tokentype_ids
=
tokentype_ids
,
enc_hidden_states
=
enc_hidden_states
)
if
self
.
post_process
and
self
.
add_decoder
:
decoder_output
,
encoder_output
=
lm_output
# Output. [s, b, h]
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
shared_embedding_or_output_weight
())
if
lm_labels
is
None
:
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
else
:
# [b s] => [s b]
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
lm_loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
# [s b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
decoder_output
,
encoder_output
=
lm_output
return
decoder_output
else
:
encoder_output
=
lm_output
return
encoder_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
"""For easy load when model is combined with other heads,
add an extra key."""
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
if
self
.
post_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_lm_head_key
]
\
=
self
.
lm_head
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
self
.
language_model
.
load_state_dict
(
state_dict
[
self
.
_language_model_key
],
strict
=
strict
)
if
self
.
post_process
and
self
.
add_decoder
:
self
.
lm_head
.
load_state_dict
(
state_dict
[
self
.
_lm_head_key
],
strict
=
strict
)
# Load word embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
self
.
add_decoder
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/transformer.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
"""Transformer."""
from
contextlib
import
nullcontext
import
os
import
math
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
typing
import
Optional
from
megatron
import
core
from
megatron.training
import
get_timers
,
get_args
,
get_num_microbatches
from
.module
import
MegatronModule
from
megatron.core
import
mpu
,
tensor_parallel
from
megatron.core.enums
import
ModelType
from
megatron.legacy.model.enums
import
AttnMaskType
,
LayerType
,
AttnType
from
megatron.legacy.model.fused_softmax
import
FusedScaleMaskSoftmax
from
megatron.legacy.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.core.models.common.embeddings.rotary_pos_embedding
import
RotaryEmbedding
,
apply_rotary_pos_emb
from
megatron.legacy.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
,
get_norm
from
megatron.core.tensor_parallel
import
(
gather_from_sequence_parallel_region_to_moe
,
reduce_scatter_to_sequence_parallel_region_from_moe
,
get_cuda_rng_tracker
,
get_data_parallel_rng_tracker_name
)
from
megatron.core.parallel_state
import
get_tensor_model_parallel_group
,
get_tensor_and_expert_parallel_group
from
megatron.core.jit
import
jit_fuser
try
:
from
einops
import
rearrange
except
ImportError
:
rearrange
=
None
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_func
except
ImportError
:
try
:
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_func
as
flash_attn_unpadded_func
except
ImportError
:
flash_attn_unpadded_func
=
None
""" We use the following notation throughout this file:
h: hidden size
n: number of attention heads
p: number of model parallel partitions
np: n/p
hp: h/p
hn: h/n
b: batch size
s: sequence length
l: number of layers
Transformer takes input of size [s, b, h] and returns a
tensor of the same size. We use the following arguments:
hyperparameters: transformer hyperparameters
"""
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
"""
def
__init__
(
self
,
drop_prob
=
0.
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
hidden_state
):
if
self
.
drop_prob
==
0.
or
not
self
.
training
:
return
hidden_state
keep_prob
=
1
-
self
.
drop_prob
# work with diff dim tensors, not just 2D ConvNets
# hidden_state: [s, b, h]
shape
=
(
1
,)
+
(
hidden_state
.
shape
[
1
],)
+
(
1
,)
*
(
hidden_state
.
ndim
-
2
)
random_tensor
=
keep_prob
+
\
torch
.
rand
(
shape
,
dtype
=
hidden_state
.
dtype
,
device
=
hidden_state
.
device
)
random_tensor
.
floor_
()
# binarize
output
=
hidden_state
.
div
(
keep_prob
)
*
random_tensor
return
output
class
ParallelMLP
(
MegatronModule
):
"""MLP.
MLP will take the input with h hidden state, project it to 4*h
hidden dimension, perform nonlinear transformation, and project the
state back into h hidden dimension.
"""
def
__init__
(
self
,
config
,
is_expert
=
False
):
super
(
ParallelMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
add_bias
=
config
.
add_bias_linear
ffn_hidden_size
=
config
.
ffn_hidden_size
if
config
.
gated_linear_unit
:
ffn_hidden_size
*=
2
# Project to 4h. If using swiglu double the output width, see https://arxiv.org/pdf/2002.05202.pdf
self
.
dense_h_to_4h
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
ffn_hidden_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
self
.
add_bias
,
gather_output
=
False
,
skip_bias_add
=
True
,
is_expert
=
is_expert
,
)
self
.
bias_gelu_fusion
=
False
self
.
activation_func
=
None
self
.
swiglu
=
args
.
swiglu
if
args
.
openai_gelu
:
self
.
activation_func
=
openai_gelu
elif
args
.
onnx_safe
:
self
.
activation_func
=
erf_gelu
elif
args
.
swiglu
:
def
swiglu
(
x
):
x
=
torch
.
chunk
(
x
,
2
,
dim
=-
1
)
return
F
.
silu
(
x
[
0
])
*
x
[
1
]
self
.
activation_func
=
swiglu
elif
args
.
squared_relu
:
def
squared_relu
(
x
):
return
torch
.
pow
(
F
.
relu
(
x
),
2
)
self
.
activation_func
=
squared_relu
else
:
self
.
bias_gelu_fusion
=
args
.
bias_gelu_fusion
self
.
activation_func
=
F
.
gelu
# Project back to h.
self
.
dense_4h_to_h
=
tensor_parallel
.
RowParallelLinear
(
config
.
ffn_hidden_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
self
.
add_bias
,
skip_bias_add
=
True
,
input_is_parallel
=
True
,
is_expert
=
is_expert
,
)
def
forward
(
self
,
hidden_states
):
# [s, b, 4hp]
intermediate_parallel
,
bias_parallel
=
self
.
dense_h_to_4h
(
hidden_states
)
if
self
.
bias_gelu_fusion
:
assert
self
.
add_bias
is
True
assert
self
.
activation_func
==
F
.
gelu
intermediate_parallel
=
bias_gelu_impl
(
intermediate_parallel
,
bias_parallel
)
else
:
if
bias_parallel
is
not
None
:
intermediate_parallel
=
intermediate_parallel
+
bias_parallel
intermediate_parallel
=
self
.
activation_func
(
intermediate_parallel
)
# [s, b, h]
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
def
sinkhorn
(
cost
,
tol
=
0.0001
):
cost
=
torch
.
exp
(
cost
)
d0
=
torch
.
ones
(
cost
.
size
(
0
),
device
=
cost
.
device
,
dtype
=
cost
.
dtype
)
d1
=
torch
.
ones
(
cost
.
size
(
1
),
device
=
cost
.
device
,
dtype
=
cost
.
dtype
)
eps
=
0.00000001
error
=
1e9
d1_old
=
d1
while
error
>
tol
:
d0
=
(
1
/
d0
.
size
(
0
))
*
1
/
(
torch
.
sum
(
d1
*
cost
,
1
)
+
eps
)
d1
=
(
1
/
d1
.
size
(
0
))
*
1
/
(
torch
.
sum
(
d0
.
unsqueeze
(
1
)
*
cost
,
0
)
+
eps
)
error
=
torch
.
mean
(
torch
.
abs
(
d1_old
-
d1
))
d1_old
=
d1
return
d1
*
cost
*
d0
.
unsqueeze
(
1
)
def
get_router_linear_layer
(
config
):
args
=
get_args
()
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
,
bias
=
False
)
with
get_cuda_rng_tracker
().
fork
(
get_data_parallel_rng_tracker_name
()):
config
.
init_method
(
router
.
weight
)
setattr
(
router
.
weight
,
'sequence_parallel'
,
config
.
sequence_parallel
)
return
router
class
SwitchMLP
(
MegatronModule
):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
config
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
get_router_linear_layer
(
config
)
self
.
expert_parallel_size
=
mpu
.
get_expert_model_parallel_world_size
()
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
add_bias
=
config
.
add_bias_linear
assert
args
.
num_experts
%
self
.
expert_parallel_size
==
0
self
.
num_local_experts
=
args
.
num_experts
//
self
.
expert_parallel_size
local_expert_indices_offset
=
mpu
.
get_expert_model_parallel_rank
()
*
self
.
num_local_experts
self
.
local_expert_indices
=
[
local_expert_indices_offset
+
i
for
i
in
range
(
self
.
num_local_experts
)]
self
.
local_experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
self
.
num_local_experts
):
self
.
local_experts
.
append
(
ParallelMLP
(
config
,
is_expert
=
True
))
def
gather_indices
(
self
,
local_indices
):
""" Gather tensors and concatinate along the first dimension."""
group
=
get_tensor_and_expert_parallel_group
()
world_size
=
torch
.
distributed
.
get_world_size
(
group
=
group
)
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
local_indices
dim_size
=
list
(
local_indices
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
# TODO pre allocate memory
output
=
torch
.
empty
(
dim_size
,
dtype
=
local_indices
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
output
,
local_indices
.
contiguous
(),
group
=
group
)
return
output
def
forward
(
self
,
hidden_states
):
# hidden_states: [b, s, h]
args
=
get_args
()
s
=
hidden_states
.
size
(
0
)
b
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
).
view
(
-
1
,
args
.
num_experts
)
# TODO (rprenger) Right now we're just using the sinkhorn algorithm
# for load balancing. There should be an option to do no load balancing
# and the algorithm and parametets should be further tested
if
self
.
training
:
with
torch
.
no_grad
():
sinkroute
=
sinkhorn
(
route
.
detach
().
to
(
dtype
=
torch
.
float32
))
_
,
max_ind
=
torch
.
max
(
sinkroute
,
dim
=
1
)
route
=
torch
.
sigmoid
(
route
)
max_prob
=
route
[
torch
.
arange
(
route
.
size
(
0
)),
max_ind
]
else
:
route
=
torch
.
sigmoid
(
route
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
1
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
1
)
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# TODO (rprenger) TODO this could be made easier to read
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
if
self
.
sequence_parallel
or
(
self
.
expert_parallel_size
>
1
):
global_hidden_states
=
\
gather_from_sequence_parallel_region_to_moe
(
hidden_states
)
global_indices
=
self
.
gather_indices
(
max_ind
)
else
:
global_hidden_states
=
hidden_states
global_indices
=
max_ind
output_total
=
torch
.
zeros_like
(
global_hidden_states
)
if
self
.
add_bias
:
output_bias_total
=
torch
.
zeros_like
(
global_hidden_states
)
for
expert_num
,
expert
in
enumerate
(
self
.
local_experts
):
local_expert_index
=
self
.
local_expert_indices
[
expert_num
]
local_indices
=
(
global_indices
==
local_expert_index
).
nonzero
()
hidden
=
global_hidden_states
[
local_indices
,
:]
output
,
output_bias
=
expert
(
hidden
)
output_total
[
local_indices
,
:]
=
output
if
self
.
add_bias
:
output_bias
=
output_bias
.
expand_as
(
output
)
output_bias_total
[
local_indices
,
:]
=
output_bias
if
self
.
sequence_parallel
or
(
self
.
expert_parallel_size
>
1
):
output_total
=
\
reduce_scatter_to_sequence_parallel_region_from_moe
(
output_total
)
if
self
.
add_bias
:
output_bias_total
=
\
reduce_scatter_to_sequence_parallel_region_from_moe
(
output_bias_total
)
# bias is duplicated across tensor parallelism ranks;
# reduce scatter reduces bias across tensor parallel_ranks
output_bias_total
=
\
output_bias_total
/
mpu
.
get_tensor_model_parallel_world_size
()
output_total
=
output_total
*
max_prob
output_total
=
output_total
.
view
(
s
,
b
,
h
)
if
self
.
add_bias
:
output_bias_total
=
output_bias_total
*
max_prob
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
else
:
output_bias_total
=
None
return
output_total
,
output_bias_total
class
CoreAttention
(
MegatronModule
):
def
__init__
(
self
,
layer_number
,
config
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
CoreAttention
,
self
).
__init__
()
self
.
fp16
=
config
.
fp16
self
.
bf16
=
config
.
bf16
self
.
apply_query_key_layer_scaling
=
config
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
config
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
self
.
sequence_parallel
=
config
.
sequence_parallel
projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
core
.
utils
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
config
.
num_attention_heads
,
world_size
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
config
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
config
.
attention_dropout
)
def
forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
reshape
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
mpu
.
get_global_memory_buffer
().
get_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
query_layer
.
dtype
,
"mpu"
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_input_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
sequence_parallel
:
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
return
context_layer
class
FlashSelfAttention
(
torch
.
nn
.
Module
):
"""Implement the scaled dot product attention with softmax.
Arguments
---------
softmax_scale: The temperature to use for the softmax attention.
(default: 1/sqrt(d_keys) where d_keys is computed at
runtime)
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
device
=
None
,
dtype
=
None
):
super
().
__init__
()
assert
flash_attn_unpadded_func
is
not
None
,
(
'Please install FlashAttention first, '
'e.g., with pip install flash-attn'
)
assert
rearrange
is
not
None
,
'Please install einops first, e.g., with pip install einops'
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
dropout_p
=
attention_dropout
def
forward
(
self
,
q
,
k
,
v
):
"""Implements the multihead softmax attention.
Arguments
---------
q, k, v: The tensor containing the query, key, and value. (B, S, H, D)
"""
assert
all
((
i
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
for
i
in
(
q
,
k
,
v
)))
assert
all
((
i
.
is_cuda
for
i
in
(
q
,
k
,
v
)))
batch_size
,
seqlen_q
=
q
.
shape
[
0
],
q
.
shape
[
1
]
seqlen_k
=
k
.
shape
[
1
]
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
q
,
k
,
v
]]
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
if
self
.
training
:
# during training q,k,v always have same seqlen
assert
seqlen_k
==
seqlen_q
is_causal
=
self
.
causal
cu_seqlens_k
=
cu_seqlens_q
dropout_p
=
self
.
dropout_p
else
:
# turn off FA causal mask after first inference autoregressive iteration
# only on first autoregressive step q,k,v have same seqlen
is_causal
=
seqlen_q
==
seqlen_k
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int32
,
device
=
q
.
device
)
dropout_p
=
0
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
dropout_p
,
softmax_scale
=
self
.
softmax_scale
,
causal
=
is_causal
)
output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
return
output
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
def
__init__
(
self
,
config
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
config
.
params_dtype
self
.
sequence_parallel
=
config
.
sequence_parallel
self
.
config
=
config
self
.
group_query_attention
=
args
.
group_query_attention
self
.
num_query_groups
=
args
.
num_query_groups
query_projection_size
=
config
.
kv_channels
*
config
.
num_attention_heads
if
self
.
group_query_attention
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_query_groups
else
:
kv_projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
self
.
use_flash_attn
=
args
.
use_flash_attn
\
and
attention_type
==
AttnType
.
self_attn
\
and
self
.
attn_mask_type
==
AttnMaskType
.
causal
if
self
.
use_flash_attn
:
if
flash_attn_unpadded_func
is
None
:
raise
ImportError
(
'FlashAttention is not installed, please install with '
'pip install flash-attn'
)
assert
attention_type
==
AttnType
.
self_attn
,
(
'FlashAttention code path only supports '
'self-attention for now'
)
assert
self
.
attn_mask_type
==
AttnMaskType
.
causal
,
(
'FlashAttention code path only '
'supports causal mask for now'
)
if
rearrange
is
None
:
raise
ImportError
(
'einops is not installed, please install with pip install einops'
)
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_attention_head
=
core
.
utils
.
divide
(
query_projection_size
,
config
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
core
.
utils
.
divide
(
config
.
num_attention_heads
,
world_size
)
if
self
.
group_query_attention
:
if
args
.
num_query_groups
%
world_size
!=
0
:
raise
NotImplementedError
(
'Currently the num_query_groups should be '
'a multiple of the tensor parallel size'
)
self
.
num_query_groups_per_partition
=
core
.
utils
.
divide
(
args
.
num_query_groups
,
world_size
)
else
:
self
.
num_query_groups_per_partition
=
self
.
num_attention_heads_per_partition
# Strided linear layer.
if
attention_type
==
AttnType
.
self_attn
:
self
.
query_key_value
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
query_projection_size
+
2
*
kv_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
args
.
add_bias_linear
or
args
.
add_qkv_bias
,
gather_output
=
False
)
else
:
assert
attention_type
==
AttnType
.
cross_attn
if
self
.
group_query_attention
:
raise
NotImplementedError
(
"Grouped query attention not implemented for cross-attention."
)
assert
query_projection_size
==
kv_projection_size
self
.
query
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
query_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
config
.
add_bias_linear
,
gather_output
=
False
)
self
.
key_value
=
tensor_parallel
.
ColumnParallelLinear
(
config
.
hidden_size
,
2
*
kv_projection_size
,
config
=
config
,
init_method
=
config
.
init_method
,
bias
=
config
.
add_bias_linear
,
gather_output
=
False
)
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
config
,
self
.
attn_mask_type
)
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
if
self
.
use_flash_attn
:
self
.
core_attention_flash
=
FlashSelfAttention
(
causal
=
True
,
attention_dropout
=
config
.
attention_dropout
)
# Output.
self
.
dense
=
tensor_parallel
.
RowParallelLinear
(
query_projection_size
,
config
.
hidden_size
,
config
=
config
,
init_method
=
config
.
output_layer_init_method
,
bias
=
args
.
add_bias_linear
,
input_is_parallel
=
True
,
skip_bias_add
=
True
)
def
_checkpointed_attention_forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
rotary_pos_emb
=
None
):
"""Forward method with activation checkpointing."""
def
custom_forward
(
*
inputs
):
query_layer
=
inputs
[
0
]
key_layer
=
inputs
[
1
]
value_layer
=
inputs
[
2
]
attention_mask
=
inputs
[
3
]
output_
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
return
output_
q_pos_emb
,
k_pos_emb
=
(
None
,
None
)
if
rotary_pos_emb
is
None
\
else
rotary_pos_emb
hidden_states
=
tensor_parallel
.
checkpoint
(
custom_forward
,
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
,
q_pos_emb
,
k_pos_emb
)
return
hidden_states
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
,
num_attention_heads
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
num_attention_heads
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
is_first_step
=
False
if
inference_params
:
if
self
.
layer_number
not
in
inference_params
.
key_value_memory_dict
:
inf_max_seq_len
=
inference_params
.
max_sequence_length
inf_max_batch_size
=
inference_params
.
max_batch_size
inference_key_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_value_memory
=
self
.
_allocate_memory
(
inf_max_seq_len
,
inf_max_batch_size
,
self
.
num_query_groups_per_partition
)
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
=
(
inference_key_memory
,
inference_value_memory
)
is_first_step
=
True
else
:
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# Query, Key, and Value
# =====================
if
self
.
attention_type
==
AttnType
.
self_attn
:
# Attention heads [sq, b, h] --> [sq, b, ng * (np/ng + 2) * hn)]
mixed_x_layer
,
_
=
self
.
query_key_value
(
hidden_states
)
# [sq, b, hp] --> [sq, b, ng, (np/ng + 2) * hn]
new_tensor_shape
=
mixed_x_layer
.
size
()[:
-
1
]
+
(
self
.
num_query_groups_per_partition
,
(
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
+
2
)
*
self
.
hidden_size_per_attention_head
),
)
mixed_x_layer
=
mixed_x_layer
.
view
(
*
new_tensor_shape
)
# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(
query_layer
,
key_layer
,
value_layer
)
=
torch
.
split
(
mixed_x_layer
,
[
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
*
self
.
hidden_size_per_attention_head
),
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
],
dim
=
3
)
# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn] -
query_layer
=
query_layer
.
view
(
query_layer
.
size
(
0
),
query_layer
.
size
(
1
),
-
1
,
self
.
hidden_size_per_attention_head
)
else
:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer
,
_
=
self
.
key_value
(
encoder_output
)
# [sk, b, (np * 2 * hn)] --> [sk, b, np, 2 * hn]
new_tensor_shape
=
mixed_kv_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
2
*
self
.
hidden_size_per_attention_head
)
mixed_kv_layer
=
mixed_kv_layer
.
view
(
*
new_tensor_shape
)
# [sk, b, np, 2 * hn] --> 2 [sk, b, np, hn]
(
key_layer
,
value_layer
)
=
tensor_parallel
.
split_tensor_along_last_dim
(
mixed_kv_layer
,
2
)
# Attention head [sq, b, h] --> [sq, b, hp]
query_layer
,
_
=
self
.
query
(
hidden_states
)
# [sq, b, hp] --> [sq, b, np, hn]
new_tensor_shape
=
query_layer
.
size
()[:
-
1
]
+
\
(
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
# duplicate the pos_emb for self attention
if
rotary_pos_emb
is
not
None
:
if
isinstance
(
rotary_pos_emb
,
tuple
):
rotary_pos_emb
=
rotary_pos_emb
else
:
rotary_pos_emb
=
((
rotary_pos_emb
,)
*
2
)
if
inference_params
:
batch_start
=
inference_params
.
batch_size_offset
batch_end
=
batch_start
+
key_layer
.
size
(
1
)
assert
batch_end
<=
inference_key_memory
.
size
(
1
)
sequence_start
=
inference_params
.
sequence_len_offset
sequence_end
=
sequence_start
+
key_layer
.
size
(
0
)
assert
sequence_end
<=
inference_key_memory
.
size
(
0
)
# Copy key and values.
inference_key_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
key_layer
inference_value_memory
[
sequence_start
:
sequence_end
,
batch_start
:
batch_end
,
...]
=
value_layer
key_layer
=
inference_key_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
# adjust the key rotary positional embedding
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
# need to cross check this condition during inference
# if not set_inference_key_value_memory:
if
not
is_first_step
:
# In inference, we compute one token at a time.
# Select the correct positional embedding
# (only the last token in the sequence)
q_pos_emb
=
q_pos_emb
[
sequence_end
-
1
:
sequence_end
]
else
:
# In the first forward pass of inference,
# we use the entire provided prefix.
# q_pos_emb here has the rope embeddings of the entire
# prefix + to-be-generated output so
# we slice to just the prefix.
q_pos_emb
=
q_pos_emb
[:
sequence_end
,
:,
:,
:]
k_pos_emb
=
k_pos_emb
[:
sequence_end
,
:,
:,
:]
rotary_pos_emb
=
(
q_pos_emb
,
k_pos_emb
)
# ==================================
# core attention computation
# ==================================
# expand the key_layer and value_layer [sk, b, ng, hn] -> [sk, b, np, hn]
if
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
>
1
:
key_layer
=
key_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
value_layer
=
value_layer
.
repeat_interleave
(
self
.
num_attention_heads_per_partition
//
self
.
num_query_groups_per_partition
,
dim
=
2
)
# apply relative positional encoding (rotary embedding)
if
rotary_pos_emb
is
not
None
:
q_pos_emb
,
k_pos_emb
=
rotary_pos_emb
query_layer
=
apply_rotary_pos_emb
(
query_layer
,
q_pos_emb
,
self
.
config
)
key_layer
=
apply_rotary_pos_emb
(
key_layer
,
k_pos_emb
,
self
.
config
)
# TODO, can apply positional embedding to value_layer so it has
# absolute positional embedding.
# otherwise, only relative positional embedding takes effect
# value_layer = apply_rotary_pos_emb(value_layer, k_pos_emb)
if
not
self
.
use_flash_attn
:
if
self
.
checkpoint_core_attention
:
context_layer
=
self
.
_checkpointed_attention_forward
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
context_layer
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
q
,
k
,
v
=
[
rearrange
(
x
,
's b ... -> b s ...'
).
contiguous
()
for
x
in
(
query_layer
,
key_layer
,
value_layer
)]
if
not
self
.
sequence_parallel
:
with
tensor_parallel
.
get_cuda_rng_tracker
().
fork
():
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
else
:
context_layer
=
self
.
core_attention_flash
(
q
,
k
,
v
)
context_layer
=
rearrange
(
context_layer
,
'b s h d -> s b (h d)'
).
contiguous
()
# =================
# Output. [sq, b, h]
# =================
output
,
bias
=
self
.
dense
(
context_layer
)
return
output
,
bias
def
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
):
# type: (Tensor, Optional[Tensor], Tensor, float, bool) -> Tensor
if
bias
is
not
None
:
x
=
x
+
bias
out
=
torch
.
nn
.
functional
.
dropout
(
x
,
p
=
prob
,
training
=
training
)
out
=
residual
+
out
return
out
def
get_bias_dropout_add
(
training
):
def
_bias_dropout_add
(
x
,
bias
,
residual
,
prob
):
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
training
)
return
_bias_dropout_add
@
jit_fuser
def
bias_dropout_add_fused_train
(
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
True
)
@
jit_fuser
def
bias_dropout_add_fused_inference
(
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
],
residual
:
torch
.
Tensor
,
prob
:
float
)
->
torch
.
Tensor
:
return
bias_dropout_add
(
x
,
bias
,
residual
,
prob
,
False
)
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
def
__init__
(
self
,
config
,
layer_number
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
drop_path_rate
=
0.
):
args
=
get_args
()
super
(
ParallelTransformerLayer
,
self
).
__init__
()
self
.
layer_number
=
layer_number
self
.
layer_type
=
layer_type
self
.
apply_residual_connection_post_norm
\
=
config
.
apply_residual_connection_post_layernorm
self
.
bf16
=
config
.
bf16
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
# Normalize the input data.
self
.
input_norm
=
get_norm
(
config
)
# Self attention.
self
.
self_attention
=
ParallelAttention
(
config
,
layer_number
,
attention_type
=
AttnType
.
self_attn
,
attn_mask_type
=
self_attn_mask_type
)
self
.
hidden_dropout
=
config
.
hidden_dropout
self
.
bias_dropout_fusion
=
config
.
bias_dropout_fusion
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Normalize the attention output
self
.
post_attention_norm
=
get_norm
(
config
)
# Cross attention.
if
self
.
layer_type
in
(
LayerType
.
decoder
,
LayerType
.
retro_decoder
,
LayerType
.
retro_decoder_with_retriever
,
LayerType
.
retro_encoder
):
self
.
inter_attention
=
ParallelAttention
(
config
,
layer_number
,
attention_type
=
AttnType
.
cross_attn
)
# Normalize the attention output.
self
.
post_inter_attention_norm
=
get_norm
(
config
)
# MLP
if
args
.
num_experts
is
not
None
:
self
.
mlp
=
SwitchMLP
(
config
)
else
:
self
.
mlp
=
ParallelMLP
(
config
)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
use_nvfuser
=
TORCH_MAJOR
>
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
)
self
.
bias_dropout_add_exec_handler
=
\
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
if
args
.
retro_add_retriever
:
self
.
retro_num_neighbors
=
args
.
retro_num_neighbors
self
.
retro_chunk_length
=
args
.
retro_chunk_length
self
.
retro_retrieved_length
=
\
args
.
retro_num_retrieved_chunks
*
args
.
retro_chunk_length
# Retriever (bi-directional transformer with cross attention)
if
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
self
.
retriever
=
ParallelTransformer
(
config
=
config
,
model_type
=
ModelType
.
retro_encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
pre_process
=
True
,
post_process
=
False
,
)
self
.
_retriever_key
=
'retriever'
else
:
self
.
retriever
=
None
def
default_decoder_cross_attention
(
self
,
encoder_output
,
enc_dec_attn_mask
,
norm_input
,
norm_output
,
bias_dropout_add_func
):
'''Cross attention for a standard encoder-decoder model.'''
# Attention.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
norm_output
,
enc_dec_attn_mask
,
encoder_output
=
encoder_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_norm
:
residual
=
norm_output
else
:
residual
=
norm_input
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
# Bias-dropout-add.
with
self
.
bias_dropout_add_exec_handler
():
norm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
)
# Normalize.
norm_output
=
self
.
post_inter_attention_norm
(
norm_input
)
return
norm_input
,
norm_output
def
retro_encoder_cross_attention
(
self
,
retriever_output
,
norm_input
,
norm_output
,
bias_dropout_add_func
):
"""Cross attention for Retro encoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns
,
bs
,
d
=
norm_output
.
shape
# [r, bs * l * k, d]
# Divide sequence dimension into chunks.
chunked_outputs
=
norm_output
.
reshape
(
self
.
retro_retrieved_length
,
-
1
,
self
.
retro_num_neighbors
,
d
)
chunked_outputs_before_norm
=
\
norm_input
.
reshape
(
self
.
retro_retrieved_length
,
-
1
,
self
.
retro_num_neighbors
,
d
)
# [r, bs*l, k, d]
# Per-chunk attention.
norm_inputs
=
[]
norm_outputs
=
[]
for
k
in
range
(
self
.
retro_num_neighbors
):
# Attention.
chunked_output
=
chunked_outputs
[:,:,
k
].
contiguous
()
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
chunked_output
,
# Q (neighbor embedding)
None
,
encoder_output
=
retriever_output
)
# K, V (hidden act)
# Residual connection.
if
self
.
apply_residual_connection_post_norm
:
residual
=
chunked_output
else
:
residual
=
chunked_outputs_before_norm
[:,:,
k
]
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
norm_input
=
bias_dropout_add_func
(
attention_output
,
None
if
attention_bias
is
None
else
attention_bias
.
expand_as
(
residual
),
residual
,
self
.
hidden_dropout
)
norm_inputs
.
append
(
norm_input
)
# Layer norm.
norm_output
=
self
.
post_inter_attention_norm
(
norm_input
)
norm_outputs
.
append
(
norm_output
)
# Concatenate layer norms.
# norm_input : [r, k * bs * l, d]
# norm_output : [r, k * bs * l, d]
norm_input
=
torch
.
stack
(
norm_inputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
norm_output
=
torch
.
stack
(
norm_outputs
,
dim
=
1
).
reshape
(
ns
,
bs
,
d
)
return
norm_input
,
norm_output
def
retro_decoder_cross_attention
(
self
,
retriever_input
,
retriever_output
,
retriever_attn_mask
,
norm_input
,
norm_output
,
inference_params
,
bias_dropout_add_func
):
"""Cross attention for Retro decoder.
Notation:
ns : Sequence length.
bs : Batch size.
d : Hidden size.
l : Number of chunks per sample (i.e., seq_length/chunk_length).
m : Number of tokens per chunk.
k : Number of neighbors.
r : Number of retrieved tokens (neighbors + continuation).
"""
ns
,
bs
,
d
=
norm_output
.
shape
l
=
int
(
np
.
ceil
(
ns
/
self
.
retro_chunk_length
))
# Retrieve neighbors.
if
self
.
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
first_ns
=
ns
%
self
.
retro_chunk_length
if
first_ns
>
0
:
first_chunk
,
rest_chunk
=
\
norm_output
[:
first_ns
],
norm_output
[
first_ns
:]
first_chunk
=
torch
.
nn
.
functional
.
pad
(
first_chunk
,
(
0
,
0
,
0
,
0
,
0
,
self
.
retro_chunk_length
-
first_ns
),
'constant'
,
0
)
chunked_output
=
\
torch
.
cat
((
first_chunk
,
rest_chunk
),
dim
=
0
)
# [l * m, bs, d]
else
:
chunked_output
=
norm_output
# [l * m, bs, d]
chunked_output
=
chunked_output
\
.
reshape
(
l
,
self
.
retro_chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
\
.
reshape
(
self
.
retro_chunk_length
,
bs
*
l
,
d
)
\
.
contiguous
()
# Get Encoder Output
retriever_output
=
self
.
retriever
(
hidden_states
=
retriever_input
,
attention_mask
=
retriever_attn_mask
,
retriever_output
=
chunked_output
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
# [r, k * bs * l , d]
retriever_output
=
retriever_output
.
reshape
(
self
.
retro_retrieved_length
*
self
.
retro_num_neighbors
,
bs
*
l
,
d
)
# [r * k, bs * l, d]
# Chunks.
pad
=
(
ns
-
1
)
%
self
.
retro_chunk_length
attending_chunks
=
norm_output
[
pad
:]
padded_chunks
=
torch
.
nn
.
functional
.
pad
(
attending_chunks
,
(
0
,
0
,
0
,
0
,
0
,
self
.
retro_chunk_length
-
1
),
'constant'
,
0
)
padded_chunked_output
=
padded_chunks
\
.
reshape
(
l
,
self
.
retro_chunk_length
,
bs
,
d
)
\
.
permute
(
1
,
2
,
0
,
3
)
padded_chunked_output
=
padded_chunked_output
.
reshape
(
self
.
retro_chunk_length
,
bs
*
l
,
d
).
contiguous
()
# Encoder output.
attention_output
,
attention_bias
=
\
self
.
inter_attention
(
padded_chunked_output
,
None
,
encoder_output
=
retriever_output
)
# Residual connection.
if
self
.
apply_residual_connection_post_norm
:
residual
=
norm_output
else
:
residual
=
norm_input
# Re-enable torch grad to enable fused optimization.
with
torch
.
enable_grad
():
norm_input
=
bias_dropout_add_func
(
attention_output
,
None
if
attention_bias
is
None
else
attention_bias
.
expand_as
(
attention_output
),
torch
.
zeros_like
(
attention_output
),
self
.
hidden_dropout
)
norm_input
=
norm_input
\
.
reshape
(
self
.
retro_chunk_length
,
bs
,
l
,
d
)
\
.
permute
(
2
,
0
,
1
,
3
)
# [l, m, bs, d]
norm_input
=
norm_input
.
reshape
(
self
.
retro_chunk_length
*
l
,
bs
,
d
)
norm_input
=
torch
.
nn
.
functional
.
pad
(
norm_input
,
(
0
,
0
,
0
,
0
,
pad
,
0
),
'constant'
,
0
)[:
ns
]
# [ns, b, d]
# TODO: better redesign with inference param
args
=
get_args
()
norm_input
=
args
.
retro_attention_gate
*
norm_input
+
residual
# Layer norm post the decoder attention
norm_output
=
self
.
post_inter_attention_norm
(
norm_input
)
return
retriever_output
,
norm_input
,
norm_output
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
retriever_input
=
None
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# Update the params in case the retro param changes during inference
# TODO: better redesign with inference param
args
=
get_args
()
if
args
.
retro_add_retriever
:
self
.
retro_num_neighbors
=
args
.
retro_num_neighbors
self
.
retro_chunk_length
=
args
.
retro_chunk_length
self
.
retro_retrieved_length
=
\
args
.
retro_num_retrieved_chunks
*
args
.
retro_chunk_length
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
norm_output
=
self
.
input_norm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
\
self
.
self_attention
(
norm_output
,
attention_mask
,
inference_params
=
inference_params
,
rotary_pos_emb
=
rotary_pos_emb
)
# Residual connection.
if
self
.
apply_residual_connection_post_norm
:
residual
=
norm_output
else
:
residual
=
hidden_states
if
self
.
drop_path
is
None
:
# jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying
# dropout semantics during training and inference phases.
if
self
.
bias_dropout_fusion
:
if
self
.
training
:
bias_dropout_add_func
=
bias_dropout_add_fused_train
else
:
bias_dropout_add_func
=
bias_dropout_add_fused_inference
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
if
attention_bias
is
not
None
:
attention_bias
=
attention_bias
.
expand_as
(
residual
)
with
self
.
bias_dropout_add_exec_handler
():
norm_input
=
bias_dropout_add_func
(
attention_output
,
attention_bias
,
residual
,
self
.
hidden_dropout
)
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
attention_output
+
attention_bias
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
norm_input
=
residual
+
self
.
drop_path
(
out
)
# Layer norm post the self attention.
norm_output
=
self
.
post_attention_norm
(
norm_input
)
# Cross attention.
if
self
.
layer_type
==
LayerType
.
encoder
:
pass
elif
self
.
layer_type
==
LayerType
.
decoder
:
norm_input
,
norm_output
=
\
self
.
default_decoder_cross_attention
(
encoder_output
,
enc_dec_attn_mask
,
norm_input
,
norm_output
,
bias_dropout_add_func
)
elif
self
.
layer_type
==
LayerType
.
retro_encoder
:
norm_input
,
norm_output
=
\
self
.
retro_encoder_cross_attention
(
retriever_output
,
norm_input
,
norm_output
,
bias_dropout_add_func
)
elif
self
.
layer_type
in
(
LayerType
.
retro_decoder
,
LayerType
.
retro_decoder_with_retriever
):
retriever_output
,
norm_input
,
norm_output
=
\
self
.
retro_decoder_cross_attention
(
retriever_input
,
retriever_output
,
retriever_attn_mask
,
norm_input
,
norm_output
,
inference_params
,
bias_dropout_add_func
)
else
:
raise
Exception
(
"Unsupported layer type, '%s'."
%
self
.
layer_type
.
name
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
norm_output
)
# Second residual connection.
if
self
.
apply_residual_connection_post_norm
:
residual
=
norm_output
else
:
residual
=
norm_input
if
self
.
drop_path
is
None
:
if
mlp_bias
is
not
None
:
mlp_bias
=
mlp_bias
.
expand_as
(
residual
)
with
self
.
bias_dropout_add_exec_handler
():
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_bias
,
residual
,
self
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
core
.
utils
.
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
output
.
requires_grad
,
keep_graph
=
True
)
else
:
if
mlp_bias
is
not
None
:
mlp_output
=
mlp_output
+
mlp_bias
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
,
p
=
self
.
hidden_dropout
,
training
=
self
.
training
)
output
=
residual
+
self
.
drop_path
(
out
)
if
self
.
layer_type
==
LayerType
.
retro_decoder_with_retriever
:
return
output
,
retriever_output
else
:
return
output
class
NoopTransformerLayer
(
MegatronModule
):
"""A single 'no-op' transformer layer.
The sole purpose of this layer is for when a standalone embedding layer
is used (i.e., args.standalone_embedding_stage == True). In this case,
zero transformer layers are assigned when pipeline rank == 0. Additionally,
when virtual pipeline rank >= 1, zero total model parameters are created
(virtual rank 0 contains the input embedding). This results in the model's
input and output tensors being the same, which causes an error when
performing certain memory optimiations on the output tensor (e.g.,
deallocating it). Thus, this layer disconnects the input from the output
via a clone. Since ranks containing a no-op layer are generally under-
utilized (both compute and memory), there's no worry of any performance
degredation.
"""
def
__init__
(
self
,
layer_number
):
super
().
__init__
()
self
.
layer_number
=
layer_number
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
return
hidden_states
.
clone
()
def
_get_num_layers
(
args
,
model_type
,
is_decoder
=
False
):
"""Compute the number of transformer layers resident on the current rank."""
is_encoder_and_decoder_model
=
(
model_type
==
ModelType
.
encoder_and_decoder
)
if
model_type
==
ModelType
.
retro_encoder
:
num_layers
=
args
.
retro_encoder_layers
elif
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
if
is_encoder_and_decoder_model
:
assert
args
.
pipeline_model_parallel_split_rank
is
not
None
# When a standalone embedding stage is used, a rank is taken from
# the encoder's ranks, to be used for the encoder's embedding
# layer. This way, the rank referenced by the 'split rank' remains
# the same whether or not a standalone embedding stage is used.
num_ranks_in_encoder
=
(
args
.
pipeline_model_parallel_split_rank
-
1
if
args
.
standalone_embedding_stage
else
args
.
pipeline_model_parallel_split_rank
)
num_ranks_in_decoder
=
args
.
transformer_pipeline_model_parallel_size
-
num_ranks_in_encoder
assert
args
.
encoder_num_layers
%
num_ranks_in_encoder
==
0
,
\
'encoder_num_layers (%d) must be divisible by number of ranks given to encoder (%d)'
%
(
args
.
encoder_num_layers
,
num_ranks_in_encoder
)
assert
args
.
decoder_num_layers
%
num_ranks_in_decoder
==
0
,
\
'decoder_num_layers (%d) must be divisible by number of ranks given to decoder (%d)'
%
(
args
.
decoder_num_layers
,
num_ranks_in_decoder
)
if
mpu
.
is_pipeline_stage_before_split
():
num_layers
=
(
0
if
args
.
standalone_embedding_stage
and
mpu
.
get_pipeline_model_parallel_rank
()
==
0
else
args
.
encoder_num_layers
//
num_ranks_in_encoder
)
else
:
num_layers
=
args
.
decoder_num_layers
//
num_ranks_in_decoder
else
:
assert
args
.
num_layers
==
args
.
encoder_num_layers
assert
args
.
num_layers
%
args
.
transformer_pipeline_model_parallel_size
==
0
,
\
'num_layers must be divisible by transformer_pipeline_model_parallel_size'
# When a standalone embedding stage is used, all transformer layers
# are divided among pipeline rank >= 1, while on pipeline rank 0,
# ranks either contain the input embedding layer (virtual pp rank 0),
# or no layers at all (virtual pp rank >= 1).
num_layers
=
(
0
if
args
.
standalone_embedding_stage
and
mpu
.
get_pipeline_model_parallel_rank
()
==
0
else
args
.
num_layers
//
args
.
transformer_pipeline_model_parallel_size
)
else
:
if
not
is_decoder
:
num_layers
=
args
.
encoder_num_layers
else
:
num_layers
=
args
.
decoder_num_layers
return
num_layers
def
_get_layer_type
(
model_type
,
default_layer_type
,
retro_layer_numbers
,
layer_number
):
args
=
get_args
()
if
args
.
retro_add_retriever
and
layer_number
in
retro_layer_numbers
:
if
model_type
==
ModelType
.
retro_decoder
:
return
LayerType
.
retro_decoder_with_retriever
\
if
layer_number
==
retro_layer_numbers
[
0
]
\
else
LayerType
.
retro_decoder
elif
model_type
==
ModelType
.
retro_encoder
:
return
LayerType
.
retro_encoder
else
:
raise
Exception
(
"Unsupported model type, '%s'."
%
model_type
)
else
:
return
default_layer_type
class
ParallelTransformer
(
MegatronModule
):
"""Transformer class."""
def
__init__
(
self
,
config
,
model_type
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
post_norm
=
True
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
self
.
layer_type
=
layer_type
self
.
model_type
=
model_type
self
.
bf16
=
config
.
bf16
self
.
fp32_residual_connection
=
config
.
fp32_residual_connection
self
.
post_norm
=
post_norm
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
self
.
transformer_impl
=
args
.
transformer_impl
self
.
retro_add_retriever
=
args
.
retro_add_retriever
# Store activation checkpoiting flag.
self
.
recompute_granularity
=
config
.
recompute_granularity
self
.
recompute_method
=
config
.
recompute_method
self
.
recompute_num_layers
=
config
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
config
.
distribute_saved_activations
and
not
config
.
sequence_parallel
self
.
sequence_parallel
=
config
.
sequence_parallel
# Transformer Engine Init.
self
.
transformer_engine_v_0_10
=
False
self
.
transformer_engine_v_0_11
=
False
self
.
transformer_engine_v_0_8
=
False
if
self
.
transformer_impl
==
'transformer_engine'
:
global
transformer_engine
import
transformer_engine
from
importlib.metadata
import
version
from
pkg_resources
import
packaging
te_version
=
packaging
.
version
.
Version
(
version
(
"transformer-engine"
))
if
te_version
>=
packaging
.
version
.
Version
(
"0.8.0"
):
self
.
transformer_engine_v_0_8
=
True
if
te_version
>=
packaging
.
version
.
Version
(
"0.10.0"
):
self
.
transformer_engine_v_0_10
=
True
if
te_version
>=
packaging
.
version
.
Version
(
"0.11.0"
):
self
.
transformer_engine_v_0_11
=
True
del
version
,
packaging
assert
not
args
.
squared_relu
,
"TransformerEngine does not support squared relu activation."
self
.
use_fp8
=
args
.
fp8
is
not
None
self
.
fp8_recipe
=
None
self
.
fp8_group
=
None
if
self
.
use_fp8
:
assert
args
.
transformer_impl
==
'transformer_engine'
,
\
'transformer-engine required for fp8 training and inference'
self
.
fp8_group
=
mpu
.
get_amax_reduction_group
()
if
args
.
fp8
==
"e4m3"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
E4M3
elif
args
.
fp8
==
"hybrid"
:
fp8_format
=
transformer_engine
.
common
.
recipe
.
Format
.
HYBRID
else
:
raise
ValueError
(
"The DelayedScaling recipe only supports E4M3 and HYBRID formats."
)
self
.
fp8_recipe
=
transformer_engine
.
common
.
recipe
.
DelayedScaling
(
margin
=
args
.
fp8_margin
,
interval
=
args
.
fp8_interval
,
fp8_format
=
fp8_format
,
amax_history_len
=
args
.
fp8_amax_history_len
,
amax_compute_algo
=
args
.
fp8_amax_compute_algo
,
override_linear_precision
=
(
False
,
False
,
not
args
.
fp8_wgrad
),
)
self
.
num_microbatches_in_previous_step
=
-
1
self
.
microbatch_count
=
0
self
.
checkpoint_core_attention
=
config
.
recompute_granularity
==
'selective'
# Number of layers.
self
.
num_layers
=
_get_num_layers
(
args
,
model_type
,
layer_type
==
LayerType
.
decoder
)
self
.
drop_path_rates
=
[
rate
.
item
()
for
rate
in
torch
.
linspace
(
0
,
self
.
drop_path_rate
,
config
.
num_layers
)]
self
.
retro_layer_numbers
=
None
if
model_type
==
ModelType
.
retro_decoder
:
retro_layer_start
=
6
if
config
.
num_layers
<=
15
else
9
self
.
retro_layer_numbers
=
\
np
.
arange
(
retro_layer_start
,
args
.
num_layers
+
1
,
3
).
tolist
()
if
model_type
==
ModelType
.
retro_encoder
:
self
.
retro_layer_numbers
=
[
1
]
# Transformer layers.
if
args
.
retro_add_retriever
:
assert
self
.
recompute_granularity
!=
'full'
,
\
"Full recompute not supported for Retro."
assert
args
.
transformer_impl
==
'local'
,
\
"Transformer engine does not support Retro layers."
def
build_layer
(
layer_number
):
if
args
.
transformer_impl
==
'local'
:
current_layer_type
=
_get_layer_type
(
model_type
,
layer_type
,
self
.
retro_layer_numbers
,
layer_number
)
return
ParallelTransformerLayer
(
config
,
layer_number
,
layer_type
=
current_layer_type
,
self_attn_mask_type
=
self_attn_mask_type
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
])
else
:
# This argument is only available from TE v0.10 onwards.
extra_transformer_engine_kwargs
=
{}
if
self
.
transformer_engine_v_0_8
:
extra_transformer_engine_kwargs
[
"bias"
]
=
args
.
add_bias_linear
if
self
.
transformer_engine_v_0_10
:
extra_transformer_engine_kwargs
[
"activation"
]
=
"swiglu"
if
args
.
swiglu
else
"gelu"
if
self
.
transformer_engine_v_0_11
:
extra_transformer_engine_kwargs
[
"normalization"
]
=
args
.
normalization
assert
config
.
attention_softmax_in_fp32
,
"TransformerEngine only supports softmax compute in FP32."
assert
(
(
bool
(
int
(
os
.
getenv
(
"NVTE_APPLY_QK_LAYER_SCALING"
,
"0"
)))
and
args
.
fp16
)
==
config
.
apply_query_key_layer_scaling
),
"Unsupported config for apply_query_key_layer_scaling in TransformerEngine."
return
transformer_engine
.
pytorch
.
TransformerLayer
(
config
.
hidden_size
,
config
.
ffn_hidden_size
,
config
.
num_attention_heads
,
layernorm_epsilon
=
config
.
layernorm_epsilon
,
hidden_dropout
=
config
.
hidden_dropout
,
attention_dropout
=
config
.
attention_dropout
,
init_method
=
config
.
init_method
,
output_layer_init_method
=
config
.
output_layer_init_method
,
layer_number
=
layer_number
,
kv_channels
=
config
.
kv_channels
,
self_attn_mask_type
=
self_attn_mask_type
.
name
,
tp_group
=
mpu
.
get_tensor_model_parallel_group
(),
get_rng_state_tracker
=
tensor_parallel
.
get_cuda_rng_tracker
,
fuse_wgrad_accumulation
=
config
.
gradient_accumulation_fusion
,
seq_length
=
args
.
seq_length
,
micro_batch_size
=
args
.
micro_batch_size
,
sequence_parallel
=
config
.
sequence_parallel
,
params_dtype
=
config
.
params_dtype
,
apply_residual_connection_post_layernorm
=
config
.
apply_residual_connection_post_layernorm
,
output_layernorm
=
False
,
layer_type
=
"encoder"
,
drop_path_rate
=
self
.
drop_path_rates
[
layer_number
-
1
],
set_parallel_mode
=
True
,
fuse_qkv_params
=
True
,
**
extra_transformer_engine_kwargs
)
if
config
.
virtual_pipeline_model_parallel_size
is
not
None
:
assert
config
.
num_layers
%
config
.
virtual_pipeline_model_parallel_size
==
0
,
\
'num_layers_per_stage must be divisible by '
\
'virtual_pipeline_model_parallel_size'
assert
args
.
model_type
!=
ModelType
.
encoder_and_decoder
# Number of layers in each model chunk is the number of layers in the stage,
# divided by the number of model chunks in a stage.
self
.
num_layers
=
self
.
num_layers
//
config
.
virtual_pipeline_model_parallel_size
# With 8 layers, 2 stages, and 4 model chunks, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0] [2] [4] [6]
# Stage 1: [1] [3] [5] [7]
# With 8 layers, 2 stages, and 2 virtual stages, we want an assignment of
# layers to stages like (each list is a model chunk):
# Stage 0: [0, 1] [4, 5]
# Stage 1: [2, 3] [6, 7]
offset
=
mpu
.
get_virtual_pipeline_model_parallel_rank
()
*
(
config
.
num_layers
//
config
.
virtual_pipeline_model_parallel_size
)
+
\
(
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
)
else
:
# Each stage gets a contiguous set of layers.
if
args
.
model_type
==
ModelType
.
encoder_and_decoder
and
\
mpu
.
get_pipeline_model_parallel_world_size
()
>
1
:
pipeline_rank
=
mpu
.
get_pipeline_model_parallel_rank
()
if
layer_type
==
LayerType
.
encoder
:
offset
=
pipeline_rank
*
self
.
num_layers
else
:
num_ranks_in_enc
=
args
.
pipeline_model_parallel_split_rank
offset
=
(
pipeline_rank
-
num_ranks_in_enc
)
*
self
.
num_layers
else
:
offset
=
mpu
.
get_pipeline_model_parallel_rank
()
*
self
.
num_layers
if
self
.
num_layers
==
0
:
# When a standalone embedding stage is used (e.g.,
# args.standalone_embedding_stage == True), virtual pipeline ranks
# on pipeline rank 0 will have zero transformer layers assigned to
# them. This results in the model's input and output tensors to be
# the same, which will cause failure for certain output tensor
# optimizations (e.g., pipeline output deallocation). To remedy
# this, we assign a 'no-op' layer on these ranks, which will
# disconnect the input tensor from the output tensor.
self
.
num_layers
=
1
self
.
layers
=
torch
.
nn
.
ModuleList
([
NoopTransformerLayer
(
1
)
])
else
:
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
# Update dropout rate for Retro encoder.
if
model_type
==
ModelType
.
retro_encoder
:
for
layer
in
self
.
layers
:
if
layer
.
self_attention
.
use_flash_attn
:
layer
.
self_attention
.
core_attention_flash
.
dropout_p
=
\
torch
.
nn
.
Dropout
(
args
.
retro_encoder_attention_dropout
)
else
:
layer
.
self_attention
.
core_attention
.
attention_dropout
.
p
=
\
args
.
retro_encoder_attention_dropout
layer
.
hidden_dropout
=
args
.
retro_encoder_hidden_dropout
if
self
.
post_process
and
self
.
post_norm
:
# Final layer norm before output.
self
.
final_norm
=
get_norm
(
config
)
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
def
_checkpointed_forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
,
is_first_microbatch
):
"""Forward method with activation checkpointing."""
def
custom
(
start
,
end
):
def
custom_forward
(
*
args
,
**
kwargs
):
x_
,
*
args
=
args
for
index
in
range
(
start
,
end
):
layer
=
self
.
_get_layer
(
index
)
x_
=
layer
(
x_
,
*
args
,
**
kwargs
)
return
x_
return
custom_forward
te_forward_kwargs
=
{}
if
self
.
transformer_impl
==
'transformer_engine'
:
te_forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
if
self
.
transformer_engine_v_0_10
:
te_forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
if
self
.
recompute_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and
# checkpoint the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
while
l
<
self
.
num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
self
.
recompute_num_layers
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
l
+=
self
.
recompute_num_layers
elif
self
.
recompute_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
recompute_num_layers
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
transformer_engine
.
pytorch
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
tensor_parallel
.
get_cuda_rng_tracker
,
mpu
.
get_tensor_model_parallel_group
(),
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
hidden_states
=
tensor_parallel
.
checkpoint
(
custom
(
l
,
l
+
1
),
self
.
distribute_saved_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
else
:
if
self
.
transformer_impl
==
'transformer_engine'
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
**
te_forward_kwargs
)
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
None
,
None
,
None
,
None
,
rotary_pos_emb
)
else
:
raise
ValueError
(
"Invalid activation recompute method."
)
return
hidden_states
def
set_input_tensor
(
self
,
input_tensor
):
"""Set input tensor to be used instead of forward()'s input.
When doing pipeline parallelism the input from the previous
stage comes from communication, not from the input, so the
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
retriever_input
=
None
,
retriever_output
=
None
,
retriever_attn_mask
=
None
,
inference_params
=
None
,
rotary_pos_emb
=
None
):
# hidden_states: [s, b, h]
# Checks.
if
inference_params
:
assert
self
.
recompute_granularity
is
None
,
\
'inference does not work with activation checkpointing'
if
not
self
.
pre_process
:
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
# Viewless tensor.
# - We only need to create a viewless tensor in the case of micro batch
# size (mbs) == 1, since in this case, 'hidden_states.transpose()'
# above creates a view tensor, and '.contiguous()' is a pass-through.
# For mbs >= 2, '.contiguous()' creates a new tensor, eliminating
# the need to make it viewless.
#
# However, we don't explicitly check mbs == 1 here because
# make_viewless_tensor() has negligible overhead when its input
# is already viewless.
#
# - For the 'else' case above, calling make_viewless_tensor() here is
# likely redundant, since p2p_communication.py (likely originator)
# already creates viewless tensors. That said, make_viewless_tensor()
# is called here to be future-proof and corner-case-proof.
hidden_states
=
core
.
utils
.
make_viewless_tensor
(
hidden_states
,
requires_grad
=
True
,
keep_graph
=
True
,
)
# RNG context.
if
self
.
sequence_parallel
:
rng_context
=
tensor_parallel
.
get_cuda_rng_tracker
().
fork
()
else
:
rng_context
=
nullcontext
()
# Forward layers.
with
rng_context
:
# The fp8_autocast context manager is a no-op when enabled=True
# The if...else serves to short circuit name resolution for fp8_autocast
with
transformer_engine
.
pytorch
.
fp8_autocast
(
enabled
=
self
.
use_fp8
,
fp8_recipe
=
self
.
fp8_recipe
,
fp8_group
=
self
.
fp8_group
)
if
self
.
use_fp8
else
nullcontext
():
# Determine if the current iteration is first microbatch
if
self
.
num_microbatches_in_previous_step
!=
get_num_microbatches
():
self
.
microbatch_count
=
0
# Reset count on new batch size rampup interval
self
.
num_microbatches_in_previous_step
=
get_num_microbatches
()
is_first_microbatch
=
self
.
microbatch_count
%
get_num_microbatches
()
==
0
# Forward pass.
if
self
.
recompute_granularity
==
'full'
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
,
rotary_pos_emb
,
is_first_microbatch
)
else
:
forward_kwargs
=
{
'encoder_output'
:
encoder_output
,
'enc_dec_attn_mask'
:
enc_dec_attn_mask
,
'inference_params'
:
inference_params
,
}
if
self
.
transformer_impl
==
'transformer_engine'
:
forward_kwargs
[
'is_first_microbatch'
]
=
is_first_microbatch
forward_kwargs
[
'checkpoint_core_attention'
]
=
self
.
checkpoint_core_attention
if
self
.
transformer_engine_v_0_10
:
forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
else
:
forward_kwargs
[
'rotary_pos_emb'
]
=
rotary_pos_emb
forward_kwargs
[
'retriever_input'
]
=
retriever_input
forward_kwargs
[
'retriever_output'
]
=
retriever_output
forward_kwargs
[
'retriever_attn_mask'
]
=
retriever_attn_mask
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
**
forward_kwargs
)
# First Retro decoder layer returns both hidden_states
# and retriever_output. Make retriever_output available
# to subsequence Retro layers.
if
isinstance
(
hidden_states
,
tuple
):
assert
len
(
hidden_states
)
==
2
hidden_states
,
retriever_output
=
hidden_states
forward_kwargs
[
"retriever_output"
]
=
retriever_output
# Skip counter update for eval and activation checkpointing
if
torch
.
is_grad_enabled
()
and
self
.
training
:
self
.
microbatch_count
+=
1
# Final layer norm.
if
self
.
post_process
and
self
.
post_norm
:
hidden_states
=
self
.
final_norm
(
hidden_states
)
return
hidden_states
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customize load."""
# Handle renaming layernorm -> norm in component names
state_dict_
=
{}
for
key
in
state_dict
.
keys
():
# Bypass TransformerEngine module parameters.
if
"layernorm_qkv"
in
key
or
"layernorm_mlp"
in
key
:
state_dict_
[
key
]
=
state_dict
[
key
]
continue
newkey
=
key
.
replace
(
"layernorm"
,
"norm"
)
state_dict_
[
newkey
]
=
state_dict
[
key
]
super
().
load_state_dict
(
state_dict_
,
strict
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/utils.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""Utilities for models."""
import
math
import
torch
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
,
RMSNorm
from
megatron.core.jit
import
jit_fuser
def
init_method_normal
(
sigma
):
"""Init method based on N(0, sigma)."""
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
sigma
)
return
init_
def
scaled_init_method_normal
(
sigma
,
num_layers
):
"""Init method based on N(0, sigma/sqrt(2*num_layers)."""
std
=
sigma
/
math
.
sqrt
(
2.0
*
num_layers
)
def
init_
(
tensor
):
return
torch
.
nn
.
init
.
normal_
(
tensor
,
mean
=
0.0
,
std
=
std
)
return
init_
def
attention_mask_func
(
attention_scores
,
attention_mask
):
attention_scores
.
masked_fill_
(
attention_mask
,
-
10000.0
)
return
attention_scores
def
get_linear_layer
(
rows
,
columns
,
init_method
):
"""Simple linear layer with weight initialization."""
layer
=
torch
.
nn
.
Linear
(
rows
,
columns
)
if
get_args
().
perform_initialization
:
init_method
(
layer
.
weight
)
with
torch
.
no_grad
():
layer
.
bias
.
zero_
()
return
layer
@
jit_fuser
def
gelu_impl
(
x
):
"""OpenAI's gelu implementation."""
return
0.5
*
x
*
(
1.0
+
torch
.
tanh
(
0.7978845608028654
*
x
*
(
1.0
+
0.044715
*
x
*
x
)))
def
openai_gelu
(
x
):
return
gelu_impl
(
x
)
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
@
jit_fuser
def
erf_gelu
(
x
):
return
x
*
0.5
*
(
torch
.
erf
(
x
/
1.41421
).
to
(
dtype
=
x
.
dtype
)
+
torch
.
ones_like
(
x
).
to
(
dtype
=
x
.
dtype
))
def
get_norm
(
config
):
args
=
get_args
()
if
args
.
normalization
==
"LayerNorm"
:
return
LayerNorm
(
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
no_persist_layer_norm
=
not
config
.
persist_layer_norm
,
sequence_parallel
=
config
.
sequence_parallel
,
apply_layernorm_1p
=
args
.
apply_layernorm_1p
)
elif
args
.
normalization
==
"RMSNorm"
:
if
args
.
apply_layernorm_1p
:
raise
NotImplementedError
(
'RMSNorm does not currently support the layernorm_1p formulation.'
)
return
RMSNorm
(
dim
=
config
.
hidden_size
,
eps
=
config
.
layernorm_epsilon
,
sequence_parallel
=
config
.
sequence_parallel
)
else
:
raise
Exception
(
f
"unsupported norm type '
{
args
.
normalization
}
'."
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/classification.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision Transformer(VIT) model."""
import
torch
from
torch.nn.init
import
trunc_normal_
from
megatron.training
import
get_args
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.legacy.model.vision.mit_backbone
import
mit_b3_avg
from
megatron.legacy.model.module
import
MegatronModule
class
VitClassificationModel
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
config
,
num_classes
,
finetune
=
False
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
config
=
config
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
finetune
=
finetune
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
backbone
=
VitBackbone
(
config
=
config
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
single_token_output
=
True
)
if
self
.
post_process
:
if
not
self
.
finetune
:
self
.
head
=
VitMlpHead
(
config
,
self
.
hidden_size
,
self
.
num_classes
)
else
:
self
.
head
=
get_linear_layer
(
self
.
hidden_size
,
self
.
num_classes
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
self
.
post_process
:
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
class
MitClassificationModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
mit_b3_avg
()
self
.
head
=
torch
.
nn
.
Linear
(
512
,
num_classes
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/dino.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import
math
import
apex
import
einops
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
megatron.training
import
get_args
,
print_rank_0
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.vision.vit_backbone
import
VitBackbone
from
megatron.legacy.model.module
import
MegatronModule
from
megatron.legacy.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.legacy.model.vision.esvit_swin_backbone
import
get_swin
class
DINOLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_dim
,
ncrops
,
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
,
nepochs
,
student_temp
=
0.1
,
center_momentum
=
0.9
):
super
().
__init__
()
self
.
student_temp
=
student_temp
self
.
center_momentum
=
center_momentum
self
.
ncrops
=
ncrops
self
.
register_buffer
(
"center"
,
torch
.
zeros
(
1
,
out_dim
))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self
.
teacher_temp_schedule
=
np
.
concatenate
((
np
.
linspace
(
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
),
np
.
ones
(
nepochs
-
warmup_teacher_temp_epochs
)
*
teacher_temp
))
self
.
teacher_temp
=
teacher_temp
def
forward
(
self
,
student_output
,
teacher_output
,
iteration
):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args
=
get_args
()
student_out
=
student_output
/
self
.
student_temp
student_out
=
student_out
.
chunk
(
self
.
ncrops
)
epoch
=
iteration
//
args
.
iter_per_epoch
# teacher centering and sharpening
temp
=
self
.
teacher_temp_schedule
[
epoch
]
teacher_out
=
F
.
softmax
((
teacher_output
-
self
.
center
)
/
temp
,
dim
=-
1
)
teacher_out
=
teacher_out
.
detach
().
chunk
(
2
)
total_loss
=
0
n_loss_terms
=
0
for
iq
,
q
in
enumerate
(
teacher_out
):
for
v
in
range
(
len
(
student_out
)):
if
v
==
iq
:
# we skip cases where student and teacher operate on the same view
continue
loss
=
torch
.
sum
(
-
q
*
F
.
log_softmax
(
student_out
[
v
],
dim
=-
1
),
dim
=-
1
)
total_loss
+=
loss
.
mean
()
n_loss_terms
+=
1
total_loss
/=
n_loss_terms
self
.
update_center
(
teacher_output
)
return
total_loss
@
torch
.
no_grad
()
def
update_center
(
self
,
teacher_output
):
"""
Update center used for teacher output.
"""
batch_center
=
torch
.
sum
(
teacher_output
,
dim
=
0
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
batch_center
)
batch_center
=
batch_center
/
(
len
(
teacher_output
)
*
torch
.
distributed
.
get_world_size
())
self
.
center
=
self
.
center
*
self
.
center_momentum
+
batch_center
*
(
1
-
self
.
center_momentum
)
class
DINOHead
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
norm_last_layer
=
True
,
nlayers
=
3
):
super
().
__init__
()
args
=
get_args
()
hidden_dim
=
args
.
dino_head_hidden_size
bottleneck_dim
=
args
.
dino_bottleneck_size
nlayers
=
max
(
nlayers
,
1
)
if
nlayers
==
1
:
self
.
mlp
=
torch
.
nn
.
Linear
(
in_dim
,
bottleneck_dim
)
else
:
layers
=
[
torch
.
nn
.
Linear
(
in_dim
,
hidden_dim
)]
layers
.
append
(
torch
.
nn
.
GELU
())
for
_
in
range
(
nlayers
-
2
):
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
))
layers
.
append
(
torch
.
nn
.
GELU
())
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
bottleneck_dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
*
layers
)
self
.
apply
(
self
.
_init_weights
)
self
.
last_layer
=
torch
.
nn
.
utils
.
weight_norm
(
torch
.
nn
.
Linear
(
bottleneck_dim
,
out_dim
,
bias
=
False
))
self
.
last_layer
.
weight_g
.
data
.
fill_
(
1
)
if
norm_last_layer
:
self
.
last_layer
.
weight_g
.
requires_grad
=
False
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
x
=
self
.
mlp
(
x
)
x
=
torch
.
nn
.
functional
.
normalize
(
x
,
dim
=-
1
,
p
=
2
)
x
=
self
.
last_layer
(
x
)
return
x
class
MultiCropWrapper
(
MegatronModule
):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def
__init__
(
self
,
backbone
,
head
):
super
(
MultiCropWrapper
,
self
).
__init__
()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self
.
backbone
=
backbone
self
.
head
=
head
def
forward
(
self
,
x
):
# convert to list
if
not
isinstance
(
x
,
list
):
x
=
[
x
]
idx_crops
=
torch
.
cumsum
(
torch
.
unique_consecutive
(
torch
.
tensor
([
inp
.
shape
[
-
1
]
for
inp
in
x
]),
return_counts
=
True
,
)[
1
],
0
)
start_idx
=
0
for
end_idx
in
idx_crops
:
_out
=
self
.
backbone
(
torch
.
cat
(
x
[
start_idx
:
end_idx
]))
if
start_idx
==
0
:
output
=
_out
else
:
output
=
torch
.
cat
((
output
,
_out
))
start_idx
=
end_idx
# Run the head forward on the concatenated features.
if
self
.
training
:
return
self
.
head
(
output
)
else
:
return
output
def
cosine_scheduler
(
base_value
,
final_value
,
epochs
,
niter_per_ep
,
warmup_epochs
=
0
,
start_warmup_value
=
0
):
warmup_schedule
=
np
.
array
([])
warmup_iters
=
warmup_epochs
*
niter_per_ep
if
warmup_epochs
>
0
:
warmup_schedule
=
\
np
.
linspace
(
start_warmup_value
,
base_value
,
warmup_iters
)
iters
=
np
.
arange
(
epochs
*
niter_per_ep
-
warmup_iters
)
schedule
=
final_value
+
0.5
*
(
base_value
-
final_value
)
\
*
(
1
+
np
.
cos
(
np
.
pi
*
iters
/
len
(
iters
)))
schedule
=
np
.
concatenate
((
warmup_schedule
,
schedule
))
assert
len
(
schedule
)
==
epochs
*
niter_per_ep
return
schedule
def
get_student_backbone_and_num_features
(
config
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
student
=
VitBackbone
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
drop_path_rate
=
0.1
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
student
=
mit_b5_avg
(
drop_path_rate
=
0.1
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
student
=
get_swin
()
num_features
=
student
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
student
,
num_features
def
get_teacher_backbone_and_num_features
(
config
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
teacher
=
VitBackbone
(
config
,
pre_process
=
pre_process
,
post_process
=
post_process
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
teacher
=
mit_b5_avg
(
drop_path_rate
=
0.0
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
teacher
=
get_swin
(
is_teacher
=
True
)
num_features
=
teacher
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
teacher
,
num_features
class
DINOPretrainModel
(
MegatronModule
):
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
):
super
(
DINOPretrainModel
,
self
).
__init__
()
args
=
get_args
()
self
.
config
=
config
self
.
out_dim
=
65536
self
.
dino_loss
=
DINOLoss
(
self
.
out_dim
,
args
.
dino_local_crops_number
+
2
,
args
.
dino_warmup_teacher_temp
,
args
.
dino_teacher_temp
,
args
.
dino_warmup_teacher_temp_epochs
,
300
,
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
momentum_teacher
=
0.996
student_backbone
,
num_features
=
\
get_student_backbone_and_num_features
(
config
,
pre_process
,
post_process
)
self
.
student
=
MultiCropWrapper
(
student_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
,
norm_last_layer
=
args
.
dino_norm_last_layer
)
)
self
.
momentum_schedule
=
cosine_scheduler
(
self
.
momentum_teacher
,
1
,
args
.
train_iters
//
args
.
iter_per_epoch
,
args
.
iter_per_epoch
)
teacher_backbone
,
num_features
=
\
get_teacher_backbone_and_num_features
(
config
,
pre_process
,
post_process
)
self
.
teacher
=
MultiCropWrapper
(
teacher_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
)
)
self
.
teacher
.
load_state_dict
(
self
.
student
.
state_dict
())
for
p
in
self
.
teacher
.
parameters
():
if
hasattr
(
p
,
"requires_grad"
)
and
p
.
requires_grad
is
not
None
:
p
.
requires_grad
=
False
def
set_input_tensor
(
self
,
tensor
):
pass
def
forward
(
self
,
input
):
student_output
=
None
if
self
.
training
:
student_output
=
self
.
student
(
input
)
teacher_output
=
self
.
teacher
(
input
[:
2
])
else
:
teacher_output
=
self
.
teacher
(
input
)
return
student_output
,
teacher_output
def
cancel_gradients_last_layer
(
self
,
iteration
):
args
=
get_args
()
epoch
=
iteration
//
args
.
iter_per_epoch
if
epoch
<
args
.
dino_freeze_last_layer
:
for
n
,
p
in
self
.
student
.
named_parameters
():
if
"last_layer"
in
n
:
p
.
grad
=
None
def
update_momentum
(
self
,
iteration
):
with
torch
.
no_grad
():
m
=
self
.
momentum_schedule
[
iteration
]
for
param_q
,
param_k
in
zip
(
self
.
student
.
parameters
(),
self
.
teacher
.
parameters
()):
param_k
.
data
.
mul_
(
m
).
add_
((
1
-
m
)
*
param_q
.
detach
().
data
)
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/esvit_swin_backbone.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Modified by Chunyuan Li (chunyl@microsoft.com)
# Swin Transformer
# --------------------------------------------------------
import
os
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
torch.distributed
as
dist
from
torch.nn.init
import
trunc_normal_
from
megatron.legacy.model.transformer
import
DropPath
from
megatron.training
import
get_args
from
megatron.legacy.model
import
LayerNorm
import
numpy
as
np
from
math
import
sqrt
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
(
Mlp
,
self
).
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
(
WindowAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2 Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
).
type
(
attn
.
type
())
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn_out
=
attn
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
,
attn_out
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
@
staticmethod
def
compute_macs
(
module
,
input
,
output
):
B
,
N
,
C
=
input
[
0
].
shape
module
.
__flops__
+=
module
.
flops
(
N
)
*
B
class
SwinTransformerBlock
(
nn
.
Module
):
r
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
(
self
.
window_size
,
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of window size
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
_
,
Hp
,
Wp
,
_
=
x
.
shape
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
if
H
in
self
.
attn_mask_dict
.
keys
():
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
self
.
attn_mask_dict
[
H
]
=
self
.
create_attn_mask
(
self
.
H
,
self
.
W
).
to
(
x
.
device
)
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
shifted_x
=
x
attn_mask
=
None
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
,
attn
=
self
.
attn
(
x_windows
,
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
Hp
,
Wp
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
,
attn
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# padding
pad_input
=
(
H
%
2
==
1
)
or
(
W
%
2
==
1
)
if
pad_input
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
W
%
2
,
0
,
H
%
2
))
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
def
forward_with_features
(
self
,
x
):
fea
=
[]
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
fea
.
append
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
fea
def
forward_with_attention
(
self
,
x
):
attns
=
[]
for
blk
in
self
.
blocks
:
x
,
attn
=
blk
(
x
)
attns
.
append
(
attn
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
attns
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
)
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
# todo: to be implemented
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_region
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_region
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
def
forward_feature_maps
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_grid
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_grid
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
,
x_grid
def
forward_selfattention
(
self
,
x
,
n
=
1
):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
if
n
==
1
:
return
self
.
forward_last_selfattention
(
x
)
else
:
return
self
.
forward_all_selfattention
(
x
)
def
forward_last_selfattention
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
<
len
(
self
.
layers
)
-
1
:
x
=
layer
(
x
)
else
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
return
attns
[
-
1
]
def
forward_all_selfattention
(
self
,
x
):
attn_out
=
[]
for
layer
in
self
.
layers
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
attn_out
+=
attns
return
attn_out
def
forward_return_n_last_blocks
(
self
,
x
,
n
=
1
,
return_patch_avgpool
=
False
,
depth
=
[]):
num_blks
=
sum
(
depth
)
start_idx
=
num_blks
-
n
sum_cur
=
0
for
i
,
d
in
enumerate
(
depth
):
sum_cur_new
=
sum_cur
+
d
if
start_idx
>=
sum_cur
and
start_idx
<
sum_cur_new
:
start_stage
=
i
start_blk
=
start_idx
-
sum_cur
sum_cur
=
sum_cur_new
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output
=
[]
s
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
,
fea
=
layer
.
forward_with_features
(
x
)
if
i
>=
start_stage
:
for
x_
in
fea
[
start_blk
:]:
if
i
==
len
(
self
.
layers
)
-
1
:
# use the norm in the last stage
x_
=
self
.
norm
(
x_
)
x_avg
=
torch
.
flatten
(
self
.
avgpool
(
x_
.
transpose
(
1
,
2
)),
1
)
# B C
# print(f'Stage {i}, x_avg {x_avg.shape}')
output
.
append
(
x_avg
)
start_blk
=
0
return
torch
.
cat
(
output
,
dim
=-
1
)
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
if
dist
.
get_rank
()
==
0
:
print
(
f
"GFLOPs layer_
{
i
}
:
{
layer
.
flops
()
/
1e9
}
"
)
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
init_weights
(
self
,
pretrained
=
''
,
pretrained_layers
=
[],
verbose
=
True
):
if
os
.
path
.
isfile
(
pretrained
):
pretrained_dict
=
torch
.
load
(
pretrained
,
map_location
=
'cpu'
)
logging
.
info
(
f
'=> loading pretrained model
{
pretrained
}
'
)
model_dict
=
self
.
state_dict
()
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
k
in
model_dict
.
keys
()
}
need_init_state_dict
=
{}
for
k
,
v
in
pretrained_dict
.
items
():
need_init
=
(
k
.
split
(
'.'
)[
0
]
in
pretrained_layers
or
pretrained_layers
[
0
]
is
'*'
or
'relative_position_index'
not
in
k
or
'attn_mask'
not
in
k
)
if
need_init
:
if
verbose
:
logging
.
info
(
f
'=> init
{
k
}
from
{
pretrained
}
'
)
if
'relative_position_bias_table'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
relative_position_bias_table_pretrained
=
v
relative_position_bias_table_current
=
model_dict
[
k
]
L1
,
nH1
=
relative_position_bias_table_pretrained
.
size
()
L2
,
nH2
=
relative_position_bias_table_current
.
size
()
if
nH1
!=
nH2
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
L1
,
nH1
),
(
L2
,
nH2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
relative_position_bias_table_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
relative_position_bias_table_pretrained
.
permute
(
1
,
0
).
view
(
1
,
nH1
,
S1
,
S1
),
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
relative_position_bias_table_pretrained_resized
.
view
(
nH2
,
L2
).
permute
(
1
,
0
)
if
'absolute_pos_embed'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
absolute_pos_embed_pretrained
=
v
absolute_pos_embed_current
=
model_dict
[
k
]
_
,
L1
,
C1
=
absolute_pos_embed_pretrained
.
size
()
_
,
L2
,
C2
=
absolute_pos_embed_current
.
size
()
if
C1
!=
C1
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
1
,
L1
,
C1
),
(
1
,
L2
,
C2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
reshape
(
-
1
,
S1
,
S1
,
C1
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
permute
(
0
,
3
,
1
,
2
)
absolute_pos_embed_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
absolute_pos_embed_pretrained
,
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
absolute_pos_embed_pretrained_resized
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
1
,
2
)
need_init_state_dict
[
k
]
=
v
self
.
load_state_dict
(
need_init_state_dict
,
strict
=
False
)
def
freeze_pretrained_layers
(
self
,
frozen_layers
=
[]):
for
name
,
module
in
self
.
named_modules
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
'.'
.
join
(
name
.
split
(
'.'
)[
0
:
2
])
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
):
for
_name
,
param
in
module
.
named_parameters
():
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
for
name
,
param
in
self
.
named_parameters
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
and
param
.
requires_grad
is
True
):
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
return
self
def
get_swin
(
is_teacher
=
False
):
args
=
get_args
()
if
args
.
swin_backbone_type
==
"tiny"
:
embed_dim
=
96
depths
=
[
2
,
2
,
6
,
2
]
num_heads
=
[
3
,
6
,
12
,
24
]
drop_path_rate
=
0.1
elif
args
.
swin_backbone_type
==
'h3'
:
embed_dim
=
384
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
6
,
12
,
24
,
48
]
drop_path_rate
=
0.2
else
:
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
drop_path_rate
=
0.2
swin
=
SwinTransformer
(
img_size
=
224
,
in_chans
=
3
,
num_classes
=
1000
,
patch_size
=
4
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
7
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
drop_rate
=
0
,
attn_drop_rate
=
0
,
drop_path_rate
=
(
0.0
if
is_teacher
else
drop_path_rate
),
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
)
return
swin
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/inpainting.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
apex
import
einops
import
torch
import
torch.nn.functional
as
F
from
megatron.training
import
get_args
,
print_rank_0
from
megatron.legacy.model.utils
import
get_linear_layer
from
megatron.legacy.model.vision.vit_backbone
import
VitBackbone
from
megatron.legacy.model.module
import
MegatronModule
from
megatron.legacy.model.vision.mit_backbone
import
mit_b3
from
megatron.legacy.model.vision.utils
import
resize
class
VitInpaintingModel
(
MegatronModule
):
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitInpaintingModel
,
self
).
__init__
()
args
=
get_args
()
self
.
config
=
config
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
config
.
hidden_size
self
.
backbone
=
VitBackbone
(
config
=
config
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
class_token
=
False
,
)
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
seq_length
=
args
.
seq_length
# full mask
if
self
.
post_process
:
self
.
linear_decoder
=
get_linear_layer
(
self
.
hidden_size
,
self
.
backbone
.
flatten_dim
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
not
self
.
post_process
:
return
hidden_states
decoded_output
=
self
.
linear_decoder
(
hidden_states
)
output
=
einops
.
rearrange
(
decoded_output
,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
MitInpaintingModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitInpaintingModel
,
self
).
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
args
=
get_args
()
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
3
self
.
backbone
=
mit_b3
()
self
.
in_channels
=
[
64
,
128
,
320
,
512
]
self
.
embedding_dim
=
768
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
,
bias
=
False
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
0.1
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
flatten_dim
,
kernel_size
=
1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.legacy.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
c1
,
c2
,
c3
,
c4
=
self
.
backbone
(
input
)
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
)
_c
=
self
.
conv_fuse
(
_c
)
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
output
=
einops
.
rearrange
(
x
,
"b (c p1 p2) h w -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/knn_monitor.py
0 → 100644
View file @
bc5c7fa7
import
torch.nn.functional
as
F
import
torch
from
megatron.training
import
print_rank_0
,
get_args
from
megatron.core
import
mpu
from
megatron.legacy.data.vit_dataset
import
ClassificationTransform
from
megatron.legacy.data.image_folder
import
ImageFolder
_FEATURE_BANK
=
None
def
build_data_loader
(
dataset
,
drop_last
=
True
,
shuffle
=
False
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args
=
get_args
()
micro_batch_size
=
16
num_workers
=
args
.
num_workers
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
micro_batch_size
,
sampler
=
sampler
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
not
drop_last
,
pin_memory
=
True
,
)
return
data_loader
def
compute_feature_bank
(
model
):
args
=
get_args
()
global
_FEATURE_BANK
feature_bank
=
[]
feature_label
=
[]
train_ds
=
ImageFolder
(
root
=
args
.
data_path
[
0
],
transform
=
ClassificationTransform
((
args
.
img_h
,
args
.
img_w
),
train
=
False
),
data_per_class_fraction
=
1.0
)
classes
=
len
(
train_ds
.
classes
)
dataloader
=
build_data_loader
(
train_ds
)
for
m
in
model
:
m
.
eval
()
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
dataloader
):
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
student_feature
,
teacher_feature
=
model
[
0
](
images
)
feature
=
F
.
normalize
(
teacher_feature
.
float
(),
dim
=
1
)
feature_bank
.
append
(
feature
)
feature_label
.
append
(
labels
)
for
m
in
model
:
m
.
train
()
# [N', D]
feature_bank
=
torch
.
cat
(
feature_bank
,
dim
=
0
).
contiguous
()
feature_label
=
torch
.
cat
(
feature_label
,
dim
=
0
).
contiguous
()
feature_banks
=
[
torch
.
zeros_like
(
feature_bank
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_banks
,
feature_bank
,
group
=
mpu
.
get_data_parallel_group
())
assert
torch
.
all
(
torch
.
eq
(
feature_banks
[
mpu
.
get_data_parallel_rank
()],
feature_bank
))
feature_labels
=
[
torch
.
zeros_like
(
feature_label
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_labels
,
feature_label
,
group
=
mpu
.
get_data_parallel_group
())
# [D, N]
feature_banks
=
torch
.
cat
(
feature_banks
,
dim
=
0
).
t
().
contiguous
()
# [N]
feature_labels
=
torch
.
cat
(
feature_labels
,
dim
=
0
).
contiguous
()
print_rank_0
(
"feature_banks size is {}"
.
format
(
feature_banks
.
size
()))
print_rank_0
(
"feature labels size is {}"
.
format
(
feature_labels
.
size
()))
_FEATURE_BANK
=
(
feature_banks
,
feature_labels
,
classes
)
def
get_feature_bank
():
global
_FEATURE_BANK
assert
_FEATURE_BANK
is
not
None
return
_FEATURE_BANK
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def
knn_predict
(
feature
,
feature_bank
,
feature_labels
,
classes
,
knn_k
,
knn_t
):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix
=
torch
.
mm
(
feature
,
feature_bank
)
# [B, K]
sim_weight
,
sim_indices
=
sim_matrix
.
topk
(
k
=
knn_k
,
dim
=-
1
)
# [B, K]
sim_labels
=
torch
.
gather
(
feature_labels
.
expand
(
feature
.
size
(
0
),
-
1
),
dim
=-
1
,
index
=
sim_indices
)
sim_weight
=
(
sim_weight
/
knn_t
).
exp
()
# counts for each class
one_hot_label
=
torch
.
zeros
(
feature
.
size
(
0
)
*
knn_k
,
classes
,
device
=
sim_labels
.
device
)
# [B*K, C]
one_hot_label
=
one_hot_label
.
scatter
(
dim
=-
1
,
index
=
sim_labels
.
view
(
-
1
,
1
),
value
=
1.0
)
# weighted score ---> [B, C]
pred_scores
=
torch
.
sum
(
one_hot_label
.
view
(
feature
.
size
(
0
),
-
1
,
classes
)
*
sim_weight
.
unsqueeze
(
dim
=-
1
),
dim
=
1
)
pred_labels
=
pred_scores
.
argsort
(
dim
=-
1
,
descending
=
True
)
return
pred_labels
Megatron-LM-core_r0.7.0.beta/megatron/legacy/model/vision/mit_backbone.py
0 → 100644
View file @
bc5c7fa7
# Copyright (c) 2023, NVIDIA Corporation. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
torch.nn.init
import
trunc_normal_
from
megatron.legacy.model.transformer
import
DropPath
from
megatron.legacy.model
import
LayerNorm
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
dwconv
=
DWConv
(
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
,
H
,
W
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
sr_ratio
=
1
):
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
f
"dim
{
dim
}
should be divided by num_heads
{
num_heads
}
."
self
.
dim
=
dim
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
q
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
self
.
sr
=
nn
.
Conv2d
(
dim
,
dim
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
self
.
norm
=
LayerNorm
(
dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x_
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
H
,
W
)
x_
=
self
.
sr
(
x_
).
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
x_
=
self
.
norm
(
x_
)
kv
=
self
.
kv
(
x_
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
else
:
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
sr_ratio
=
1
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
sr_ratio
=
sr_ratio
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
H
,
W
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
),
H
,
W
))
return
x
class
OverlapPatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
(
patch_size
[
0
]
//
2
,
patch_size
[
1
]
//
2
))
self
.
norm
=
LayerNorm
(
embed_dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
return
x
,
H
,
W
class
MixVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
,
512
],
num_heads
=
[
1
,
2
,
4
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
LayerNorm
,
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
output_avg
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
depths
=
depths
self
.
output_avg
=
output_avg
# patch_embed
self
.
patch_embed1
=
OverlapPatchEmbed
(
img_size
=
img_size
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
in_chans
,
embed_dim
=
embed_dims
[
0
])
self
.
patch_embed2
=
OverlapPatchEmbed
(
img_size
=
img_size
//
4
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
0
],
embed_dim
=
embed_dims
[
1
])
self
.
patch_embed3
=
OverlapPatchEmbed
(
img_size
=
img_size
//
8
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
1
],
embed_dim
=
embed_dims
[
2
])
self
.
patch_embed4
=
OverlapPatchEmbed
(
img_size
=
img_size
//
16
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
2
],
embed_dim
=
embed_dims
[
3
])
# transformer encoder
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
cur
=
0
self
.
block1
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
0
],
num_heads
=
num_heads
[
0
],
mlp_ratio
=
mlp_ratios
[
0
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
0
])
for
i
in
range
(
depths
[
0
])])
self
.
norm1
=
norm_layer
(
embed_dims
[
0
])
cur
+=
depths
[
0
]
self
.
block2
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
1
],
num_heads
=
num_heads
[
1
],
mlp_ratio
=
mlp_ratios
[
1
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
1
])
for
i
in
range
(
depths
[
1
])])
self
.
norm2
=
norm_layer
(
embed_dims
[
1
])
cur
+=
depths
[
1
]
self
.
block3
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
2
],
num_heads
=
num_heads
[
2
],
mlp_ratio
=
mlp_ratios
[
2
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
2
])
for
i
in
range
(
depths
[
2
])])
self
.
norm3
=
norm_layer
(
embed_dims
[
2
])
cur
+=
depths
[
2
]
self
.
block4
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
3
],
num_heads
=
num_heads
[
3
],
mlp_ratio
=
mlp_ratios
[
3
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
3
])
for
i
in
range
(
depths
[
3
])])
self
.
norm4
=
norm_layer
(
embed_dims
[
3
])
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
reset_drop_path
(
self
,
drop_path_rate
):
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))]
cur
=
0
for
i
in
range
(
self
.
depths
[
0
]):
self
.
block1
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
0
]
for
i
in
range
(
self
.
depths
[
1
]):
self
.
block2
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
1
]
for
i
in
range
(
self
.
depths
[
2
]):
self
.
block3
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
2
]
for
i
in
range
(
self
.
depths
[
3
]):
self
.
block4
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
def
freeze_patch_emb
(
self
):
self
.
patch_embed1
.
requires_grad
=
False
def
forward_features
(
self
,
x
):
B
=
x
.
shape
[
0
]
outs
=
[]
# stage 1
x
,
H
,
W
=
self
.
patch_embed1
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block1
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm1
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 2
x
,
H
,
W
=
self
.
patch_embed2
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block2
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm2
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 3
x
,
H
,
W
=
self
.
patch_embed3
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block3
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm3
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 4
x
,
H
,
W
=
self
.
patch_embed4
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block4
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm4
(
x
)
if
not
self
.
output_avg
:
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
return
outs
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
output_avg
:
x
=
x
[
3
].
mean
(
dim
=
1
)
return
x
class
DWConv
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
768
):
super
(
DWConv
,
self
).
__init__
()
self
.
dwconv
=
nn
.
Conv2d
(
dim
,
dim
,
3
,
1
,
1
,
bias
=
True
,
groups
=
dim
)
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
x
=
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
)
x
=
self
.
dwconv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
mit_b0
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b0
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
32
,
64
,
160
,
256
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b1
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b1
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b2
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b2
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b3
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b3_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
class
mit_b4
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b4
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
8
,
27
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b5
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b5_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
Prev
1
…
11
12
13
14
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment