Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b8428a7f
Commit
b8428a7f
authored
Jul 15, 2022
by
peng xu
Browse files
Merge branch 'main' into beam_search
parents
e5034150
3f4e71df
Changes
67
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3226 additions
and
260 deletions
+3226
-260
megatron/model/biencoder_model.py
megatron/model/biencoder_model.py
+1
-1
megatron/model/distributed.py
megatron/model/distributed.py
+27
-6
megatron/model/fused_layer_norm.py
megatron/model/fused_layer_norm.py
+20
-2
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+8
-2
megatron/model/language_model.py
megatron/model/language_model.py
+45
-8
megatron/model/t5_model.py
megatron/model/t5_model.py
+7
-2
megatron/model/transformer.py
megatron/model/transformer.py
+302
-181
megatron/model/vision/classification.py
megatron/model/vision/classification.py
+34
-0
megatron/model/vision/dino.py
megatron/model/vision/dino.py
+288
-0
megatron/model/vision/esvit_swin_backbone.py
megatron/model/vision/esvit_swin_backbone.py
+849
-0
megatron/model/vision/inpainting.py
megatron/model/vision/inpainting.py
+151
-0
megatron/model/vision/knn_monitor.py
megatron/model/vision/knn_monitor.py
+128
-0
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+420
-0
megatron/model/vision/swin_backbone.py
megatron/model/vision/swin_backbone.py
+625
-0
megatron/model/vision/utils.py
megatron/model/vision/utils.py
+27
-0
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+7
-2
megatron/mpu/__init__.py
megatron/mpu/__init__.py
+9
-5
megatron/mpu/initialize.py
megatron/mpu/initialize.py
+12
-5
megatron/mpu/layers.py
megatron/mpu/layers.py
+125
-36
megatron/mpu/mappings.py
megatron/mpu/mappings.py
+141
-10
No files found.
megatron/model/biencoder_model.py
View file @
b8428a7f
...
@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
...
@@ -291,7 +291,7 @@ class PretrainedBertModel(MegatronModule):
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
pool_mask
=
(
input_ids
==
self
.
pad_id
).
unsqueeze
(
2
)
# Taking the representation of the [CLS] token of BERT
# Taking the representation of the [CLS] token of BERT
pooled_output
=
lm_output
[
:
,
0
,
:]
pooled_output
=
lm_output
[
0
,
:
,
:]
# Converting to float16 dtype
# Converting to float16 dtype
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
pooled_output
=
pooled_output
.
to
(
lm_output
.
dtype
)
...
...
megatron/model/distributed.py
View file @
b8428a7f
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
from
abc
import
ABC
from
abc
import
ABC
from
abc
import
abstractmethod
from
abc
import
abstractmethod
import
math
import
torch
import
torch
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
from
torch._utils
import
_flatten_dense_tensors
,
_unflatten_dense_tensors
...
@@ -24,18 +25,17 @@ from megatron import mpu
...
@@ -24,18 +25,17 @@ from megatron import mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
class
MemoryBuffer
:
class
MemoryBuffer
:
def
__init__
(
self
,
numel
,
dtype
):
def
__init__
(
self
,
numel
,
numel_padded
,
dtype
):
self
.
numel
=
numel
self
.
numel
=
numel
self
.
numel_padded
=
numel_padded
self
.
dtype
=
dtype
self
.
dtype
=
dtype
self
.
data
=
torch
.
zeros
(
self
.
numel
,
self
.
data
=
torch
.
zeros
(
self
.
numel
_padded
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
requires_grad
=
False
)
def
zero
(
self
):
def
zero
(
self
):
"""Reset the buffer to zero."""
"""Reset the buffer to zero."""
self
.
data
.
zero_
()
self
.
data
.
zero_
()
...
@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -121,8 +121,11 @@ class DistributedDataParallel(DistributedDataParallelBase):
# the case we use continuous buffers.
# the case we use continuous buffers.
# ===================================
# ===================================
self
.
_grad_buffers
=
None
self
.
_grad_buffers
=
None
self
.
_grad_buffer_param_index_map
=
None
if
self
.
use_contiguous_buffers
:
if
self
.
use_contiguous_buffers
:
self
.
_grad_buffers
=
{}
self
.
_grad_buffers
=
{}
self
.
_grad_buffer_param_index_map
=
{}
data_parallel_world_size
=
mpu
.
get_data_parallel_world_size
()
# Simple function to define buffer type.
# Simple function to define buffer type.
def
_get_buffer_type
(
param
):
def
_get_buffer_type
(
param
):
...
@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -139,7 +142,18 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Allocate the buffer.
# Allocate the buffer.
for
dtype
,
num_elements
in
type_num_elements
.
items
():
for
dtype
,
num_elements
in
type_num_elements
.
items
():
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
dtype
)
# If using distributed optimizer, pad memory buffer to be
# multiple of data_parallel_world_size. (This padding is done
# due to a constraint with the reduce_scatter op, which requires
# all tensors have equal size. See: optimizer.py.)
num_elements_padded
=
data_parallel_world_size
*
\
int
(
math
.
ceil
(
num_elements
/
data_parallel_world_size
))
# Allocate grad buffer.
self
.
_grad_buffers
[
dtype
]
=
MemoryBuffer
(
num_elements
,
num_elements_padded
,
dtype
)
# Assume the back prop order is reverse the params order,
# Assume the back prop order is reverse the params order,
# store the start index for the gradients.
# store the start index for the gradients.
...
@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -149,6 +163,12 @@ class DistributedDataParallel(DistributedDataParallelBase):
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
type_num_elements
[
dtype
]
-=
param
.
data
.
nelement
()
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
main_grad
=
self
.
_grad_buffers
[
dtype
].
get
(
param
.
data
.
shape
,
type_num_elements
[
dtype
])
param
.
data
.
shape
,
type_num_elements
[
dtype
])
if
dtype
not
in
self
.
_grad_buffer_param_index_map
:
self
.
_grad_buffer_param_index_map
[
dtype
]
=
{}
self
.
_grad_buffer_param_index_map
[
dtype
][
param
]
=
(
type_num_elements
[
dtype
],
type_num_elements
[
dtype
]
+
param
.
data
.
nelement
(),
)
# Backward hook.
# Backward hook.
# Accumalation function for the gradients. We need
# Accumalation function for the gradients. We need
...
@@ -170,7 +190,8 @@ class DistributedDataParallel(DistributedDataParallelBase):
...
@@ -170,7 +190,8 @@ class DistributedDataParallel(DistributedDataParallelBase):
# Hook used for back-prop.
# Hook used for back-prop.
def
param_hook
(
*
unused
):
def
param_hook
(
*
unused
):
# Add the gradient to the buffer.
# Add the gradient to the buffer.
if
param
.
grad
.
data
is
not
None
:
if
param
.
grad
is
not
None
:
# The gradient function of linear layers is fused with GEMMs
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
param
.
main_grad
.
add_
(
param
.
grad
.
data
)
# Now we can deallocate grad memory.
# Now we can deallocate grad memory.
param
.
grad
=
None
param
.
grad
=
None
...
...
megatron/model/fused_layer_norm.py
View file @
b8428a7f
...
@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
...
@@ -23,6 +23,8 @@ from torch.nn.parameter import Parameter
from
torch.nn
import
init
from
torch.nn
import
init
import
importlib
import
importlib
from
megatron.mpu
import
make_viewless_tensor
try
:
try
:
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
from
apex.contrib.layer_norm.layer_norm
import
FastLayerNormFN
HAVE_PERSIST_LAYER_NORM
=
True
HAVE_PERSIST_LAYER_NORM
=
True
...
@@ -67,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
...
@@ -67,7 +69,9 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
class
MixedFusedLayerNorm
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
):
def
__init__
(
self
,
normalized_shape
,
eps
=
1e-5
,
no_persist_layer_norm
=
True
,
sequence_parallel
=
False
):
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
super
(
MixedFusedLayerNorm
,
self
).
__init__
()
global
fused_mix_prec_layer_norm_cuda
global
fused_mix_prec_layer_norm_cuda
...
@@ -92,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -92,6 +96,11 @@ class MixedFusedLayerNorm(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
bias
=
Parameter
(
torch
.
Tensor
(
*
normalized_shape
))
self
.
reset_parameters
()
self
.
reset_parameters
()
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
no_persist_layer_norm
=
no_persist_layer_norm
self
.
sequence_parallel
=
sequence_parallel
# set sequence parallelism flag on weight and bias parameters
setattr
(
self
.
weight
,
'sequence_parallel'
,
self
.
sequence_parallel
)
setattr
(
self
.
bias
,
'sequence_parallel'
,
self
.
sequence_parallel
)
def
reset_parameters
(
self
):
def
reset_parameters
(
self
):
...
@@ -106,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
...
@@ -106,6 +115,15 @@ class MixedFusedLayerNorm(torch.nn.Module):
return
FusedLayerNormAffineFunction
.
apply
(
return
FusedLayerNormAffineFunction
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
normalized_shape
,
self
.
eps
)
else
:
else
:
return
FastLayerNormFN
.
apply
(
output
=
FastLayerNormFN
.
apply
(
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
input
,
self
.
weight
,
self
.
bias
,
self
.
eps
)
# Apex's fast layer norm function outputs a 'view' tensor (i.e., has
# a populated '_base' field). This will result in schedule.py's
# deallocate_output_tensor() throwing an error, so a viewless tensor is
# created to prevent this.
output
=
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
input
.
requires_grad
,
keep_graph
=
True
)
return
output
megatron/model/gpt_model.py
View file @
b8428a7f
...
@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights,
...
@@ -32,20 +32,26 @@ def post_language_model_processing(lm_output, labels, logit_weights,
parallel_output
,
parallel_output
,
fp16_lm_cross_entropy
):
fp16_lm_cross_entropy
):
# Output.
# Output.
Format [s b h]
output
=
parallel_lm_logits
(
output
=
parallel_lm_logits
(
lm_output
,
lm_output
,
logit_weights
,
logit_weights
,
parallel_output
)
parallel_output
)
if
labels
is
None
:
if
labels
is
None
:
return
output
# [s b h] => [b s h]
return
output
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
assert
output
.
dtype
==
torch
.
half
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
else
:
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
return
loss
...
...
megatron/model/language_model.py
View file @
b8428a7f
...
@@ -26,17 +26,29 @@ from megatron.model.transformer import ParallelTransformer
...
@@ -26,17 +26,29 @@ from megatron.model.transformer import ParallelTransformer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
from
megatron.model.utils
import
init_method_normal
,
scaled_init_method_normal
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
def
parallel_lm_logits
(
input_
,
word_embeddings_weight
,
parallel_output
,
bias
=
None
):
bias
=
None
):
"""LM logits using word embedding weights."""
"""LM logits using word embedding weights."""
args
=
get_args
()
# Parallel logits.
# Parallel logits.
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
if
args
.
async_tensor_model_parallel_allreduce
or
\
# Matrix multiply.
args
.
sequence_parallel
:
if
bias
is
None
:
input_parallel
=
input_
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
)
model_parallel
=
mpu
.
get_tensor_model_parallel_world_size
()
>
1
async_grad_allreduce
=
args
.
async_tensor_model_parallel_allreduce
and
\
model_parallel
and
not
args
.
sequence_parallel
else
:
else
:
logits_parallel
=
F
.
linear
(
input_parallel
,
word_embeddings_weight
,
bias
)
input_parallel
=
mpu
.
copy_to_tensor_model_parallel_region
(
input_
)
async_grad_allreduce
=
False
# Matrix multiply.
logits_parallel
=
mpu
.
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
word_embeddings_weight
,
bias
,
args
.
gradient_accumulation_fusion
,
async_grad_allreduce
,
args
.
sequence_parallel
)
# Gather if needed.
# Gather if needed.
if
parallel_output
:
if
parallel_output
:
return
logits_parallel
return
logits_parallel
...
@@ -92,12 +104,23 @@ class Pooler(MegatronModule):
...
@@ -92,12 +104,23 @@ class Pooler(MegatronModule):
def
__init__
(
self
,
hidden_size
,
init_method
):
def
__init__
(
self
,
hidden_size
,
init_method
):
super
(
Pooler
,
self
).
__init__
()
super
(
Pooler
,
self
).
__init__
()
args
=
get_args
()
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
dense
=
get_linear_layer
(
hidden_size
,
hidden_size
,
init_method
)
self
.
sequence_parallel
=
args
.
sequence_parallel
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
def
forward
(
self
,
hidden_states
,
sequence_index
=
0
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
# sequence_index: index of the token to pool.
# sequence_index: index of the token to pool.
pooled
=
hidden_states
[:,
sequence_index
,
:]
# gather data along sequence dimensions
# same pooler is run on all tensor parallel nodes
if
self
.
sequence_parallel
:
hidden_states
=
mpu
.
gather_from_sequence_parallel_region
(
hidden_states
,
tensor_parallel_output_grad
=
False
)
pooled
=
hidden_states
[
sequence_index
,
:,
:]
pooled
=
self
.
dense
(
pooled
)
pooled
=
self
.
dense
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
pooled
=
torch
.
tanh
(
pooled
)
return
pooled
return
pooled
...
@@ -158,6 +181,8 @@ class Embedding(MegatronModule):
...
@@ -158,6 +181,8 @@ class Embedding(MegatronModule):
else
:
else
:
self
.
tokentype_embeddings
=
None
self
.
tokentype_embeddings
=
None
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
sequence_parallel
=
args
.
sequence_parallel
# Embeddings dropout
# Embeddings dropout
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
embedding_dropout_prob
)
...
@@ -199,8 +224,20 @@ class Embedding(MegatronModule):
...
@@ -199,8 +224,20 @@ class Embedding(MegatronModule):
else
:
else
:
assert
self
.
tokentype_embeddings
is
None
assert
self
.
tokentype_embeddings
is
None
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
embeddings
=
embeddings
.
transpose
(
0
,
1
).
contiguous
()
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
embeddings
=
embeddings
.
float
()
# Dropout.
# Dropout.
embeddings
=
self
.
embedding_dropout
(
embeddings
)
if
self
.
sequence_parallel
:
embeddings
=
mpu
.
scatter_to_sequence_parallel_region
(
embeddings
)
with
mpu
.
get_cuda_rng_tracker
().
fork
():
embeddings
=
self
.
embedding_dropout
(
embeddings
)
else
:
embeddings
=
self
.
embedding_dropout
(
embeddings
)
return
embeddings
return
embeddings
...
...
megatron/model/t5_model.py
View file @
b8428a7f
...
@@ -152,19 +152,24 @@ class T5Model(MegatronModule):
...
@@ -152,19 +152,24 @@ class T5Model(MegatronModule):
if
self
.
post_process
and
self
.
add_decoder
:
if
self
.
post_process
and
self
.
add_decoder
:
decoder_output
,
encoder_output
=
lm_output
decoder_output
,
encoder_output
=
lm_output
# Output.
# Output.
[s, b, h]
lm_logits
=
self
.
lm_head
(
decoder_output
,
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
word_embeddings_weight
())
self
.
word_embeddings_weight
())
if
lm_labels
is
None
:
if
lm_labels
is
None
:
return
lm_logits
# [s b h] => [b s h]
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
# [b s] => [s b]
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
lm_labels
)
# [s b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
return
lm_loss
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
decoder_output
,
encoder_output
=
lm_output
decoder_output
,
encoder_output
=
lm_output
...
...
megatron/model/transformer.py
View file @
b8428a7f
...
@@ -15,10 +15,11 @@
...
@@ -15,10 +15,11 @@
"""Transformer."""
"""Transformer."""
import
math
import
math
from
contextlib
import
nullcontext
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_
args
from
megatron
import
get_
timers
,
get_args
,
get_global_memory_buffer
from
megatron
import
mpu
from
megatron
import
mpu
from
.module
import
MegatronModule
from
.module
import
MegatronModule
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
from
megatron.model.enums
import
AttnMaskType
,
ModelType
,
LayerType
,
AttnType
...
@@ -27,6 +28,7 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
...
@@ -27,6 +28,7 @@ from megatron.model.fused_softmax import FusedScaleMaskSoftmax
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.fused_bias_gelu
import
bias_gelu_impl
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
from
megatron.model.utils
import
attention_mask_func
,
openai_gelu
,
erf_gelu
""" We use the following notation throughout this file:
""" We use the following notation throughout this file:
h: hidden size
h: hidden size
n: number of attention heads
n: number of attention heads
...
@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
...
@@ -42,7 +44,6 @@ from megatron.model.utils import attention_mask_func, openai_gelu, erf_gelu
hyperparameters: transformer hyperparameters
hyperparameters: transformer hyperparameters
"""
"""
class
DropPath
(
MegatronModule
):
class
DropPath
(
MegatronModule
):
"""Drop paths (Stochastic Depth) per sample
"""Drop paths (Stochastic Depth) per sample
(when applied in main path of residual blocks).
(when applied in main path of residual blocks).
...
@@ -116,11 +117,196 @@ class ParallelMLP(MegatronModule):
...
@@ -116,11 +117,196 @@ class ParallelMLP(MegatronModule):
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
output
,
output_bias
=
self
.
dense_4h_to_h
(
intermediate_parallel
)
return
output
,
output_bias
return
output
,
output_bias
class
SwitchMLP
(
MegatronModule
):
"""
Routes input to one of N MLP "experts"
"""
def
__init__
(
self
,
init_method
,
output_layer_init_method
):
super
(
SwitchMLP
,
self
).
__init__
()
args
=
get_args
()
self
.
router
=
torch
.
nn
.
Linear
(
args
.
hidden_size
,
args
.
num_experts
)
self
.
experts
=
torch
.
nn
.
ModuleList
()
for
i
in
range
(
args
.
num_experts
):
self
.
experts
.
append
(
ParallelMLP
(
init_method
,
output_layer_init_method
))
def
forward
(
self
,
hidden_states
):
# hidden_states: [s, b, h]
s
=
hidden_states
.
size
(
0
)
b
=
hidden_states
.
size
(
1
)
h
=
hidden_states
.
size
(
2
)
route
=
self
.
router
(
hidden_states
)
route
=
torch
.
nn
.
functional
.
softmax
(
route
,
dim
=
2
)
max_prob
,
max_ind
=
torch
.
max
(
route
,
dim
=
2
)
max_prob
=
torch
.
unsqueeze
(
max_prob
,
2
)
# [s b 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
hidden_states
=
hidden_states
.
view
(
-
1
,
hidden_states
.
size
(
2
))
# [s*b h]
max_prob
=
max_prob
.
view
(
-
1
,
max_prob
.
size
(
2
))
# [s*b 1]
max_ind
=
max_ind
.
view
(
-
1
)
# [s*b]
output_total
=
torch
.
empty_like
(
hidden_states
)
output_bias_total
=
torch
.
empty_like
(
hidden_states
)
#TODO (rprenger) This does each expert in serial, but it could be parallelized
for
expert_num
,
expert
in
enumerate
(
self
.
experts
):
local_indices
=
(
max_ind
==
expert_num
).
nonzero
()
hidden
=
hidden_states
[
local_indices
,:]
output
,
output_bias
=
expert
(
hidden
)
output_bias
=
output_bias
.
expand_as
(
output
)
output_total
[
local_indices
,:]
=
output
output_bias_total
[
local_indices
,:]
=
output_bias
output_total
=
output_total
*
max_prob
output_bias_total
=
output_bias_total
*
max_prob
output_total
=
output_total
.
view
(
s
,
b
,
h
)
output_bias_total
=
output_bias_total
.
view
(
s
,
b
,
h
)
return
output_total
,
output_bias_total
class
CoreAttention
(
MegatronModule
):
def
__init__
(
self
,
layer_number
,
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
CoreAttention
,
self
).
__init__
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attn_mask_type
=
attn_mask_type
self
.
sequence_parallel
=
args
.
sequence_parallel
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
args
.
num_attention_heads
,
world_size
)
coeff
=
None
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
if
self
.
apply_query_key_layer_scaling
:
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
def
forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
# ===================================
# Raw attention scores. [b, np, s, s]
# ===================================
# [b, np, sq, sk]
output_size
=
(
query_layer
.
size
(
1
),
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
matmul_input_buffer
=
get_global_memory_buffer
().
get_tensor
(
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
query_layer
.
dtype
,
"mpu"
)
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_input_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
if
not
self
.
sequence_parallel
:
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
else
:
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
return
context_layer
class
ParallelAttention
(
MegatronModule
):
class
ParallelAttention
(
MegatronModule
):
"""Parallel self-attention layer abstract class.
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [
b
,
s
, h]
Self-attention layer takes input with size [
s
,
b
, h]
and returns output of the same size.
and returns output of the same size.
"""
"""
...
@@ -130,13 +316,6 @@ class ParallelAttention(MegatronModule):
...
@@ -130,13 +316,6 @@ class ParallelAttention(MegatronModule):
attn_mask_type
=
AttnMaskType
.
padding
):
attn_mask_type
=
AttnMaskType
.
padding
):
super
(
ParallelAttention
,
self
).
__init__
()
super
(
ParallelAttention
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
fp16
=
args
.
fp16
self
.
bf16
=
args
.
bf16
self
.
apply_query_key_layer_scaling
=
args
.
apply_query_key_layer_scaling
self
.
attention_softmax_in_fp32
=
args
.
attention_softmax_in_fp32
if
self
.
apply_query_key_layer_scaling
:
self
.
attention_softmax_in_fp32
=
True
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
...
@@ -146,8 +325,6 @@ class ParallelAttention(MegatronModule):
...
@@ -146,8 +325,6 @@ class ParallelAttention(MegatronModule):
# Per attention head and per partition values.
# Per attention head and per partition values.
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
world_size
=
mpu
.
get_tensor_model_parallel_world_size
()
self
.
hidden_size_per_partition
=
mpu
.
divide
(
projection_size
,
world_size
)
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
self
.
hidden_size_per_attention_head
=
mpu
.
divide
(
projection_size
,
args
.
num_attention_heads
)
projection_size
,
args
.
num_attention_heads
)
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
self
.
num_attention_heads_per_partition
=
mpu
.
divide
(
...
@@ -174,24 +351,9 @@ class ParallelAttention(MegatronModule):
...
@@ -174,24 +351,9 @@ class ParallelAttention(MegatronModule):
gather_output
=
False
,
gather_output
=
False
,
init_method
=
init_method
)
init_method
=
init_method
)
coeff
=
None
self
.
core_attention
=
CoreAttention
(
self
.
layer_number
,
self
.
norm_factor
=
math
.
sqrt
(
self
.
hidden_size_per_attention_head
)
self
.
attn_mask_type
)
if
self
.
apply_query_key_layer_scaling
:
self
.
checkpoint_core_attention
=
args
.
recompute_granularity
==
'selective'
coeff
=
self
.
layer_number
self
.
norm_factor
*=
coeff
self
.
scale_mask_softmax
=
FusedScaleMaskSoftmax
(
self
.
fp16
,
self
.
bf16
,
self
.
attn_mask_type
,
args
.
masked_softmax_fusion
,
attention_mask_func
,
self
.
attention_softmax_in_fp32
,
coeff
)
# Dropout. Note that for a single iteration, this layer will generate
# different outputs on different number of parallel partitions but
# on average it should not be partition dependent.
self
.
attention_dropout
=
torch
.
nn
.
Dropout
(
args
.
attention_dropout
)
# Output.
# Output.
self
.
dense
=
mpu
.
RowParallelLinear
(
self
.
dense
=
mpu
.
RowParallelLinear
(
...
@@ -201,6 +363,23 @@ class ParallelAttention(MegatronModule):
...
@@ -201,6 +363,23 @@ class ParallelAttention(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
_checkpointed_attention_forward
(
self
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
):
"""Forward method with activation checkpointing."""
def
custom_forward
(
*
inputs
):
query_layer
=
inputs
[
0
]
key_layer
=
inputs
[
1
]
value_layer
=
inputs
[
2
]
attention_mask
=
inputs
[
3
]
output_
=
self
.
core_attention
(
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
return
output_
hidden_states
=
mpu
.
checkpoint
(
custom_forward
,
False
,
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
return
hidden_states
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
return
torch
.
empty
(
...
@@ -210,13 +389,11 @@ class ParallelAttention(MegatronModule):
...
@@ -210,13 +389,11 @@ class ParallelAttention(MegatronModule):
self
.
hidden_size_per_attention_head
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
inference_params
=
None
):
encoder_output
=
None
,
inference_params
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
# =================================================
# =================================================
# Pre-allocate memory for key-values for inference.
# Pre-allocate memory for key-values for inference.
# =================================================
# =================================================
...
@@ -234,7 +411,6 @@ class ParallelAttention(MegatronModule):
...
@@ -234,7 +411,6 @@ class ParallelAttention(MegatronModule):
inference_key_memory
,
inference_value_memory
=
\
inference_key_memory
,
inference_value_memory
=
\
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
inference_params
.
key_value_memory_dict
[
self
.
layer_number
]
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
...
@@ -275,7 +451,6 @@ class ParallelAttention(MegatronModule):
...
@@ -275,7 +451,6 @@ class ParallelAttention(MegatronModule):
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# ==================================
# Adjust key and value for inference
# Adjust key and value for inference
# ==================================
# ==================================
...
@@ -297,90 +472,16 @@ class ParallelAttention(MegatronModule):
...
@@ -297,90 +472,16 @@ class ParallelAttention(MegatronModule):
value_layer
=
inference_value_memory
[
value_layer
=
inference_value_memory
[
:
sequence_end
,
batch_start
:
batch_end
,
...]
:
sequence_end
,
batch_start
:
batch_end
,
...]
# ==================================
# core attention computation
# ==================================
# ===================================
if
self
.
checkpoint_core_attention
:
# Raw attention scores. [b, np, s, s]
context_layer
=
self
.
_checkpointed_attention_forward
(
# ===================================
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
else
:
# [b, np, sq, sk]
context_layer
=
self
.
core_attention
(
output_size
=
(
query_layer
.
size
(
1
),
query_layer
,
key_layer
,
value_layer
,
attention_mask
)
query_layer
.
size
(
2
),
query_layer
.
size
(
0
),
key_layer
.
size
(
0
))
# [sq, b, np, hn] -> [sq, b * np, hn]
query_layer
=
query_layer
.
view
(
output_size
[
2
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# [sk, b, np, hn] -> [sk, b * np, hn]
key_layer
=
key_layer
.
view
(
output_size
[
3
],
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting result tensor: [b * np, sq, sk]
matmul_result
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ===========================
# Attention probs and dropout
# ===========================
# attention scores and attention mask [b, np, sq, sk]
attention_probs
=
self
.
scale_mask_softmax
(
attention_scores
,
attention_mask
)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with
mpu
.
get_cuda_rng_tracker
().
fork
():
attention_probs
=
self
.
attention_dropout
(
attention_probs
)
# =========================
# Context layer. [sq, b, hp]
# =========================
# value_layer -> context layer.
# [sk, b, np, hn] --> [b, np, sq, hn]
# context layer shape: [b, np, sq, hn]
output_size
=
(
value_layer
.
size
(
1
),
value_layer
.
size
(
2
),
query_layer
.
size
(
0
),
value_layer
.
size
(
3
))
# change view [sk, b * np, hn]
value_layer
=
value_layer
.
view
(
value_layer
.
size
(
0
),
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# change view [b * np, sq, sk]
attention_probs
=
attention_probs
.
view
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
-
1
)
# matmul: [b * np, sq, hn]
context_layer
=
torch
.
bmm
(
attention_probs
,
value_layer
.
transpose
(
0
,
1
))
# change view [b, np, sq, hn]
context_layer
=
context_layer
.
view
(
*
output_size
)
# [b, np, sq, hn] --> [sq, b, np, hn]
context_layer
=
context_layer
.
permute
(
2
,
0
,
1
,
3
).
contiguous
()
# [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape
=
context_layer
.
size
()[:
-
2
]
+
\
(
self
.
hidden_size_per_partition
,)
context_layer
=
context_layer
.
view
(
*
new_context_layer_shape
)
# =================
# =================
# Output. [sq, b, h]
# Output. [sq, b, h]
...
@@ -423,7 +524,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
...
@@ -423,7 +524,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class
ParallelTransformerLayer
(
MegatronModule
):
class
ParallelTransformerLayer
(
MegatronModule
):
"""A single transformer layer.
"""A single transformer layer.
Transformer layer takes input with size [
b
,
s
, h] and returns an
Transformer layer takes input with size [
s
,
b
, h] and returns an
output of the same size.
output of the same size.
"""
"""
...
@@ -447,7 +548,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -447,7 +548,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
input_layernorm
=
LayerNorm
(
self
.
input_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
# Self attention.
# Self attention.
self
.
self_attention
=
ParallelAttention
(
self
.
self_attention
=
ParallelAttention
(
...
@@ -464,7 +566,8 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -464,7 +566,8 @@ class ParallelTransformerLayer(MegatronModule):
self
.
post_attention_layernorm
=
LayerNorm
(
self
.
post_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
if
self
.
layer_type
==
LayerType
.
decoder
:
if
self
.
layer_type
==
LayerType
.
decoder
:
self
.
inter_attention
=
ParallelAttention
(
self
.
inter_attention
=
ParallelAttention
(
...
@@ -476,16 +579,26 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -476,16 +579,26 @@ class ParallelTransformerLayer(MegatronModule):
self
.
post_inter_attention_layernorm
=
LayerNorm
(
self
.
post_inter_attention_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
# MLP
# MLP
self
.
mlp
=
ParallelMLP
(
init_method
,
if
args
.
num_experts
is
not
None
:
output_layer_init_method
)
self
.
mlp
=
SwitchMLP
(
init_method
,
output_layer_init_method
)
else
:
self
.
mlp
=
ParallelMLP
(
init_method
,
output_layer_init_method
)
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
'.'
)[
1
])
use_nvfuser
=
TORCH_MAJOR
>
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
)
self
.
bias_dropout_add_exec_handler
=
\
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
):
# hidden_states: [
b
,
s
, h]
# hidden_states: [
s
,
b
, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
...
@@ -515,8 +628,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -515,8 +628,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
bias_dropout_add_func
=
get_bias_dropout_add
(
self
.
training
)
# re-enable torch grad to enable fused optimization.
with
self
.
bias_dropout_add_exec_handler
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
),
attention_bias
.
expand_as
(
residual
),
...
@@ -542,8 +654,7 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -542,8 +654,7 @@ class ParallelTransformerLayer(MegatronModule):
else
:
else
:
residual
=
layernorm_input
residual
=
layernorm_input
# re-enable torch grad to enable fused optimization.
with
self
.
bias_dropout_add_exec_handler
():
with
torch
.
enable_grad
():
layernorm_input
=
bias_dropout_add_func
(
layernorm_input
=
bias_dropout_add_func
(
attention_output
,
attention_output
,
attention_bias
.
expand_as
(
residual
),
attention_bias
.
expand_as
(
residual
),
...
@@ -563,13 +674,23 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -563,13 +674,23 @@ class ParallelTransformerLayer(MegatronModule):
residual
=
layernorm_input
residual
=
layernorm_input
if
self
.
drop_path
is
None
:
if
self
.
drop_path
is
None
:
# re-enable torch grad to enable fused optimization.
with
self
.
bias_dropout_add_exec_handler
():
with
torch
.
enable_grad
():
output
=
bias_dropout_add_func
(
output
=
bias_dropout_add_func
(
mlp_output
,
mlp_output
,
mlp_bias
.
expand_as
(
residual
),
mlp_bias
.
expand_as
(
residual
),
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
# Jit compiled function creates 'view' tensor. This tensor
# potentially gets saved in the MPU checkpoint function context,
# which rejects view tensors. While making a viewless tensor here
# won't result in memory savings (like the data loader, or
# p2p_communication), it serves to document the origin of this
# 'view' tensor.
output
=
mpu
.
make_viewless_tensor
(
inp
=
output
,
requires_grad
=
output
.
requires_grad
,
keep_graph
=
True
)
else
:
else
:
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
out
=
torch
.
nn
.
functional
.
dropout
(
mlp_output
+
mlp_bias
,
p
=
self
.
hidden_dropout
,
p
=
self
.
hidden_dropout
,
...
@@ -611,22 +732,30 @@ class ParallelTransformer(MegatronModule):
...
@@ -611,22 +732,30 @@ class ParallelTransformer(MegatronModule):
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
def
__init__
(
self
,
init_method
,
output_layer_init_method
,
layer_type
=
LayerType
.
encoder
,
layer_type
=
LayerType
.
encoder
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
self_attn_mask_type
=
AttnMaskType
.
padding
,
post_layer_norm
=
True
,
pre_process
=
True
,
post_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
drop_path_rate
=
0.0
):
drop_path_rate
=
0.0
):
super
(
ParallelTransformer
,
self
).
__init__
()
super
(
ParallelTransformer
,
self
).
__init__
()
args
=
get_args
()
args
=
get_args
()
self
.
layer_type
=
layer_type
self
.
model_type
=
args
.
model_type
self
.
bf16
=
args
.
bf16
self
.
bf16
=
args
.
bf16
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
fp32_residual_connection
=
args
.
fp32_residual_connection
self
.
post_layer_norm
=
post_layer_norm
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
input_tensor
=
None
self
.
input_tensor
=
None
self
.
drop_path_rate
=
drop_path_rate
self
.
drop_path_rate
=
drop_path_rate
# Store activation checkpoiting flag.
# Store activation checkpoiting flag.
self
.
activations_checkpoint_method
=
args
.
activations_checkpoint_method
self
.
recompute_granularity
=
args
.
recompute_granularity
self
.
activations_checkpoint_num_layers
=
args
.
activations_checkpoint_num_layers
self
.
recompute_method
=
args
.
recompute_method
self
.
distribute_checkpointed_activations
=
args
.
distribute_checkpointed_activations
self
.
recompute_num_layers
=
args
.
recompute_num_layers
self
.
distribute_saved_activations
=
\
args
.
distribute_saved_activations
and
not
args
.
sequence_parallel
self
.
sequence_parallel
=
args
.
sequence_parallel
# Number of layers.
# Number of layers.
self
.
num_layers
=
mpu
.
get_num_layers
(
self
.
num_layers
=
mpu
.
get_num_layers
(
...
@@ -690,12 +819,13 @@ class ParallelTransformer(MegatronModule):
...
@@ -690,12 +819,13 @@ class ParallelTransformer(MegatronModule):
self
.
layers
=
torch
.
nn
.
ModuleList
(
self
.
layers
=
torch
.
nn
.
ModuleList
(
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
[
build_layer
(
i
+
1
+
offset
)
for
i
in
range
(
self
.
num_layers
)])
if
self
.
post_process
:
if
self
.
post_process
and
self
.
post_layer_norm
:
# Final layer norm before output.
# Final layer norm before output.
self
.
final_layernorm
=
LayerNorm
(
self
.
final_layernorm
=
LayerNorm
(
args
.
hidden_size
,
args
.
hidden_size
,
eps
=
args
.
layernorm_epsilon
,
eps
=
args
.
layernorm_epsilon
,
no_persist_layer_norm
=
args
.
no_persist_layer_norm
)
no_persist_layer_norm
=
args
.
no_persist_layer_norm
,
sequence_parallel
=
args
.
sequence_parallel
)
def
_get_layer
(
self
,
layer_number
):
def
_get_layer
(
self
,
layer_number
):
return
self
.
layers
[
layer_number
]
return
self
.
layers
[
layer_number
]
...
@@ -715,32 +845,33 @@ class ParallelTransformer(MegatronModule):
...
@@ -715,32 +845,33 @@ class ParallelTransformer(MegatronModule):
return
x_
return
x_
return
custom_forward
return
custom_forward
if
self
.
activations_checkpoint
_method
==
'uniform'
:
if
self
.
recompute
_method
==
'uniform'
:
# Uniformly divide the total number of Transformer layers and checkpoint
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
# the input activation of each divided chunk.
# A method to further reduce memory usage reducing checkpoints.
# A method to further reduce memory usage reducing checkpoints.
l
=
0
l
=
0
while
l
<
self
.
num_layers
:
while
l
<
self
.
num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
self
.
activations_checkpoint
_num_layers
),
custom
(
l
,
l
+
self
.
recompute
_num_layers
),
self
.
distribute_
checkpoint
ed_activations
,
self
.
distribute_
sav
ed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
l
+=
self
.
activations_checkpoint_num_layers
l
+=
self
.
recompute_num_layers
elif
self
.
activations_checkpoint_method
==
'block'
:
elif
self
.
recompute_method
==
'block'
:
# Checkpoint the input activation of only a set number of individual
# Checkpoint the input activation of only a set number of individual
# Transformer layers and skip the rest.
# Transformer layers and skip the rest.
# A method fully use the device memory removing redundant re-computation.
# A method fully use the device memory removing redundant re-computation.
for
l
in
range
(
self
.
num_layers
):
for
l
in
range
(
self
.
num_layers
):
if
l
<
self
.
activations_checkpoint
_num_layers
:
if
l
<
self
.
recompute
_num_layers
:
hidden_states
=
mpu
.
checkpoint
(
hidden_states
=
mpu
.
checkpoint
(
custom
(
l
,
l
+
1
),
custom
(
l
,
l
+
1
),
self
.
distribute_
checkpoint
ed_activations
,
self
.
distribute_
sav
ed_activations
,
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
=
custom
(
l
,
l
+
1
)(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
raise
ValueError
(
"Invalid activation
checkpoint
method."
)
raise
ValueError
(
"Invalid activation
recompute
method."
)
return
hidden_states
return
hidden_states
...
@@ -757,21 +888,14 @@ class ParallelTransformer(MegatronModule):
...
@@ -757,21 +888,14 @@ class ParallelTransformer(MegatronModule):
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
inference_params
=
None
):
inference_params
=
None
):
# hidden_states: [s, b, h]
# Checks.
# Checks.
if
inference_params
:
if
inference_params
:
assert
self
.
activations_checkpoint_method
is
None
,
\
assert
self
.
recompute_granularity
is
None
,
\
'inference does not work with activation checkpointing'
'inference does not work with activation checkpointing'
if
self
.
pre_process
:
if
not
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# If the input flag for fp32 residual connection is set, convert for float.
if
self
.
fp32_residual_connection
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
().
float
()
# Otherwise, leave it as is.
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
else
:
# See set_input_tensor()
# See set_input_tensor()
hidden_states
=
self
.
input_tensor
hidden_states
=
self
.
input_tensor
...
@@ -792,37 +916,34 @@ class ParallelTransformer(MegatronModule):
...
@@ -792,37 +916,34 @@ class ParallelTransformer(MegatronModule):
# is called here to be future-proof and corner-case-proof.
# is called here to be future-proof and corner-case-proof.
hidden_states
=
mpu
.
make_viewless_tensor
(
hidden_states
=
mpu
.
make_viewless_tensor
(
hidden_states
,
hidden_states
,
requires_grad
=
True
,
requires_grad
=
True
,
keep_graph
=
True
,
keep_graph
=
True
,
)
)
# Transpose encoder output.
if
self
.
sequence_parallel
:
if
encoder_output
is
not
None
:
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
encoder_output
=
encoder_output
.
transpose
(
0
,
1
).
contiguous
()
# Forward pass.
if
self
.
activations_checkpoint_method
is
not
None
:
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
attention_mask
,
encoder_output
,
enc_dec_attn_mask
)
else
:
else
:
for
index
in
range
(
self
.
num_layers
):
rng_context
=
nullcontext
()
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
with
rng_context
:
hidden_states
,
# Forward pass.
attention_mask
,
if
self
.
recompute_granularity
==
'full'
:
encoder_output
=
encoder_output
,
hidden_states
=
self
.
_checkpointed_forward
(
hidden_states
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
attention_mask
,
inference_params
=
inference_params
)
encoder_output
,
enc_dec_attn_mask
)
else
:
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
hidden_states
=
layer
(
hidden_states
,
attention_mask
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_params
=
inference_params
)
# Final layer norm.
# Final layer norm.
if
self
.
post_process
:
if
self
.
post_process
and
self
.
post_layer_norm
:
# Reverting data format change [s b h] --> [b s h].
hidden_states
=
self
.
final_layernorm
(
hidden_states
)
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
output
=
hidden_states
return
output
return
hidden_states
megatron/model/vision/classification.py
View file @
b8428a7f
...
@@ -16,9 +16,11 @@
...
@@ -16,9 +16,11 @@
"""Vision Transformer(VIT) model."""
"""Vision Transformer(VIT) model."""
import
torch
import
torch
from
torch.nn.init
import
trunc_normal_
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.vit_backbone
import
VitBackbone
,
VitMlpHead
from
megatron.model.vision.mit_backbone
import
mit_b3_avg
from
megatron.model.module
import
MegatronModule
from
megatron.model.module
import
MegatronModule
class
VitClassificationModel
(
MegatronModule
):
class
VitClassificationModel
(
MegatronModule
):
...
@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
...
@@ -61,3 +63,35 @@ class VitClassificationModel(MegatronModule):
hidden_states
=
self
.
head
(
hidden_states
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
return
hidden_states
class
MitClassificationModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
num_classes
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitClassificationModel
,
self
).
__init__
()
args
=
get_args
()
self
.
hidden_size
=
args
.
hidden_size
self
.
num_classes
=
num_classes
self
.
backbone
=
mit_b3_avg
()
self
.
head
=
torch
.
nn
.
Linear
(
512
,
num_classes
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
hidden_states
=
self
.
head
(
hidden_states
)
return
hidden_states
megatron/model/vision/dino.py
0 → 100644
View file @
b8428a7f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the Apache license found in the
# LICENSE file in the root directory of this source tree.
# copied from https://github.com/facebookresearch/dino/blob/main/main_dino.py
# reworked/refactored some parts to make it run in Megatron.
import
math
import
apex
import
einops
import
torch
import
numpy
as
np
import
torch.nn.functional
as
F
from
torch.nn.init
import
trunc_normal_
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.mit_backbone
import
mit_b5_avg
from
megatron.model.vision.esvit_swin_backbone
import
get_swin
class
DINOLoss
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
out_dim
,
ncrops
,
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
,
nepochs
,
student_temp
=
0.1
,
center_momentum
=
0.9
):
super
().
__init__
()
self
.
student_temp
=
student_temp
self
.
center_momentum
=
center_momentum
self
.
ncrops
=
ncrops
self
.
register_buffer
(
"center"
,
torch
.
zeros
(
1
,
out_dim
))
# we apply a warm up for the teacher temperature because
# a too high temperature makes the training instable at the beginning
self
.
teacher_temp_schedule
=
np
.
concatenate
((
np
.
linspace
(
warmup_teacher_temp
,
teacher_temp
,
warmup_teacher_temp_epochs
),
np
.
ones
(
nepochs
-
warmup_teacher_temp_epochs
)
*
teacher_temp
))
self
.
teacher_temp
=
teacher_temp
def
forward
(
self
,
student_output
,
teacher_output
,
iteration
):
"""
Cross-entropy between softmax outputs of the teacher
and student network.
"""
args
=
get_args
()
student_out
=
student_output
/
self
.
student_temp
student_out
=
student_out
.
chunk
(
self
.
ncrops
)
epoch
=
iteration
//
args
.
iter_per_epoch
# teacher centering and sharpening
temp
=
self
.
teacher_temp_schedule
[
epoch
]
teacher_out
=
F
.
softmax
((
teacher_output
-
self
.
center
)
/
temp
,
dim
=-
1
)
teacher_out
=
teacher_out
.
detach
().
chunk
(
2
)
total_loss
=
0
n_loss_terms
=
0
for
iq
,
q
in
enumerate
(
teacher_out
):
for
v
in
range
(
len
(
student_out
)):
if
v
==
iq
:
# we skip cases where student and teacher operate on the same view
continue
loss
=
torch
.
sum
(
-
q
*
F
.
log_softmax
(
student_out
[
v
],
dim
=-
1
),
dim
=-
1
)
total_loss
+=
loss
.
mean
()
n_loss_terms
+=
1
total_loss
/=
n_loss_terms
self
.
update_center
(
teacher_output
)
return
total_loss
@
torch
.
no_grad
()
def
update_center
(
self
,
teacher_output
):
"""
Update center used for teacher output.
"""
batch_center
=
torch
.
sum
(
teacher_output
,
dim
=
0
,
keepdim
=
True
)
torch
.
distributed
.
all_reduce
(
batch_center
)
batch_center
=
batch_center
/
(
len
(
teacher_output
)
*
torch
.
distributed
.
get_world_size
())
self
.
center
=
self
.
center
*
self
.
center_momentum
+
batch_center
*
(
1
-
self
.
center_momentum
)
class
DINOHead
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
in_dim
,
out_dim
,
norm_last_layer
=
True
,
nlayers
=
3
):
super
().
__init__
()
args
=
get_args
()
hidden_dim
=
args
.
dino_head_hidden_size
bottleneck_dim
=
args
.
dino_bottleneck_size
nlayers
=
max
(
nlayers
,
1
)
if
nlayers
==
1
:
self
.
mlp
=
torch
.
nn
.
Linear
(
in_dim
,
bottleneck_dim
)
else
:
layers
=
[
torch
.
nn
.
Linear
(
in_dim
,
hidden_dim
)]
layers
.
append
(
torch
.
nn
.
GELU
())
for
_
in
range
(
nlayers
-
2
):
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
hidden_dim
))
layers
.
append
(
torch
.
nn
.
GELU
())
layers
.
append
(
torch
.
nn
.
Linear
(
hidden_dim
,
bottleneck_dim
))
self
.
mlp
=
torch
.
nn
.
Sequential
(
*
layers
)
self
.
apply
(
self
.
_init_weights
)
self
.
last_layer
=
torch
.
nn
.
utils
.
weight_norm
(
torch
.
nn
.
Linear
(
bottleneck_dim
,
out_dim
,
bias
=
False
))
self
.
last_layer
.
weight_g
.
data
.
fill_
(
1
)
if
norm_last_layer
:
self
.
last_layer
.
weight_g
.
requires_grad
=
False
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
torch
.
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
torch
.
nn
.
Linear
)
and
m
.
bias
is
not
None
:
torch
.
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
forward
(
self
,
x
):
x
=
self
.
mlp
(
x
)
x
=
torch
.
nn
.
functional
.
normalize
(
x
,
dim
=-
1
,
p
=
2
)
x
=
self
.
last_layer
(
x
)
return
x
class
MultiCropWrapper
(
MegatronModule
):
"""
Perform forward pass separately on each resolution input.
The inputs corresponding to a single resolution are clubbed and single
forward is run on the same resolution inputs. Hence we do several
forward passes = number of different resolutions used. We then
concatenate all the output features and run the head forward on these
concatenated features.
"""
def
__init__
(
self
,
backbone
,
head
):
super
(
MultiCropWrapper
,
self
).
__init__
()
# disable layers dedicated to ImageNet labels classification
#backbone.fc, backbone.head = torch.nn.Identity(), torch.nn.Identity()
self
.
backbone
=
backbone
self
.
head
=
head
def
forward
(
self
,
x
):
# convert to list
if
not
isinstance
(
x
,
list
):
x
=
[
x
]
idx_crops
=
torch
.
cumsum
(
torch
.
unique_consecutive
(
torch
.
tensor
([
inp
.
shape
[
-
1
]
for
inp
in
x
]),
return_counts
=
True
,
)[
1
],
0
)
start_idx
=
0
for
end_idx
in
idx_crops
:
_out
=
self
.
backbone
(
torch
.
cat
(
x
[
start_idx
:
end_idx
]))
if
start_idx
==
0
:
output
=
_out
else
:
output
=
torch
.
cat
((
output
,
_out
))
start_idx
=
end_idx
# Run the head forward on the concatenated features.
if
self
.
training
:
return
self
.
head
(
output
)
else
:
return
output
def
cosine_scheduler
(
base_value
,
final_value
,
epochs
,
niter_per_ep
,
warmup_epochs
=
0
,
start_warmup_value
=
0
):
warmup_schedule
=
np
.
array
([])
warmup_iters
=
warmup_epochs
*
niter_per_ep
if
warmup_epochs
>
0
:
warmup_schedule
=
\
np
.
linspace
(
start_warmup_value
,
base_value
,
warmup_iters
)
iters
=
np
.
arange
(
epochs
*
niter_per_ep
-
warmup_iters
)
schedule
=
final_value
+
0.5
*
(
base_value
-
final_value
)
\
*
(
1
+
np
.
cos
(
np
.
pi
*
iters
/
len
(
iters
)))
schedule
=
np
.
concatenate
((
warmup_schedule
,
schedule
))
assert
len
(
schedule
)
==
epochs
*
niter_per_ep
return
schedule
def
get_student_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
student
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
drop_path_rate
=
0.1
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
student
=
mit_b5_avg
(
drop_path_rate
=
0.1
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
student
=
get_swin
()
num_features
=
student
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
student
,
num_features
def
get_teacher_backbone_and_num_features
(
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
if
args
.
vision_backbone_type
==
'vit'
:
teacher
=
VitBackbone
(
pre_process
=
pre_process
,
post_process
=
post_process
,
single_token_output
=
True
)
num_features
=
args
.
hidden_size
elif
args
.
vision_backbone_type
==
'mit'
:
teacher
=
mit_b5_avg
(
drop_path_rate
=
0.0
)
num_features
=
512
elif
args
.
vision_backbone_type
==
'swin'
:
teacher
=
get_swin
(
is_teacher
=
True
)
num_features
=
teacher
.
num_features
else
:
raise
Exception
(
'{} vision backbone is not supported.'
.
format
(
args
.
vision_backbone_type
))
return
teacher
,
num_features
class
DINOPretrainModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
DINOPretrainModel
,
self
).
__init__
()
args
=
get_args
()
self
.
out_dim
=
65536
self
.
dino_loss
=
DINOLoss
(
self
.
out_dim
,
args
.
dino_local_crops_number
+
2
,
args
.
dino_warmup_teacher_temp
,
args
.
dino_teacher_temp
,
args
.
dino_warmup_teacher_temp_epochs
,
300
,
)
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
momentum_teacher
=
0.996
student_backbone
,
num_features
=
\
get_student_backbone_and_num_features
(
pre_process
,
post_process
)
self
.
student
=
MultiCropWrapper
(
student_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
,
norm_last_layer
=
args
.
dino_norm_last_layer
)
)
self
.
momentum_schedule
=
cosine_scheduler
(
self
.
momentum_teacher
,
1
,
args
.
train_iters
//
args
.
iter_per_epoch
,
args
.
iter_per_epoch
)
teacher_backbone
,
num_features
=
\
get_teacher_backbone_and_num_features
(
pre_process
,
post_process
)
self
.
teacher
=
MultiCropWrapper
(
teacher_backbone
,
DINOHead
(
num_features
,
self
.
out_dim
)
)
self
.
teacher
.
load_state_dict
(
self
.
student
.
state_dict
())
for
p
in
self
.
teacher
.
parameters
():
if
hasattr
(
p
,
"requires_grad"
)
and
p
.
requires_grad
is
not
None
:
p
.
requires_grad
=
False
def
set_input_tensor
(
self
,
tensor
):
pass
def
forward
(
self
,
input
):
student_output
=
None
if
self
.
training
:
student_output
=
self
.
student
(
input
)
teacher_output
=
self
.
teacher
(
input
[:
2
])
else
:
teacher_output
=
self
.
teacher
(
input
)
return
student_output
,
teacher_output
def
cancel_gradients_last_layer
(
self
,
iteration
):
args
=
get_args
()
epoch
=
iteration
//
args
.
iter_per_epoch
if
epoch
<
args
.
dino_freeze_last_layer
:
for
n
,
p
in
self
.
student
.
named_parameters
():
if
"last_layer"
in
n
:
p
.
grad
=
None
def
update_momentum
(
self
,
iteration
):
with
torch
.
no_grad
():
m
=
self
.
momentum_schedule
[
iteration
]
for
param_q
,
param_k
in
zip
(
self
.
student
.
parameters
(),
self
.
teacher
.
parameters
()):
param_k
.
data
.
mul_
(
m
).
add_
((
1
-
m
)
*
param_q
.
detach
().
data
)
megatron/model/vision/esvit_swin_backbone.py
0 → 100644
View file @
b8428a7f
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Modified by Chunyuan Li (chunyl@microsoft.com)
# Swin Transformer
# --------------------------------------------------------
import
os
import
logging
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
import
torch.distributed
as
dist
from
torch.nn.init
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
import
numpy
as
np
from
math
import
sqrt
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
(
Mlp
,
self
).
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
(
WindowAttention
,
self
).
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2 Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
).
type
(
attn
.
type
())
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn_out
=
attn
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
,
attn_out
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
@
staticmethod
def
compute_macs
(
module
,
input
,
output
):
B
,
N
,
C
=
input
[
0
].
shape
module
.
__flops__
+=
module
.
flops
(
N
)
*
B
class
SwinTransformerBlock
(
nn
.
Module
):
r
"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
(
self
.
window_size
,
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# pad feature maps to multiples of window size
pad_l
=
pad_t
=
0
pad_r
=
(
self
.
window_size
-
W
%
self
.
window_size
)
%
self
.
window_size
pad_b
=
(
self
.
window_size
-
H
%
self
.
window_size
)
%
self
.
window_size
x
=
F
.
pad
(
x
,
(
0
,
0
,
pad_l
,
pad_r
,
pad_t
,
pad_b
))
_
,
Hp
,
Wp
,
_
=
x
.
shape
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
if
H
in
self
.
attn_mask_dict
.
keys
():
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
self
.
attn_mask_dict
[
H
]
=
self
.
create_attn_mask
(
self
.
H
,
self
.
W
).
to
(
x
.
device
)
attn_mask
=
self
.
attn_mask_dict
[
H
]
else
:
shifted_x
=
x
attn_mask
=
None
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
,
attn
=
self
.
attn
(
x_windows
,
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
Hp
,
Wp
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
if
pad_r
>
0
or
pad_b
>
0
:
x
=
x
[:,
:
H
,
:
W
,
:].
contiguous
()
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
,
attn
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
""" Forward function.
Args:
x: Input feature, tensor size (B, H*W, C).
H, W: Spatial resolution of the input feature.
"""
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# padding
pad_input
=
(
H
%
2
==
1
)
or
(
W
%
2
==
1
)
if
pad_input
:
x
=
F
.
pad
(
x
,
(
0
,
0
,
0
,
W
%
2
,
0
,
H
%
2
))
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
"""A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
def
forward_with_features
(
self
,
x
):
fea
=
[]
for
blk
in
self
.
blocks
:
x
,
_
=
blk
(
x
)
fea
.
append
(
x
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
fea
def
forward_with_attention
(
self
,
x
):
attns
=
[]
for
blk
in
self
.
blocks
:
x
,
attn
=
blk
(
x
)
attns
.
append
(
attn
)
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x
,
attns
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size.
patch_size (int | tuple(int)): Patch size.
in_chans (int): Number of input channels.
num_classes (int): Number of classes for classification head.
embed_dim (int): Embedding dimension.
depths (tuple(int)): Depth of Swin Transformer layers.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: Truee
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
drop_rate (float): Dropout rate.
attn_drop_rate (float): Attention dropout rate.
drop_path_rate (float): Stochastic depth rate.
norm_layer (nn.Module): normalization layer.
ape (bool): If True, add absolute position embedding to the patch embedding.
patch_norm (bool): If True, add normalization after patch embedding.
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.1
,
norm_layer
=
nn
.
LayerNorm
,
ape
=
False
,
patch_norm
=
True
,
**
kwargs
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
)
self
.
layers
.
append
(
layer
)
self
.
norm
=
norm_layer
(
self
.
num_features
)
self
.
avgpool
=
nn
.
AdaptiveAvgPool1d
(
1
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
# todo: to be implemented
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_region
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_region
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
def
forward_feature_maps
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
for
layer
in
self
.
layers
:
x
=
layer
(
x
)
x_grid
=
self
.
norm
(
x
)
# B L C
x
=
self
.
avgpool
(
x_grid
.
transpose
(
1
,
2
))
# B C 1
x
=
torch
.
flatten
(
x
,
1
)
return
x
,
x_grid
def
forward_selfattention
(
self
,
x
,
n
=
1
):
# n=1 return the last layer attn map; otherwise return attn maps in all layers
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
if
n
==
1
:
return
self
.
forward_last_selfattention
(
x
)
else
:
return
self
.
forward_all_selfattention
(
x
)
def
forward_last_selfattention
(
self
,
x
):
for
i
,
layer
in
enumerate
(
self
.
layers
):
if
i
<
len
(
self
.
layers
)
-
1
:
x
=
layer
(
x
)
else
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
return
attns
[
-
1
]
def
forward_all_selfattention
(
self
,
x
):
attn_out
=
[]
for
layer
in
self
.
layers
:
x
,
attns
=
layer
.
forward_with_attention
(
x
)
attn_out
+=
attns
return
attn_out
def
forward_return_n_last_blocks
(
self
,
x
,
n
=
1
,
return_patch_avgpool
=
False
,
depth
=
[]):
num_blks
=
sum
(
depth
)
start_idx
=
num_blks
-
n
sum_cur
=
0
for
i
,
d
in
enumerate
(
depth
):
sum_cur_new
=
sum_cur
+
d
if
start_idx
>=
sum_cur
and
start_idx
<
sum_cur_new
:
start_stage
=
i
start_blk
=
start_idx
-
sum_cur
sum_cur
=
sum_cur_new
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
# we will return the averaged token features from the `n` last blocks
# note: there is no [CLS] token in Swin Transformer
output
=
[]
s
=
0
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
,
fea
=
layer
.
forward_with_features
(
x
)
if
i
>=
start_stage
:
for
x_
in
fea
[
start_blk
:]:
if
i
==
len
(
self
.
layers
)
-
1
:
# use the norm in the last stage
x_
=
self
.
norm
(
x_
)
x_avg
=
torch
.
flatten
(
self
.
avgpool
(
x_
.
transpose
(
1
,
2
)),
1
)
# B C
# print(f'Stage {i}, x_avg {x_avg.shape}')
output
.
append
(
x_avg
)
start_blk
=
0
return
torch
.
cat
(
output
,
dim
=-
1
)
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
if
dist
.
get_rank
()
==
0
:
print
(
f
"GFLOPs layer_
{
i
}
:
{
layer
.
flops
()
/
1e9
}
"
)
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
init_weights
(
self
,
pretrained
=
''
,
pretrained_layers
=
[],
verbose
=
True
):
if
os
.
path
.
isfile
(
pretrained
):
pretrained_dict
=
torch
.
load
(
pretrained
,
map_location
=
'cpu'
)
logging
.
info
(
f
'=> loading pretrained model
{
pretrained
}
'
)
model_dict
=
self
.
state_dict
()
pretrained_dict
=
{
k
:
v
for
k
,
v
in
pretrained_dict
.
items
()
if
k
in
model_dict
.
keys
()
}
need_init_state_dict
=
{}
for
k
,
v
in
pretrained_dict
.
items
():
need_init
=
(
k
.
split
(
'.'
)[
0
]
in
pretrained_layers
or
pretrained_layers
[
0
]
is
'*'
or
'relative_position_index'
not
in
k
or
'attn_mask'
not
in
k
)
if
need_init
:
if
verbose
:
logging
.
info
(
f
'=> init
{
k
}
from
{
pretrained
}
'
)
if
'relative_position_bias_table'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
relative_position_bias_table_pretrained
=
v
relative_position_bias_table_current
=
model_dict
[
k
]
L1
,
nH1
=
relative_position_bias_table_pretrained
.
size
()
L2
,
nH2
=
relative_position_bias_table_current
.
size
()
if
nH1
!=
nH2
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
L1
,
nH1
),
(
L2
,
nH2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
relative_position_bias_table_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
relative_position_bias_table_pretrained
.
permute
(
1
,
0
).
view
(
1
,
nH1
,
S1
,
S1
),
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
relative_position_bias_table_pretrained_resized
.
view
(
nH2
,
L2
).
permute
(
1
,
0
)
if
'absolute_pos_embed'
in
k
and
v
.
size
()
!=
model_dict
[
k
].
size
():
absolute_pos_embed_pretrained
=
v
absolute_pos_embed_current
=
model_dict
[
k
]
_
,
L1
,
C1
=
absolute_pos_embed_pretrained
.
size
()
_
,
L2
,
C2
=
absolute_pos_embed_current
.
size
()
if
C1
!=
C1
:
logging
.
info
(
f
"Error in loading
{
k
}
, passing"
)
else
:
if
L1
!=
L2
:
logging
.
info
(
'=> load_pretrained: resized variant: {} to {}'
.
format
((
1
,
L1
,
C1
),
(
1
,
L2
,
C2
))
)
S1
=
int
(
L1
**
0.5
)
S2
=
int
(
L2
**
0.5
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
reshape
(
-
1
,
S1
,
S1
,
C1
)
absolute_pos_embed_pretrained
=
absolute_pos_embed_pretrained
.
permute
(
0
,
3
,
1
,
2
)
absolute_pos_embed_pretrained_resized
=
torch
.
nn
.
functional
.
interpolate
(
absolute_pos_embed_pretrained
,
size
=
(
S2
,
S2
),
mode
=
'bicubic'
)
v
=
absolute_pos_embed_pretrained_resized
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
1
,
2
)
need_init_state_dict
[
k
]
=
v
self
.
load_state_dict
(
need_init_state_dict
,
strict
=
False
)
def
freeze_pretrained_layers
(
self
,
frozen_layers
=
[]):
for
name
,
module
in
self
.
named_modules
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
'.'
.
join
(
name
.
split
(
'.'
)[
0
:
2
])
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
):
for
_name
,
param
in
module
.
named_parameters
():
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
for
name
,
param
in
self
.
named_parameters
():
if
(
name
.
split
(
'.'
)[
0
]
in
frozen_layers
or
(
len
(
frozen_layers
)
>
0
and
frozen_layers
[
0
]
is
'*'
)
and
param
.
requires_grad
is
True
):
param
.
requires_grad
=
False
logging
.
info
(
'=> set param {} requires grad to False'
.
format
(
name
)
)
return
self
def
get_swin
(
is_teacher
=
False
):
args
=
get_args
()
if
args
.
swin_backbone_type
==
"tiny"
:
embed_dim
=
96
depths
=
[
2
,
2
,
6
,
2
]
num_heads
=
[
3
,
6
,
12
,
24
]
drop_path_rate
=
0.1
elif
args
.
swin_backbone_type
==
'h3'
:
embed_dim
=
384
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
6
,
12
,
24
,
48
]
drop_path_rate
=
0.2
else
:
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
drop_path_rate
=
0.2
swin
=
SwinTransformer
(
img_size
=
224
,
in_chans
=
3
,
num_classes
=
1000
,
patch_size
=
4
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
7
,
mlp_ratio
=
4
,
qkv_bias
=
True
,
drop_rate
=
0
,
attn_drop_rate
=
0
,
drop_path_rate
=
(
0.0
if
is_teacher
else
drop_path_rate
),
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
)
return
swin
megatron/model/vision/inpainting.py
0 → 100644
View file @
b8428a7f
# coding=utf-8
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
i
import
math
import
apex
import
einops
import
torch
import
torch.nn.functional
as
F
from
megatron
import
get_args
,
print_rank_0
from
megatron.model.utils
import
get_linear_layer
from
megatron.model.vision.vit_backbone
import
VitBackbone
from
megatron.model.module
import
MegatronModule
from
megatron.model.vision.mit_backbone
import
mit_b3
from
megatron.model.vision.utils
import
resize_
class
VitInpaintingModel
(
MegatronModule
):
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
VitInpaintingModel
,
self
).
__init__
()
args
=
get_args
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
hidden_size
=
args
.
hidden_size
self
.
backbone
=
VitBackbone
(
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
class_token
=
False
,
)
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
seq_length
=
args
.
seq_length
# full mask
if
self
.
post_process
:
self
.
linear_decoder
=
get_linear_layer
(
self
.
hidden_size
,
self
.
backbone
.
flatten_dim
,
torch
.
nn
.
init
.
zeros_
)
def
set_input_tensor
(
self
,
input_tensor
):
self
.
backbone
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
hidden_states
=
self
.
backbone
(
input
)
if
not
self
.
post_process
:
return
hidden_states
decoded_output
=
self
.
linear_decoder
(
hidden_states
)
output
=
einops
.
rearrange
(
decoded_output
,
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
class
MLP
(
torch
.
nn
.
Module
):
"""
Linear Embedding
"""
def
__init__
(
self
,
input_dim
=
2048
,
embed_dim
=
768
):
super
().
__init__
()
self
.
proj
=
torch
.
nn
.
Linear
(
input_dim
,
embed_dim
)
def
forward
(
self
,
x
):
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
MitInpaintingModel
(
MegatronModule
):
"""Mix vision Transformer Model."""
def
__init__
(
self
,
pre_process
=
True
,
post_process
=
True
):
super
(
MitInpaintingModel
,
self
).
__init__
()
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
args
=
get_args
()
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
3
self
.
backbone
=
mit_b3
()
self
.
in_channels
=
[
64
,
128
,
320
,
512
]
self
.
embedding_dim
=
768
c1_in_channels
,
c2_in_channels
,
c3_in_channels
,
c4_in_channels
=
self
.
in_channels
self
.
linear_c4
=
MLP
(
input_dim
=
c4_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c3
=
MLP
(
input_dim
=
c3_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c2
=
MLP
(
input_dim
=
c2_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
linear_c1
=
MLP
(
input_dim
=
c1_in_channels
,
embed_dim
=
self
.
embedding_dim
)
self
.
conv_fuse
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
*
4
,
self
.
embedding_dim
,
1
,
1
,
bias
=
False
)
self
.
norm
=
apex
.
parallel
.
SyncBatchNorm
(
self
.
embedding_dim
)
self
.
dropout
=
torch
.
nn
.
Dropout2d
(
0.1
)
self
.
linear_pred
=
torch
.
nn
.
Conv2d
(
self
.
embedding_dim
,
self
.
flatten_dim
,
kernel_size
=
1
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
pass
def
forward
(
self
,
input
):
c1
,
c2
,
c3
,
c4
=
self
.
backbone
(
input
)
n
,
_
,
h
,
w
=
c4
.
shape
_c4
=
self
.
linear_c4
(
c4
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c4
.
shape
[
2
],
c4
.
shape
[
3
])
_c4
=
resize
(
_c4
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c3
=
self
.
linear_c3
(
c3
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c3
.
shape
[
2
],
c3
.
shape
[
3
])
_c3
=
resize
(
_c3
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c2
=
self
.
linear_c2
(
c2
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c2
.
shape
[
2
],
c2
.
shape
[
3
])
_c2
=
resize
(
_c2
,
size
=
c1
.
size
()[
2
:],
mode
=
'bilinear'
,
align_corners
=
False
)
_c1
=
self
.
linear_c1
(
c1
).
permute
(
0
,
2
,
1
).
reshape
(
n
,
-
1
,
c1
.
shape
[
2
],
c1
.
shape
[
3
])
_c
=
torch
.
cat
([
_c4
,
_c3
,
_c2
,
_c1
],
dim
=
1
)
_c
=
self
.
conv_fuse
(
_c
)
x
=
self
.
norm
(
_c
)
x
=
F
.
relu
(
x
,
inplace
=
True
)
x
=
self
.
dropout
(
x
)
x
=
self
.
linear_pred
(
x
)
output
=
einops
.
rearrange
(
x
,
"b (c p1 p2) h w -> b c (h p1) (w p2)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
h
=
self
.
img_h
//
self
.
patch_dim
,
w
=
self
.
img_w
//
self
.
patch_dim
,
)
return
output
megatron/model/vision/knn_monitor.py
0 → 100644
View file @
b8428a7f
import
torch.nn.functional
as
F
import
torch
from
megatron
import
print_rank_0
,
get_args
,
mpu
from
megatron.data.vit_dataset
import
ClassificationTransform
from
megatron.data.image_folder
import
ImageFolder
_FEATURE_BANK
=
None
def
build_data_loader
(
dataset
,
drop_last
=
True
,
shuffle
=
False
):
"""Data loader. Note that batch-size is the local (per GPU) batch-size."""
# Sampler.
args
=
get_args
()
micro_batch_size
=
16
num_workers
=
args
.
num_workers
world_size
=
mpu
.
get_data_parallel_world_size
()
rank
=
mpu
.
get_data_parallel_rank
()
sampler
=
torch
.
utils
.
data
.
distributed
.
DistributedSampler
(
dataset
,
num_replicas
=
world_size
,
rank
=
rank
,
drop_last
=
drop_last
,
shuffle
=
shuffle
)
# Data loader. Note that batch size is the per GPU batch size.
data_loader
=
torch
.
utils
.
data
.
DataLoader
(
dataset
,
batch_size
=
micro_batch_size
,
sampler
=
sampler
,
shuffle
=
False
,
num_workers
=
num_workers
,
drop_last
=
not
drop_last
,
pin_memory
=
True
,
)
return
data_loader
def
compute_feature_bank
(
model
):
args
=
get_args
()
global
_FEATURE_BANK
feature_bank
=
[]
feature_label
=
[]
train_ds
=
ImageFolder
(
root
=
args
.
data_path
[
0
],
transform
=
ClassificationTransform
((
args
.
img_h
,
args
.
img_w
),
train
=
False
),
data_per_class_fraction
=
1.0
)
classes
=
len
(
train_ds
.
classes
)
dataloader
=
build_data_loader
(
train_ds
)
for
m
in
model
:
m
.
eval
()
with
torch
.
no_grad
():
for
i
,
batch
in
enumerate
(
dataloader
):
images
=
batch
[
0
].
cuda
().
contiguous
()
labels
=
batch
[
1
].
cuda
().
contiguous
()
student_feature
,
teacher_feature
=
model
[
0
](
images
)
feature
=
F
.
normalize
(
teacher_feature
.
float
(),
dim
=
1
)
feature_bank
.
append
(
feature
)
feature_label
.
append
(
labels
)
for
m
in
model
:
m
.
train
()
# [N', D]
feature_bank
=
torch
.
cat
(
feature_bank
,
dim
=
0
).
contiguous
()
feature_label
=
torch
.
cat
(
feature_label
,
dim
=
0
).
contiguous
()
feature_banks
=
[
torch
.
zeros_like
(
feature_bank
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_banks
,
feature_bank
,
group
=
mpu
.
get_data_parallel_group
())
assert
torch
.
all
(
torch
.
eq
(
feature_banks
[
mpu
.
get_data_parallel_rank
()],
feature_bank
))
feature_labels
=
[
torch
.
zeros_like
(
feature_label
)
for
i
in
range
(
mpu
.
get_data_parallel_world_size
())]
torch
.
distributed
.
all_gather
(
feature_labels
,
feature_label
,
group
=
mpu
.
get_data_parallel_group
())
# [D, N]
feature_banks
=
torch
.
cat
(
feature_banks
,
dim
=
0
).
t
().
contiguous
()
# [N]
feature_labels
=
torch
.
cat
(
feature_labels
,
dim
=
0
).
contiguous
()
print_rank_0
(
"feature_banks size is {}"
.
format
(
feature_banks
.
size
()))
print_rank_0
(
"feature labels size is {}"
.
format
(
feature_labels
.
size
()))
_FEATURE_BANK
=
(
feature_banks
,
feature_labels
,
classes
)
def
get_feature_bank
():
global
_FEATURE_BANK
assert
_FEATURE_BANK
is
not
None
return
_FEATURE_BANK
# knn monitor as in InstDisc https://arxiv.org/abs/1805.01978
# implementation follows http://github.com/zhirongw/lemniscate.pytorch and
# https://github.com/leftthomas/SimCLR
def
knn_predict
(
feature
,
feature_bank
,
feature_labels
,
classes
,
knn_k
,
knn_t
):
# compute cos similarity between each feature vector and feature bank ---> [B, N]
sim_matrix
=
torch
.
mm
(
feature
,
feature_bank
)
# [B, K]
sim_weight
,
sim_indices
=
sim_matrix
.
topk
(
k
=
knn_k
,
dim
=-
1
)
# [B, K]
sim_labels
=
torch
.
gather
(
feature_labels
.
expand
(
feature
.
size
(
0
),
-
1
),
dim
=-
1
,
index
=
sim_indices
)
sim_weight
=
(
sim_weight
/
knn_t
).
exp
()
# counts for each class
one_hot_label
=
torch
.
zeros
(
feature
.
size
(
0
)
*
knn_k
,
classes
,
device
=
sim_labels
.
device
)
# [B*K, C]
one_hot_label
=
one_hot_label
.
scatter
(
dim
=-
1
,
index
=
sim_labels
.
view
(
-
1
,
1
),
value
=
1.0
)
# weighted score ---> [B, C]
pred_scores
=
torch
.
sum
(
one_hot_label
.
view
(
feature
.
size
(
0
),
-
1
,
classes
)
*
sim_weight
.
unsqueeze
(
dim
=-
1
),
dim
=
1
)
pred_labels
=
pred_scores
.
argsort
(
dim
=-
1
,
descending
=
True
)
return
pred_labels
megatron/model/vision/mit_backbone.py
0 → 100644
View file @
b8428a7f
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
torch.nn.init
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron.model
import
LayerNorm
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
dwconv
=
DWConv
(
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
,
H
,
W
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
sr_ratio
=
1
):
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
f
"dim
{
dim
}
should be divided by num_heads
{
num_heads
}
."
self
.
dim
=
dim
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
q
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
self
.
sr
=
nn
.
Conv2d
(
dim
,
dim
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
self
.
norm
=
LayerNorm
(
dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x_
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
H
,
W
)
x_
=
self
.
sr
(
x_
).
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
x_
=
self
.
norm
(
x_
)
kv
=
self
.
kv
(
x_
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
else
:
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
sr_ratio
=
1
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
sr_ratio
=
sr_ratio
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
H
,
W
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
),
H
,
W
))
return
x
class
OverlapPatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
(
patch_size
[
0
]
//
2
,
patch_size
[
1
]
//
2
))
self
.
norm
=
LayerNorm
(
embed_dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
return
x
,
H
,
W
class
MixVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
,
512
],
num_heads
=
[
1
,
2
,
4
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
LayerNorm
,
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
output_avg
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
depths
=
depths
self
.
output_avg
=
output_avg
# patch_embed
self
.
patch_embed1
=
OverlapPatchEmbed
(
img_size
=
img_size
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
in_chans
,
embed_dim
=
embed_dims
[
0
])
self
.
patch_embed2
=
OverlapPatchEmbed
(
img_size
=
img_size
//
4
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
0
],
embed_dim
=
embed_dims
[
1
])
self
.
patch_embed3
=
OverlapPatchEmbed
(
img_size
=
img_size
//
8
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
1
],
embed_dim
=
embed_dims
[
2
])
self
.
patch_embed4
=
OverlapPatchEmbed
(
img_size
=
img_size
//
16
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
2
],
embed_dim
=
embed_dims
[
3
])
# transformer encoder
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
cur
=
0
self
.
block1
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
0
],
num_heads
=
num_heads
[
0
],
mlp_ratio
=
mlp_ratios
[
0
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
0
])
for
i
in
range
(
depths
[
0
])])
self
.
norm1
=
norm_layer
(
embed_dims
[
0
])
cur
+=
depths
[
0
]
self
.
block2
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
1
],
num_heads
=
num_heads
[
1
],
mlp_ratio
=
mlp_ratios
[
1
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
1
])
for
i
in
range
(
depths
[
1
])])
self
.
norm2
=
norm_layer
(
embed_dims
[
1
])
cur
+=
depths
[
1
]
self
.
block3
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
2
],
num_heads
=
num_heads
[
2
],
mlp_ratio
=
mlp_ratios
[
2
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
2
])
for
i
in
range
(
depths
[
2
])])
self
.
norm3
=
norm_layer
(
embed_dims
[
2
])
cur
+=
depths
[
2
]
self
.
block4
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
3
],
num_heads
=
num_heads
[
3
],
mlp_ratio
=
mlp_ratios
[
3
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
3
])
for
i
in
range
(
depths
[
3
])])
self
.
norm4
=
norm_layer
(
embed_dims
[
3
])
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
reset_drop_path
(
self
,
drop_path_rate
):
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))]
cur
=
0
for
i
in
range
(
self
.
depths
[
0
]):
self
.
block1
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
0
]
for
i
in
range
(
self
.
depths
[
1
]):
self
.
block2
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
1
]
for
i
in
range
(
self
.
depths
[
2
]):
self
.
block3
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
2
]
for
i
in
range
(
self
.
depths
[
3
]):
self
.
block4
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
def
freeze_patch_emb
(
self
):
self
.
patch_embed1
.
requires_grad
=
False
def
forward_features
(
self
,
x
):
B
=
x
.
shape
[
0
]
outs
=
[]
# stage 1
x
,
H
,
W
=
self
.
patch_embed1
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block1
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm1
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 2
x
,
H
,
W
=
self
.
patch_embed2
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block2
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm2
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 3
x
,
H
,
W
=
self
.
patch_embed3
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block3
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm3
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 4
x
,
H
,
W
=
self
.
patch_embed4
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block4
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm4
(
x
)
if
not
self
.
output_avg
:
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
return
outs
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
output_avg
:
x
=
x
[
3
].
mean
(
dim
=
1
)
return
x
class
DWConv
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
768
):
super
(
DWConv
,
self
).
__init__
()
self
.
dwconv
=
nn
.
Conv2d
(
dim
,
dim
,
3
,
1
,
1
,
bias
=
True
,
groups
=
dim
)
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
x
=
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
)
x
=
self
.
dwconv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
mit_b0
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b0
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
32
,
64
,
160
,
256
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b1
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b1
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b2
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b2
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b3
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b3_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
class
mit_b4
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b4
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
8
,
27
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b5
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b5_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
megatron/model/vision/swin_backbone.py
0 → 100644
View file @
b8428a7f
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Swin Transformer
# --------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
checkpoint
from
timm.models.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
from
math
import
sqrt
from
megatron
import
get_args
from
functools
import
partial
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
class
SwinTransformerBlock
(
nn
.
Module
):
r
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
mask
=
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
, mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
f
"x size (
{
H
}
*
{
W
}
) are not even."
x
=
x
.
view
(
B
,
H
,
W
,
C
)
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
,
use_checkpoint
=
False
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
if
self
.
use_checkpoint
:
x
=
checkpoint
.
checkpoint
(
blk
,
x
)
else
:
x
=
blk
(
x
)
x_b4_ds
=
x
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x_b4_ds
,
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
r
""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.3
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
use_checkpoint
=
False
,
output_avg
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
img_size
=
to_2tuple
(
img_size
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
self
.
output_avg
=
output_avg
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
use_checkpoint
=
use_checkpoint
)
self
.
layers
.
append
(
layer
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
h
=
self
.
img_size
[
0
]
//
self
.
patch_size
[
0
]
w
=
self
.
img_size
[
1
]
//
self
.
patch_size
[
1
]
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
px
,
x
=
layer
(
x
)
b
,
n
,
c
=
px
.
shape
if
i
!=
len
(
self
.
layers
)
-
1
or
not
self
.
output_avg
:
px
=
px
.
permute
(
0
,
2
,
1
).
contiguous
()
px
=
px
.
reshape
(
b
,
c
,
h
,
w
)
# is this a fair assumption ?? i think it's baked into the architecture
h
,
w
=
h
//
2
,
w
//
2
outs
.
append
(
px
)
if
self
.
output_avg
:
return
outs
[
-
1
].
mean
(
dim
=
1
)
return
outs
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
get_swin
(
drop_path_rate
=
0.3
,
output_avg
=
False
):
args
=
get_args
()
window_size
=
7
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
swin
=
SwinTransformer
(
img_size
=
(
args
.
img_h
,
args
.
img_w
,),
in_chans
=
3
,
patch_size
=
args
.
patch_dim
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
window_size
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
output_avg
,
)
return
swin
megatron/model/vision/utils.py
0 → 100644
View file @
b8428a7f
import
warnings
import
torch
import
torch.nn.functional
as
F
def
resize
(
input
,
size
=
None
,
scale_factor
=
None
,
mode
=
'nearest'
,
align_corners
=
None
,
warning
=
True
):
if
warning
:
if
size
is
not
None
and
align_corners
:
input_h
,
input_w
=
tuple
(
int
(
x
)
for
x
in
input
.
shape
[
2
:])
output_h
,
output_w
=
tuple
(
int
(
x
)
for
x
in
size
)
if
output_h
>
input_h
or
output_w
>
output_h
:
if
((
output_h
>
1
and
output_w
>
1
and
input_h
>
1
and
input_w
>
1
)
and
(
output_h
-
1
)
%
(
input_h
-
1
)
and
(
output_w
-
1
)
%
(
input_w
-
1
)):
warnings
.
warn
(
f
'When align_corners=
{
align_corners
}
, '
'the output would more aligned if '
f
'input size
{
(
input_h
,
input_w
)
}
is `x+1` and '
f
'out size
{
(
output_h
,
output_w
)
}
is `nx+1`'
)
if
isinstance
(
size
,
torch
.
Size
):
size
=
tuple
(
int
(
x
)
for
x
in
size
)
return
F
.
interpolate
(
input
,
size
,
scale_factor
,
mode
,
align_corners
)
megatron/model/vision/vit_backbone.py
View file @
b8428a7f
...
@@ -21,7 +21,6 @@ import torch
...
@@ -21,7 +21,6 @@ import torch
import
apex
import
apex
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron
import
get_args
from
megatron.model
import
LayerNorm
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
from
megatron.model.utils
import
(
get_linear_layer
,
get_linear_layer
,
...
@@ -147,7 +146,9 @@ class VitBackbone(MegatronModule):
...
@@ -147,7 +146,9 @@ class VitBackbone(MegatronModule):
pre_process
=
True
,
pre_process
=
True
,
post_process
=
True
,
post_process
=
True
,
class_token
=
True
,
class_token
=
True
,
single_token_output
=
False
):
single_token_output
=
False
,
post_layer_norm
=
True
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
super
(
VitBackbone
,
self
).
__init__
(
share_word_embeddings
=
False
)
args
=
get_args
()
args
=
get_args
()
...
@@ -164,12 +165,14 @@ class VitBackbone(MegatronModule):
...
@@ -164,12 +165,14 @@ class VitBackbone(MegatronModule):
self
.
pre_process
=
pre_process
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
post_process
=
post_process
self
.
class_token
=
class_token
self
.
class_token
=
class_token
self
.
post_layer_norm
=
post_layer_norm
self
.
hidden_size
=
args
.
hidden_size
self
.
hidden_size
=
args
.
hidden_size
self
.
patch_dim
=
args
.
patch_dim
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
img_w
=
args
.
img_w
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
single_token_output
=
single_token_output
self
.
single_token_output
=
single_token_output
self
.
drop_path_rate
=
drop_path_rate
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
...
@@ -216,6 +219,8 @@ class VitBackbone(MegatronModule):
...
@@ -216,6 +219,8 @@ class VitBackbone(MegatronModule):
self
.
scaled_init_method
,
self
.
scaled_init_method
,
pre_process
=
self
.
pre_process
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_process
=
self
.
post_process
,
post_layer_norm
=
self
.
post_layer_norm
,
drop_path_rate
=
self
.
drop_path_rate
)
)
def
set_input_tensor
(
self
,
input_tensor
):
def
set_input_tensor
(
self
,
input_tensor
):
...
...
megatron/mpu/__init__.py
View file @
b8428a7f
...
@@ -49,17 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
...
@@ -49,17 +49,21 @@ from .initialize import get_virtual_pipeline_model_parallel_rank, set_virtual_pi
from
.initialize
import
initialize_model_parallel
from
.initialize
import
initialize_model_parallel
from
.initialize
import
model_parallel_is_initialized
from
.initialize
import
model_parallel_is_initialized
from
.layers
import
LinearWithGradAccumulationAndAsyncCommunication
from
.layers
import
ColumnParallelLinear
from
.layers
import
ColumnParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
RowParallelLinear
from
.layers
import
VocabParallelEmbedding
from
.layers
import
VocabParallelEmbedding
from
.layers
import
(
set_tensor_model_parallel_attributes
,
from
.layers
import
(
set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
set_defaults_if_not_set_tensor_model_parallel_attributes
,
copy_tensor_model_parallel_attributes
)
copy_tensor_model_parallel_attributes
)
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_sequence_parallel_region
from
.mappings
import
gather_from_sequence_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.random
import
checkpoint
from
.random
import
checkpoint
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
...
...
megatron/mpu/initialize.py
View file @
b8428a7f
...
@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
...
@@ -54,6 +54,12 @@ _POSITION_EMBEDDING_GLOBAL_RANKS = None
# rank when broadcasting from the first or last pipeline stage.
# rank when broadcasting from the first or last pipeline stage.
_PIPELINE_GLOBAL_RANKS
=
None
_PIPELINE_GLOBAL_RANKS
=
None
# A list of global ranks for each data parallel group to ease calculation of the source
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
None
def
is_unitialized
():
def
is_unitialized
():
"""Useful for code segments that may be accessed with or without mpu initialization"""
"""Useful for code segments that may be accessed with or without mpu initialization"""
return
_DATA_PARALLEL_GROUP
is
None
return
_DATA_PARALLEL_GROUP
is
None
...
@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -124,6 +130,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
# Build the data-parallel groups.
# Build the data-parallel groups.
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GROUP
global
_DATA_PARALLEL_GLOBAL_RANKS
assert
_DATA_PARALLEL_GROUP
is
None
,
\
assert
_DATA_PARALLEL_GROUP
is
None
,
\
'data parallel group is already initialized'
'data parallel group is already initialized'
all_data_parallel_group_ranks
=
[]
all_data_parallel_group_ranks
=
[]
...
@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
...
@@ -137,6 +144,7 @@ def initialize_model_parallel(tensor_model_parallel_size_=1,
group
=
torch
.
distributed
.
new_group
(
ranks
)
group
=
torch
.
distributed
.
new_group
(
ranks
)
if
rank
in
ranks
:
if
rank
in
ranks
:
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GROUP
=
group
_DATA_PARALLEL_GLOBAL_RANKS
=
ranks
# Build the model-parallel groups.
# Build the model-parallel groups.
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
...
@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
...
@@ -478,11 +486,10 @@ def get_tensor_model_parallel_src_rank():
def
get_data_parallel_src_rank
():
def
get_data_parallel_src_rank
():
"""Calculate the global rank corresponding to the first local rank
"""Calculate the global rank corresponding to the first local rank
in the tensor model parallel group."""
in the data parallel group."""
global_rank
=
torch
.
distributed
.
get_rank
()
assert
_DATA_PARALLEL_GLOBAL_RANKS
is
not
None
,
\
data_parallel_size
=
get_data_parallel_world_size
()
"Data parallel group is not initialized"
num_data_parallel_groups
=
torch
.
distributed
.
get_world_size
()
//
data_parallel_size
return
_DATA_PARALLEL_GLOBAL_RANKS
[
0
]
return
global_rank
%
num_data_parallel_groups
def
get_pipeline_model_parallel_first_rank
():
def
get_pipeline_model_parallel_first_rank
():
...
...
megatron/mpu/layers.py
View file @
b8428a7f
...
@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size
...
@@ -30,20 +30,21 @@ from .initialize import get_tensor_model_parallel_world_size
from
.initialize
import
get_tensor_model_parallel_group
from
.initialize
import
get_tensor_model_parallel_group
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
copy_to_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_from_tensor_model_parallel_region
from
.mappings
import
gather_from_sequence_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
reduce_from_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
scatter_to_tensor_model_parallel_region
from
.mappings
import
reduce_scatter_to_sequence_parallel_region
from
.random
import
get_cuda_rng_tracker
from
.random
import
get_cuda_rng_tracker
from
.utils
import
divide
from
.utils
import
divide
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
split_tensor_along_last_dim
from
.utils
import
VocabUtility
from
.utils
import
VocabUtility
from
megatron
import
get_args
from
megatron
import
get_args
,
get_global_memory_buffer
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
_MODEL_PARALLEL_ATTRIBUTE_DEFAULTS
=
{
'tensor_model_parallel'
:
False
,
'partition_dim'
:
-
1
,
'partition_dim'
:
-
1
,
'partition_stride'
:
1
}
'partition_stride'
:
1
}
def
param_is_not_tensor_parallel_duplicate
(
param
):
def
param_is_not_tensor_parallel_duplicate
(
param
):
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
return
(
hasattr
(
param
,
'tensor_model_parallel'
)
and
param
.
tensor_model_parallel
)
or
(
param
.
tensor_model_parallel
)
or
(
...
@@ -199,16 +200,37 @@ class VocabParallelEmbedding(torch.nn.Module):
...
@@ -199,16 +200,37 @@ class VocabParallelEmbedding(torch.nn.Module):
return
output
return
output
class
ColumnParallelLinearWithAsyncAllreduce
(
torch
.
autograd
.
Function
):
class
LinearWithGradAccumulationAndAsyncCommunication
(
torch
.
autograd
.
Function
):
"""
"""
Column-parallel l
inear layer execution with asynchronous
all-reduce
L
inear layer execution with asynchronous
communication and gradient accumulation
execut
ion in backprop.
fus
ion in backprop.
"""
"""
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
):
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
async_grad_allreduce
,
sequence_parallel
):
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
save_for_backward
(
input
,
weight
)
ctx
.
use_bias
=
bias
is
not
None
ctx
.
use_bias
=
bias
is
not
None
output
=
torch
.
matmul
(
input
,
weight
.
t
())
ctx
.
gradient_accumulation_fusion
=
gradient_accumulation_fusion
ctx
.
async_grad_allreduce
=
async_grad_allreduce
ctx
.
sequence_parallel
=
sequence_parallel
if
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
())
total_input
=
all_gather_buffer
else
:
total_input
=
input
output
=
torch
.
matmul
(
total_input
,
weight
.
t
())
if
bias
is
not
None
:
if
bias
is
not
None
:
output
=
output
+
bias
output
=
output
+
bias
return
output
return
output
...
@@ -217,17 +239,75 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
...
@@ -217,17 +239,75 @@ class ColumnParallelLinearWithAsyncAllreduce(torch.autograd.Function):
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input
,
weight
=
ctx
.
saved_tensors
input
,
weight
=
ctx
.
saved_tensors
use_bias
=
ctx
.
use_bias
use_bias
=
ctx
.
use_bias
if
ctx
.
sequence_parallel
:
world_size
=
get_tensor_model_parallel_world_size
()
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
get_global_memory_buffer
().
get_tensor
(
dim_size
,
input
.
dtype
,
"mpu"
)
handle
=
torch
.
distributed
.
_all_gather_base
(
all_gather_buffer
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
total_input
=
all_gather_buffer
else
:
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
# Asyncronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
if
ctx
.
sequence_parallel
:
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
handle
.
wait
()
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
# Convert the tensor shapes to 2D for execution compatibility
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
grad_output
=
grad_output
.
view
(
grad_output
.
shape
[
0
]
*
grad_output
.
shape
[
1
],
grad_weight
=
grad_output
.
t
().
matmul
(
input
)
grad_output
.
shape
[
2
])
total_input
=
total_input
.
view
(
total_input
.
shape
[
0
]
*
total_input
.
shape
[
1
],
total_input
.
shape
[
2
])
if
ctx
.
async_grad_allreduce
:
# Asynchronous all-reduce
handle
=
torch
.
distributed
.
all_reduce
(
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# all-reduce scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
sequence_parallel
:
assert
not
ctx
.
async_grad_allreduce
dim_size
=
list
(
input
.
size
())
sub_grad_input
=
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
# reduce_scatter
handle
=
torch
.
distributed
.
_reduce_scatter_base
(
sub_grad_input
,
grad_input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of weight gradient computation shortly (3us) to have
# reduce scatter scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
if
ctx
.
gradient_accumulation_fusion
:
import
fused_dense_cuda
fused_dense_cuda
.
wgrad_gemm_accum_fp32
(
total_input
,
grad_output
,
weight
.
main_grad
)
grad_weight
=
None
else
:
grad_weight
=
grad_output
.
t
().
matmul
(
total_input
)
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
grad_bias
=
grad_output
.
sum
(
dim
=
0
)
if
use_bias
else
None
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
if
ctx
.
sequence_parallel
:
handle
.
wait
()
return
sub_grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
if
ctx
.
async_grad_allreduce
:
handle
.
wait
()
return
grad_input
,
grad_weight
,
grad_bias
,
None
,
None
,
None
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
class
ColumnParallelLinear
(
torch
.
nn
.
Module
):
...
@@ -240,7 +320,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -240,7 +320,7 @@ class ColumnParallelLinear(torch.nn.Module):
input_size: first dimension of matrix A.
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
bias: If true, add bias
gather_output: If true, call all-gather on output and make Y avaiable
gather_output: If true, call all-gather on output and make Y avai
l
able
to all GPUs, otherwise, every GPU will have its output
to all GPUs, otherwise, every GPU will have its output
which is Y_i = XA_i
which is Y_i = XA_i
init_method: method to initialize weights. Note that bias is always set
init_method: method to initialize weights. Note that bias is always set
...
@@ -305,31 +385,30 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -305,31 +385,30 @@ class ColumnParallelLinear(torch.nn.Module):
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
async_tensor_model_parallel_allreduce
=
(
self
.
async_tensor_model_parallel_allreduce
=
(
not
args
.
no_
async_tensor_model_parallel_allreduce
and
args
.
async_tensor_model_parallel_allreduce
and
world_size
>
1
)
world_size
>
1
)
self
.
sequence_parallel
=
(
args
.
sequence_parallel
and
world_size
>
1
)
assert
not
self
.
async_tensor_model_parallel_allreduce
or
\
not
self
.
sequence_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
def
forward
(
self
,
input_
):
def
forward
(
self
,
input_
):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
if
self
.
async_tensor_model_parallel_allreduce
:
if
self
.
async_tensor_model_parallel_allreduce
or
\
input_shape
=
input_
.
shape
self
.
sequence_parallel
:
input_
=
input_
.
view
(
input_shape
[
0
]
*
input_shape
[
1
],
input_shape
[
2
])
input_parallel
=
input_
# Maxtrix multiply with asynchronouse all-reduce execution
output_parallel
=
ColumnParallelLinearWithAsyncAllreduce
.
apply
(
input_
,
self
.
weight
,
bias
)
output_parallel
=
output_parallel
.
view
(
input_shape
[
0
],
input_shape
[
1
],
output_parallel
.
shape
[
1
])
else
:
else
:
# Set up backprop all-reduce.
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
copy_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix
multiply
.
output_parallel
=
LinearWithGradAccu
mul
a
ti
onAndAsyncCommunication
.
ap
ply
(
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
,
bias
)
input_parallel
,
self
.
weight
,
bias
,
self
.
gradient_accumulation_fusion
,
self
.
async_tensor_model_parallel_allreduce
,
self
.
sequence_parallel
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
assert
not
self
.
sequence_parallel
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
output
=
gather_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
=
output_parallel
...
@@ -410,11 +489,15 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -410,11 +489,15 @@ class RowParallelLinear(torch.nn.Module):
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
self
.
output_size
,
device
=
torch
.
cuda
.
current_device
(),
dtype
=
args
.
params_dtype
))
dtype
=
args
.
params_dtype
))
setattr
(
self
.
bias
,
'sequence_parallel'
,
args
.
sequence_parallel
)
# Always initialize bias to zero.
# Always initialize bias to zero.
with
torch
.
no_grad
():
with
torch
.
no_grad
():
self
.
bias
.
zero_
()
self
.
bias
.
zero_
()
else
:
else
:
self
.
register_parameter
(
'bias'
,
None
)
self
.
register_parameter
(
'bias'
,
None
)
self
.
sequence_parallel
=
args
.
sequence_parallel
self
.
gradient_accumulation_fusion
=
args
.
gradient_accumulation_fusion
...
@@ -423,11 +506,17 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -423,11 +506,17 @@ class RowParallelLinear(torch.nn.Module):
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
assert
not
self
.
sequence_parallel
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_parallel
=
LinearWithGradAccumulationAndAsyncCommunication
.
apply
(
input_parallel
,
self
.
weight
,
None
,
self
.
gradient_accumulation_fusion
,
None
,
None
)
# All-reduce across all the partitions.
# All-reduce across all the partitions.
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
self
.
sequence_parallel
:
output_
=
reduce_scatter_to_sequence_parallel_region
(
output_parallel
)
else
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
output_bias
=
None
...
...
megatron/mpu/mappings.py
View file @
b8428a7f
...
@@ -32,13 +32,13 @@ def _reduce(input_):
...
@@ -32,13 +32,13 @@ def _reduce(input_):
return
input_
return
input_
def
_split
(
input_
):
def
_split
_along_last_dim
(
input_
):
"""Split the tensor along its last dimension and keep the
"""Split the tensor along its last dimension and keep the
corresponding slice."""
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Split along last dimension.
# Split along last dimension.
...
@@ -51,12 +51,34 @@ def _split(input_):
...
@@ -51,12 +51,34 @@ def _split(input_):
return
output
return
output
def
_gather
(
input_
):
def
_split_along_first_dim
(
input_
):
"""Split the tensor along its first dimension and keep the
corresponding slice."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
# Split along first dimension.
dim_size
=
input_
.
size
()[
0
]
assert
dim_size
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size
=
dim_size
//
world_size
rank
=
get_tensor_model_parallel_rank
()
dim_offset
=
rank
*
local_dim_size
output
=
input_
[
dim_offset
:
dim_offset
+
local_dim_size
].
contiguous
()
return
output
def
_gather_along_last_dim
(
input_
):
"""Gather tensors and concatinate along the last dimension."""
"""Gather tensors and concatinate along the last dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
if
world_size
==
1
:
return
input_
return
input_
# Size and dimension.
# Size and dimension.
...
@@ -73,6 +95,44 @@ def _gather(input_):
...
@@ -73,6 +95,44 @@ def _gather(input_):
return
output
return
output
def
_gather_along_first_dim
(
input_
):
"""Gather tensors and concatinate along the first dimension."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_all_gather_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
return
output
def
_reduce_scatter_along_first_dim
(
input_
):
"""Reduce-scatter the input tensor across model parallel group."""
world_size
=
get_tensor_model_parallel_world_size
()
# Bypass the function if we are using only 1 GPU.
if
world_size
==
1
:
return
input_
dim_size
=
list
(
input_
.
size
())
assert
dim_size
[
0
]
%
world_size
==
0
,
\
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size
[
0
]
=
dim_size
[
0
]
//
world_size
output
=
torch
.
empty
(
dim_size
,
dtype
=
input_
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
_reduce_scatter_base
(
output
,
input_
.
contiguous
(),
group
=
get_tensor_model_parallel_group
())
return
output
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_CopyToModelParallelRegion
(
torch
.
autograd
.
Function
):
"""Pass the input to the model parallel region."""
"""Pass the input to the model parallel region."""
...
@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
...
@@ -110,15 +170,15 @@ class _ScatterToModelParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_split
(
input_
)
return
_split
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
_gather
(
grad_output
)
return
_gather
_along_last_dim
(
grad_output
)
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
class
_GatherFromModelParallelRegion
(
torch
.
autograd
.
Function
):
...
@@ -126,15 +186,73 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
...
@@ -126,15 +186,73 @@ class _GatherFromModelParallelRegion(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
symbolic
(
graph
,
input_
):
def
symbolic
(
graph
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input_
):
def
forward
(
ctx
,
input_
):
return
_gather
(
input_
)
return
_gather
_along_last_dim
(
input_
)
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
return
_split
(
grad_output
)
return
_split_along_last_dim
(
grad_output
)
class
_ScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Split the input and keep only the corresponding chuck to the rank."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_split_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
class
_GatherFromSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Gather the input from sequence parallel region and concatinate."""
@
staticmethod
def
symbolic
(
graph
,
input_
,
tensor_parallel_output_grad
=
True
):
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
,
tensor_parallel_output_grad
=
True
):
ctx
.
tensor_parallel_output_grad
=
tensor_parallel_output_grad
return
_gather_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
tensor_parallel_output_grad
=
ctx
.
tensor_parallel_output_grad
# If the computation graph after the gather operation is
# in the tensor parallel mode, output gradients need to reduce
# scattered and whereas if the computation is duplicated,
# output gradients need to be scattered.
if
tensor_parallel_output_grad
:
return
_reduce_scatter_along_first_dim
(
grad_output
),
None
else
:
return
_split_along_first_dim
(
grad_output
),
None
class
_ReduceScatterToSequenceParallelRegion
(
torch
.
autograd
.
Function
):
"""Reduce scatter the input from the model parallel region."""
@
staticmethod
def
symbolic
(
graph
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
forward
(
ctx
,
input_
):
return
_reduce_scatter_along_first_dim
(
input_
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
return
_gather_along_first_dim
(
grad_output
)
# -----------------
# -----------------
...
@@ -155,3 +273,16 @@ def scatter_to_tensor_model_parallel_region(input_):
...
@@ -155,3 +273,16 @@ def scatter_to_tensor_model_parallel_region(input_):
def
gather_from_tensor_model_parallel_region
(
input_
):
def
gather_from_tensor_model_parallel_region
(
input_
):
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
return
_GatherFromModelParallelRegion
.
apply
(
input_
)
def
scatter_to_sequence_parallel_region
(
input_
):
return
_ScatterToSequenceParallelRegion
.
apply
(
input_
)
def
gather_from_sequence_parallel_region
(
input_
,
tensor_parallel_output_grad
=
True
):
return
_GatherFromSequenceParallelRegion
.
apply
(
input_
,
tensor_parallel_output_grad
)
def
reduce_scatter_to_sequence_parallel_region
(
input_
):
return
_ReduceScatterToSequenceParallelRegion
.
apply
(
input_
)
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment