Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
huaerkl
fairseq-data2vec_pytorch
Commits
72f5785f
Commit
72f5785f
authored
Aug 15, 2023
by
huaerkl
Browse files
v1.0
parents
Pipeline
#505
canceled with stages
Changes
508
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6936 additions
and
0 deletions
+6936
-0
examples/data2vec/data/path_dataset.py
examples/data2vec/data/path_dataset.py
+64
-0
examples/data2vec/fb_convert_beit_cp.py
examples/data2vec/fb_convert_beit_cp.py
+165
-0
examples/data2vec/models/__init__.py
examples/data2vec/models/__init__.py
+0
-0
examples/data2vec/models/audio_classification.py
examples/data2vec/models/audio_classification.py
+614
-0
examples/data2vec/models/data2vec2.py
examples/data2vec/models/data2vec2.py
+813
-0
examples/data2vec/models/data2vec_audio.py
examples/data2vec/models/data2vec_audio.py
+537
-0
examples/data2vec/models/data2vec_image_classification.py
examples/data2vec/models/data2vec_image_classification.py
+143
-0
examples/data2vec/models/data2vec_text.py
examples/data2vec/models/data2vec_text.py
+517
-0
examples/data2vec/models/data2vec_text_classification.py
examples/data2vec/models/data2vec_text_classification.py
+141
-0
examples/data2vec/models/data2vec_vision.py
examples/data2vec/models/data2vec_vision.py
+727
-0
examples/data2vec/models/mae.py
examples/data2vec/models/mae.py
+829
-0
examples/data2vec/models/mae_image_classification.py
examples/data2vec/models/mae_image_classification.py
+386
-0
examples/data2vec/models/modalities/__init__.py
examples/data2vec/models/modalities/__init__.py
+0
-0
examples/data2vec/models/modalities/audio.py
examples/data2vec/models/modalities/audio.py
+192
-0
examples/data2vec/models/modalities/base.py
examples/data2vec/models/modalities/base.py
+684
-0
examples/data2vec/models/modalities/images.py
examples/data2vec/models/modalities/images.py
+256
-0
examples/data2vec/models/modalities/modules.py
examples/data2vec/models/modalities/modules.py
+589
-0
examples/data2vec/models/modalities/text.py
examples/data2vec/models/modalities/text.py
+161
-0
examples/data2vec/models/utils.py
examples/data2vec/models/utils.py
+55
-0
examples/data2vec/scripts/convert_audioset_labels.py
examples/data2vec/scripts/convert_audioset_labels.py
+63
-0
No files found.
Too many changes to show.
To preserve performance only
508 of 508+
files are displayed.
Plain diff
Email patch
examples/data2vec/data/path_dataset.py
0 → 100644
View file @
72f5785f
import
glob
import
os
from
typing
import
List
,
Optional
,
Tuple
import
logging
import
numpy
as
np
import
torchvision.transforms.functional
as
TF
import
PIL
from
PIL
import
Image
from
torchvision.datasets
import
VisionDataset
logger
=
logging
.
getLogger
(
__name__
)
class
PathDataset
(
VisionDataset
):
def
__init__
(
self
,
root
:
List
[
str
],
loader
:
None
=
None
,
transform
:
Optional
[
str
]
=
None
,
extra_transform
:
Optional
[
str
]
=
None
,
mean
:
Optional
[
List
[
float
]]
=
None
,
std
:
Optional
[
List
[
float
]]
=
None
,
):
super
().
__init__
(
root
=
root
)
PIL
.
Image
.
MAX_IMAGE_PIXELS
=
256000001
self
.
files
=
[]
for
folder
in
self
.
root
:
self
.
files
.
extend
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
folder
,
"**"
,
"*.jpg"
),
recursive
=
True
))
)
self
.
files
.
extend
(
sorted
(
glob
.
glob
(
os
.
path
.
join
(
folder
,
"**"
,
"*.png"
),
recursive
=
True
))
)
self
.
transform
=
transform
self
.
extra_transform
=
extra_transform
self
.
mean
=
mean
self
.
std
=
std
self
.
loader
=
loader
logger
.
info
(
f
"loaded
{
len
(
self
.
files
)
}
samples from
{
root
}
"
)
assert
(
mean
is
None
)
==
(
std
is
None
)
def
__len__
(
self
)
->
int
:
return
len
(
self
.
files
)
def
__getitem__
(
self
,
idx
)
->
Tuple
[
np
.
ndarray
,
np
.
ndarray
]:
path
=
self
.
files
[
idx
]
if
self
.
loader
is
not
None
:
return
self
.
loader
(
path
),
None
img
=
Image
.
open
(
path
).
convert
(
"RGB"
)
if
self
.
transform
is
not
None
:
img
=
self
.
transform
(
img
)
img
=
TF
.
to_tensor
(
img
)
if
self
.
mean
is
not
None
and
self
.
std
is
not
None
:
img
=
TF
.
normalize
(
img
,
self
.
mean
,
self
.
std
)
return
img
,
None
examples/data2vec/fb_convert_beit_cp.py
0 → 100644
View file @
72f5785f
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
torch
from
omegaconf
import
OmegaConf
from
fairseq.criterions.model_criterion
import
ModelCriterionConfig
from
fairseq.dataclass.configs
import
FairseqConfig
from
tasks
import
ImageClassificationConfig
,
ImagePretrainingConfig
from
models.data2vec_image_classification
import
(
Data2VecImageClassificationConfig
,
Data2VecImageClassificationModel
,
)
from
models.data2vec_vision
import
Data2VecVisionConfig
,
Data2VecVisionModel
def
get_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
"convert beit checkpoint into data2vec - vision checkpoint"
)
# fmt: off
parser
.
add_argument
(
'checkpoint'
,
help
=
'checkpoint to convert'
)
parser
.
add_argument
(
'--output'
,
required
=
True
,
metavar
=
'PATH'
,
help
=
'where to output converted checkpoint'
)
parser
.
add_argument
(
'--type'
,
type
=
str
,
choices
=
[
'vision'
,
'image_classification'
],
default
=
'image_classification'
,
help
=
'type of model to upgrade'
)
parser
.
add_argument
(
'--inception_norms'
,
action
=
'store_true'
,
default
=
False
)
# fmt: on
return
parser
def
update_checkpoint
(
model_dict
,
prefix
,
is_nested
):
replace_paths
=
{
"cls_token"
:
"model.cls_emb"
if
is_nested
else
"cls_emb"
,
"patch_embed"
:
"model.patch_embed"
if
is_nested
else
"patch_embed"
,
"mask_token"
:
"mask_emb"
,
}
starts_with
=
{
"patch_embed.proj"
:
"model.patch_embed.conv"
if
is_nested
else
"patch_embed.conv"
,
"lm_head"
:
"final_proj"
,
"fc_norm"
:
"fc_norm"
,
"head"
:
"head"
,
}
partial
=
{
"mlp.fc1"
:
"mlp.0"
,
"mlp.fc2"
:
"mlp.2"
,
}
for
k
in
list
(
model_dict
.
keys
()):
for
sw
,
r
in
starts_with
.
items
():
if
k
.
startswith
(
sw
):
replace_paths
[
k
]
=
k
.
replace
(
sw
,
r
)
for
p
,
r
in
partial
.
items
():
if
p
in
k
:
replace_paths
[
k
]
=
prefix
+
k
.
replace
(
p
,
r
)
if
prefix
!=
""
:
for
k
in
list
(
model_dict
.
keys
()):
if
k
not
in
replace_paths
:
replace_paths
[
k
]
=
prefix
+
k
for
k
in
list
(
model_dict
.
keys
()):
if
k
in
replace_paths
:
model_dict
[
replace_paths
[
k
]]
=
model_dict
[
k
]
if
k
!=
replace_paths
[
k
]:
del
model_dict
[
k
]
return
model_dict
def
main
():
parser
=
get_parser
()
args
=
parser
.
parse_args
()
cp
=
torch
.
load
(
args
.
checkpoint
,
map_location
=
"cpu"
)
cfg
=
FairseqConfig
(
criterion
=
ModelCriterionConfig
(
_name
=
"model"
,
log_keys
=
[
"correct"
]),
)
if
args
.
type
==
"image_classification"
:
cfg
.
task
=
ImageClassificationConfig
(
_name
=
"image_classification"
,
data
=
"."
,
)
if
args
.
inception_norms
:
cfg
.
task
.
normalization_mean
=
[
0.5
,
0.5
,
0.5
]
cfg
.
task
.
normalization_std
=
[
0.5
,
0.5
,
0.5
]
cfg
.
model
=
Data2VecImageClassificationConfig
(
_name
=
"data2vec_image_classification"
,
)
cfg
.
model
.
pretrained_model_args
=
FairseqConfig
(
model
=
Data2VecVisionConfig
(
_name
=
"data2vec_vision"
,
shared_rel_pos_bias
=
False
),
task
=
ImagePretrainingConfig
(
_name
=
"image_pretraining"
,
),
)
cfg
=
OmegaConf
.
create
(
cfg
)
state
=
{
"cfg"
:
OmegaConf
.
to_container
(
cfg
,
resolve
=
True
,
enum_to_str
=
True
),
"model"
:
cp
[
"module"
],
"best_loss"
:
None
,
"optimizer"
:
None
,
"extra_state"
:
{},
}
model
=
Data2VecImageClassificationModel
(
cfg
.
model
)
model
.
load_state_dict
(
update_checkpoint
(
state
[
"model"
],
prefix
=
"model.encoder."
,
is_nested
=
True
),
strict
=
True
,
)
elif
args
.
type
==
"vision"
:
cfg
.
task
=
ImagePretrainingConfig
(
_name
=
"image_pretraining"
,
data
=
"."
,
)
if
args
.
inception_norms
:
cfg
.
task
.
normalization_mean
=
[
0.5
,
0.5
,
0.5
]
cfg
.
task
.
normalization_std
=
[
0.5
,
0.5
,
0.5
]
cfg
.
model
=
Data2VecVisionConfig
(
_name
=
"data2vec_vision"
,
)
cfg
=
OmegaConf
.
create
(
cfg
)
state
=
{
"cfg"
:
OmegaConf
.
to_container
(
cfg
,
resolve
=
True
,
enum_to_str
=
True
),
"model"
:
cp
[
"model"
],
"best_loss"
:
None
,
"optimizer"
:
None
,
"extra_state"
:
{},
}
model
=
Data2VecVisionModel
(
cfg
.
model
)
model
.
load_state_dict
(
update_checkpoint
(
state
[
"model"
],
prefix
=
"encoder."
,
is_nested
=
False
),
strict
=
True
,
)
else
:
raise
Exception
(
"unsupported type "
+
args
.
type
)
print
(
state
[
"cfg"
],
state
.
keys
())
torch
.
save
(
state
,
args
.
output
)
if
__name__
==
"__main__"
:
main
()
examples/data2vec/models/__init__.py
0 → 100644
View file @
72f5785f
examples/data2vec/models/audio_classification.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
contextlib
import
logging
import
re
from
dataclasses
import
dataclass
,
field
from
typing
import
Any
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
omegaconf
import
II
,
MISSING
,
open_dict
from
fairseq
import
checkpoint_utils
,
tasks
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.dataclass.utils
import
convert_namespace_to_omegaconf
from
fairseq.models
import
(
BaseFairseqModel
,
register_model
,
)
from
fairseq.models.wav2vec.wav2vec2
import
MASKING_DISTRIBUTION_CHOICES
from
fairseq.modules
import
TransposeLast
from
fairseq.tasks
import
FairseqTask
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
AudioClassificationConfig
(
FairseqDataclass
):
model_path
:
str
=
field
(
default
=
MISSING
,
metadata
=
{
"help"
:
"path to wav2vec 2.0 model"
}
)
no_pretrained_weights
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"if true, does not load pretrained weights"
}
)
dropout_input
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout to apply to the input (after feat extr)"
},
)
final_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout after transformer and before final projection"
},
)
dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout probability inside wav2vec 2.0 model"
}
)
attention_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout probability for attention weights inside wav2vec 2.0 model"
},
)
activation_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"dropout probability after activation in FFN inside wav2vec 2.0 model"
},
)
# masking
apply_mask
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"apply masking during fine-tuning"
}
)
mask_length
:
int
=
field
(
default
=
10
,
metadata
=
{
"help"
:
"repeat the mask indices multiple times"
}
)
mask_prob
:
float
=
field
(
default
=
0.5
,
metadata
=
{
"help"
:
"probability of replacing a token with mask (normalized by length)"
},
)
mask_selection
:
MASKING_DISTRIBUTION_CHOICES
=
field
(
default
=
"static"
,
metadata
=
{
"help"
:
"how to choose masks"
}
)
mask_other
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indices"
},
)
no_mask_overlap
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to allow masks to overlap"
}
)
mask_min_space
:
Optional
[
int
]
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"min space between spans (if no overlap is enabled)"
},
)
require_same_masks
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to number of masked timesteps must be the same across all "
"examples in a batch"
},
)
mask_dropout
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"percent of masks to unmask for each sample"
},
)
# channel masking
mask_channel_length
:
int
=
field
(
default
=
10
,
metadata
=
{
"help"
:
"length of the mask for features (channels)"
}
)
mask_channel_prob
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"probability of replacing a feature with 0"
}
)
mask_channel_selection
:
MASKING_DISTRIBUTION_CHOICES
=
field
(
default
=
"static"
,
metadata
=
{
"help"
:
"how to choose mask length for channel masking"
},
)
mask_channel_other
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"secondary mask argument (used for more complex distributions), "
"see help in compute_mask_indicesh"
},
)
no_mask_channel_overlap
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to allow channel masks to overlap"
}
)
freeze_finetune_updates
:
int
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"dont finetune wav2vec for this many updates"
}
)
feature_grad_mult
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"reset feature grad mult in wav2vec 2.0 to this"
}
)
layerdrop
:
float
=
field
(
default
=
0.0
,
metadata
=
{
"help"
:
"probability of dropping a layer in wav2vec 2.0"
}
)
mask_channel_min_space
:
Optional
[
int
]
=
field
(
default
=
1
,
metadata
=
{
"help"
:
"min space between spans (if no overlap is enabled)"
},
)
mask_channel_before
:
bool
=
False
normalize
:
bool
=
II
(
"task.normalize"
)
data
:
str
=
II
(
"task.data"
)
# this holds the loaded wav2vec args
d2v_args
:
Any
=
None
offload_activations
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"offload_activations"
}
)
min_params_to_wrap
:
int
=
field
(
default
=
int
(
1e8
),
metadata
=
{
"help"
:
"minimum number of params for a layer to be wrapped with FSDP() when "
"training with --ddp-backend=fully_sharded. Smaller values will "
"improve memory efficiency, but may make torch.distributed "
"communication less efficient due to smaller input sizes. This option "
"is set to 0 (i.e., always wrap) when --checkpoint-activations or "
"--offload-activations are passed."
},
)
checkpoint_activations
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"recompute activations and save memory for extra compute"
},
)
ddp_backend
:
str
=
II
(
"distributed_training.ddp_backend"
)
prediction_mode
:
str
=
"lin_softmax"
eval_prediction_mode
:
Optional
[
str
]
=
None
conv_kernel
:
int
=
-
1
conv_stride
:
int
=
1
two_convs
:
bool
=
False
extreme_factor
:
float
=
1.0
conv_feature_layers
:
Optional
[
str
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
mixup_prob
:
float
=
1.0
source_mixup
:
float
=
-
1
same_mixup
:
bool
=
True
label_mixup
:
bool
=
False
gain_mode
:
str
=
"none"
@
register_model
(
"audio_classification"
,
dataclass
=
AudioClassificationConfig
)
class
AudioClassificationModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
AudioClassificationConfig
,
num_classes
):
super
().
__init__
()
self
.
apply_mask
=
cfg
.
apply_mask
self
.
cfg
=
cfg
arg_overrides
=
{
"dropout"
:
cfg
.
dropout
,
"activation_dropout"
:
cfg
.
activation_dropout
,
"dropout_input"
:
cfg
.
dropout_input
,
"attention_dropout"
:
cfg
.
attention_dropout
,
"mask_length"
:
cfg
.
mask_length
,
"mask_prob"
:
cfg
.
mask_prob
,
"require_same_masks"
:
getattr
(
cfg
,
"require_same_masks"
,
True
),
"mask_dropout"
:
getattr
(
cfg
,
"mask_dropout"
,
0
),
"mask_selection"
:
cfg
.
mask_selection
,
"mask_other"
:
cfg
.
mask_other
,
"no_mask_overlap"
:
cfg
.
no_mask_overlap
,
"mask_channel_length"
:
cfg
.
mask_channel_length
,
"mask_channel_prob"
:
cfg
.
mask_channel_prob
,
"mask_channel_before"
:
cfg
.
mask_channel_before
,
"mask_channel_selection"
:
cfg
.
mask_channel_selection
,
"mask_channel_other"
:
cfg
.
mask_channel_other
,
"no_mask_channel_overlap"
:
cfg
.
no_mask_channel_overlap
,
"encoder_layerdrop"
:
cfg
.
layerdrop
,
"feature_grad_mult"
:
cfg
.
feature_grad_mult
,
"checkpoint_activations"
:
cfg
.
checkpoint_activations
,
"offload_activations"
:
cfg
.
offload_activations
,
"min_params_to_wrap"
:
cfg
.
min_params_to_wrap
,
"mixup"
:
-
1
,
}
if
cfg
.
conv_feature_layers
is
not
None
:
arg_overrides
[
"conv_feature_layers"
]
=
cfg
.
conv_feature_layers
if
cfg
.
d2v_args
is
None
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
cfg
.
model_path
,
arg_overrides
)
d2v_args
=
state
.
get
(
"cfg"
,
None
)
if
d2v_args
is
None
:
d2v_args
=
convert_namespace_to_omegaconf
(
state
[
"args"
])
d2v_args
.
criterion
=
None
d2v_args
.
lr_scheduler
=
None
cfg
.
d2v_args
=
d2v_args
logger
.
info
(
d2v_args
)
else
:
state
=
None
d2v_args
=
cfg
.
d2v_args
model_normalized
=
d2v_args
.
task
.
get
(
"normalize"
,
d2v_args
.
model
.
get
(
"normalize"
,
False
)
)
assert
cfg
.
normalize
==
model_normalized
,
(
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for both pre-training and here"
)
if
hasattr
(
cfg
,
"checkpoint_activations"
)
and
cfg
.
checkpoint_activations
:
with
open_dict
(
d2v_args
):
d2v_args
.
model
.
checkpoint_activations
=
cfg
.
checkpoint_activations
d2v_args
.
task
.
data
=
cfg
.
data
task
=
tasks
.
setup_task
(
d2v_args
.
task
)
model
=
task
.
build_model
(
d2v_args
.
model
,
from_checkpoint
=
True
)
model
.
remove_pretraining_modules
()
if
state
is
not
None
and
not
cfg
.
no_pretrained_weights
:
self
.
load_model_weights
(
state
,
model
,
cfg
)
d
=
d2v_args
.
model
.
encoder_embed_dim
self
.
d2v_model
=
model
self
.
final_dropout
=
nn
.
Dropout
(
cfg
.
final_dropout
)
self
.
freeze_finetune_updates
=
cfg
.
freeze_finetune_updates
self
.
num_updates
=
0
for
p
in
self
.
parameters
():
p
.
param_group
=
"pretrained"
if
cfg
.
prediction_mode
==
"proj_avg_proj"
:
self
.
proj
=
nn
.
Linear
(
d
,
d
*
2
)
self
.
proj2
=
nn
.
Linear
(
d
*
2
,
num_classes
)
for
p
in
self
.
proj
.
parameters
():
p
.
param_group
=
"projection"
for
p
in
self
.
proj2
.
parameters
():
p
.
param_group
=
"projection"
elif
self
.
cfg
.
prediction_mode
==
"summary_proj"
:
self
.
proj
=
nn
.
Linear
(
d
//
3
,
num_classes
)
for
p
in
self
.
proj
.
parameters
():
p
.
param_group
=
"projection"
elif
self
.
cfg
.
conv_kernel
>
1
and
not
self
.
cfg
.
two_convs
:
self
.
proj
=
nn
.
Sequential
(
TransposeLast
(),
nn
.
Conv1d
(
d
,
num_classes
,
kernel_size
=
self
.
cfg
.
conv_kernel
,
stride
=
self
.
cfg
.
conv_stride
),
TransposeLast
(),
)
for
p
in
self
.
proj
.
parameters
():
p
.
param_group
=
"projection"
elif
self
.
cfg
.
conv_kernel
>
0
and
self
.
cfg
.
two_convs
:
self
.
proj
=
nn
.
Sequential
(
TransposeLast
(),
nn
.
Conv1d
(
d
,
d
,
kernel_size
=
self
.
cfg
.
conv_kernel
,
stride
=
self
.
cfg
.
conv_stride
),
TransposeLast
(),
nn
.
GELU
(),
nn
.
Linear
(
d
,
num_classes
),
)
for
p
in
self
.
proj
.
parameters
():
p
.
param_group
=
"projection"
else
:
self
.
proj
=
nn
.
Linear
(
d
,
num_classes
)
for
p
in
self
.
proj
.
parameters
():
p
.
param_group
=
"projection"
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
return
state_dict
@
classmethod
def
build_model
(
cls
,
cfg
:
AudioClassificationConfig
,
task
:
FairseqTask
):
"""Build a new model instance."""
assert
hasattr
(
task
,
"labels"
),
f
"Task
{
task
}
must have an attribute 'labels'"
return
cls
(
cfg
,
len
(
task
.
labels
))
def
load_model_weights
(
self
,
state
,
model
,
cfg
):
if
cfg
.
ddp_backend
==
"fully_sharded"
:
from
fairseq.distributed
import
FullyShardedDataParallel
for
name
,
module
in
model
.
named_modules
():
if
"encoder.layers"
in
name
and
len
(
name
.
split
(
"."
))
==
3
:
# Only for layers, we do a special handling and load the weights one by one
# We dont load all weights together as that wont be memory efficient and may
# cause oom
new_dict
=
{
k
.
replace
(
name
+
"."
,
""
):
v
for
(
k
,
v
)
in
state
[
"model"
].
items
()
if
name
+
"."
in
k
}
assert
isinstance
(
module
,
FullyShardedDataParallel
)
with
module
.
summon_full_params
():
module
.
load_state_dict
(
new_dict
,
strict
=
True
)
module
.
_reset_lazy_init
()
# Once layers are loaded, filter them out and load everything else.
r
=
re
.
compile
(
"encoder.layers.\d."
)
filtered_list
=
list
(
filter
(
r
.
match
,
state
[
"model"
].
keys
()))
new_big_dict
=
{
k
:
v
for
(
k
,
v
)
in
state
[
"model"
].
items
()
if
k
not
in
filtered_list
}
model
.
load_state_dict
(
new_big_dict
,
strict
=
False
)
else
:
if
"_ema"
in
state
[
"model"
]:
del
state
[
"model"
][
"_ema"
]
model
.
load_state_dict
(
state
[
"model"
],
strict
=
False
)
def
set_num_updates
(
self
,
num_updates
):
"""Set the number of parameters updates."""
super
().
set_num_updates
(
num_updates
)
self
.
num_updates
=
num_updates
def
compute_gain
(
self
,
sound
,
fs
=
16_000
,
min_db
=-
80.0
,
mode
=
"A_weighting"
):
if
fs
==
16000
:
n_fft
=
2048
elif
fs
==
44100
:
n_fft
=
4096
else
:
raise
Exception
(
"Invalid fs {}"
.
format
(
fs
))
stride
=
n_fft
//
2
def
a_weight
(
fs
,
n_fft
,
min_db
=-
80.0
):
freq
=
np
.
linspace
(
0
,
fs
//
2
,
n_fft
//
2
+
1
)
freq_sq
=
np
.
power
(
freq
,
2
)
freq_sq
[
0
]
=
1.0
weight
=
2.0
+
20.0
*
(
2
*
np
.
log10
(
12194
)
+
2
*
np
.
log10
(
freq_sq
)
-
np
.
log10
(
freq_sq
+
12194
**
2
)
-
np
.
log10
(
freq_sq
+
20.6
**
2
)
-
0.5
*
np
.
log10
(
freq_sq
+
107.7
**
2
)
-
0.5
*
np
.
log10
(
freq_sq
+
737.9
**
2
)
)
weight
=
np
.
maximum
(
weight
,
min_db
)
return
weight
gain
=
[]
for
i
in
range
(
0
,
len
(
sound
)
-
n_fft
+
1
,
stride
):
if
mode
==
"RMSE"
:
g
=
np
.
mean
(
sound
[
i
:
i
+
n_fft
]
**
2
)
elif
mode
==
"A_weighting"
:
spec
=
np
.
fft
.
rfft
(
np
.
hanning
(
n_fft
+
1
)[:
-
1
]
*
sound
[
i
:
i
+
n_fft
])
power_spec
=
np
.
abs
(
spec
)
**
2
a_weighted_spec
=
power_spec
*
np
.
power
(
10
,
a_weight
(
fs
,
n_fft
)
/
10
)
g
=
np
.
sum
(
a_weighted_spec
)
else
:
raise
Exception
(
"Invalid mode {}"
.
format
(
mode
))
gain
.
append
(
g
)
gain
=
np
.
array
(
gain
)
gain
=
np
.
maximum
(
gain
,
np
.
power
(
10
,
min_db
/
10
))
gain_db
=
10
*
np
.
log10
(
gain
)
return
gain_db
# adapted from https://github.com/mil-tokyo/bc_learning_sound/blob/master/utils.py
def
compute_gain_torch
(
self
,
sound
,
fs
=
16_000
,
min_db
=-
80.0
,
mode
=
"A_weighting"
):
if
fs
==
16000
:
n_fft
=
2048
elif
fs
==
44100
:
n_fft
=
4096
else
:
raise
Exception
(
"Invalid fs {}"
.
format
(
fs
))
if
mode
==
"A_weighting"
:
if
not
hasattr
(
self
,
f
"a_weight"
):
self
.
a_weight
=
{}
if
fs
not
in
self
.
a_weight
:
def
a_weight
(
fs
,
n_fft
,
min_db
=-
80.0
):
freq
=
np
.
linspace
(
0
,
fs
//
2
,
n_fft
//
2
+
1
)
freq_sq
=
freq
**
2
freq_sq
[
0
]
=
1.0
weight
=
2.0
+
20.0
*
(
2
*
np
.
log10
(
12194
)
+
2
*
np
.
log10
(
freq_sq
)
-
np
.
log10
(
freq_sq
+
12194
**
2
)
-
np
.
log10
(
freq_sq
+
20.6
**
2
)
-
0.5
*
np
.
log10
(
freq_sq
+
107.7
**
2
)
-
0.5
*
np
.
log10
(
freq_sq
+
737.9
**
2
)
)
weight
=
np
.
maximum
(
weight
,
min_db
)
return
weight
self
.
a_weight
[
fs
]
=
torch
.
from_numpy
(
np
.
power
(
10
,
a_weight
(
fs
,
n_fft
,
min_db
)
/
10
)
).
to
(
device
=
sound
.
device
)
sound
=
sound
.
unfold
(
-
1
,
n_fft
,
n_fft
//
2
)
if
mode
==
"RMSE"
:
sound
=
sound
**
2
g
=
sound
.
mean
(
-
1
)
elif
mode
==
"A_weighting"
:
w
=
torch
.
hann_window
(
n_fft
,
device
=
sound
.
device
)
*
sound
spec
=
torch
.
fft
.
rfft
(
w
)
power_spec
=
spec
.
abs
()
**
2
a_weighted_spec
=
power_spec
*
self
.
a_weight
[
fs
]
g
=
a_weighted_spec
.
sum
(
-
1
)
else
:
raise
Exception
(
"Invalid mode {}"
.
format
(
mode
))
gain
=
torch
.
maximum
(
g
,
torch
.
tensor
(
10
**
(
min_db
/
10
),
device
=
g
.
device
))
gain_db
=
10
*
torch
.
log10
(
gain
)
return
gain_db
def
forward
(
self
,
source
,
padding_mask
,
label
=
None
,
**
kwargs
):
if
self
.
cfg
.
source_mixup
>=
0
and
self
.
training
and
self
.
cfg
.
mixup_prob
>
0
:
with
torch
.
no_grad
():
mixed_source
=
source
mix_mask
=
None
if
self
.
cfg
.
mixup_prob
<
1
:
mix_mask
=
(
torch
.
empty
((
source
.
size
(
0
),),
device
=
source
.
device
)
.
bernoulli_
(
self
.
cfg
.
mixup_prob
)
.
bool
()
)
mixed_source
=
source
[
mix_mask
]
r
=
(
torch
.
FloatTensor
(
1
if
self
.
cfg
.
same_mixup
else
mixed_source
.
size
(
0
)
)
.
uniform_
(
max
(
1e-6
,
self
.
cfg
.
source_mixup
),
1
)
.
to
(
dtype
=
source
.
dtype
,
device
=
source
.
device
)
)
mixup_perm
=
torch
.
randperm
(
source
.
size
(
0
))
s2
=
source
[
mixup_perm
]
if
self
.
cfg
.
gain_mode
==
"none"
:
p
=
r
.
unsqueeze
(
-
1
)
if
mix_mask
is
not
None
:
s2
=
s2
[
mix_mask
]
else
:
if
self
.
cfg
.
gain_mode
==
"naive_rms"
:
G1
=
source
.
pow
(
2
).
mean
(
dim
=-
1
).
sqrt
()
else
:
G1
,
_
=
self
.
compute_gain_torch
(
source
,
mode
=
self
.
cfg
.
gain_mode
).
max
(
-
1
)
G1
=
G1
.
to
(
dtype
=
source
.
dtype
)
G2
=
G1
[
mixup_perm
]
if
mix_mask
is
not
None
:
G1
=
G1
[
mix_mask
]
G2
=
G2
[
mix_mask
]
s2
=
s2
[
mix_mask
]
p
=
1
/
(
1
+
10
**
((
G1
-
G2
)
/
20
)
*
(
1
-
r
)
/
r
)
p
=
p
.
unsqueeze
(
-
1
)
mixed
=
(
p
*
mixed_source
)
+
(
1
-
p
)
*
s2
if
mix_mask
is
None
:
source
=
mixed
/
torch
.
sqrt
(
p
**
2
+
(
1
-
p
)
**
2
)
else
:
source
[
mix_mask
]
=
mixed
/
torch
.
sqrt
(
p
**
2
+
(
1
-
p
)
**
2
)
if
label
is
not
None
and
self
.
cfg
.
label_mixup
:
r
=
r
.
unsqueeze
(
-
1
)
if
mix_mask
is
None
:
label
=
label
*
r
+
(
1
-
r
)
*
label
[
mixup_perm
]
else
:
label
[
mix_mask
]
=
(
label
[
mix_mask
]
*
r
+
(
1
-
r
)
*
label
[
mixup_perm
][
mix_mask
]
)
d2v_args
=
{
"source"
:
source
,
"padding_mask"
:
padding_mask
,
"mask"
:
self
.
apply_mask
and
self
.
training
,
}
ft
=
self
.
freeze_finetune_updates
<=
self
.
num_updates
with
torch
.
no_grad
()
if
not
ft
else
contextlib
.
ExitStack
():
res
=
self
.
d2v_model
.
extract_features
(
**
d2v_args
)
x
=
res
[
"x"
]
padding_mask
=
res
[
"padding_mask"
]
if
padding_mask
is
not
None
:
x
[
padding_mask
]
=
0
x
=
self
.
final_dropout
(
x
)
if
self
.
training
or
(
self
.
cfg
.
eval_prediction_mode
is
None
or
self
.
cfg
.
eval_prediction_mode
==
""
):
prediction_mode
=
self
.
cfg
.
prediction_mode
else
:
prediction_mode
=
self
.
cfg
.
eval_prediction_mode
if
prediction_mode
==
"average_before"
:
x
=
x
.
mean
(
dim
=
1
)
if
prediction_mode
!=
"summary_mha"
and
prediction_mode
!=
"summary_proj"
and
prediction_mode
!=
"cls"
:
x
=
self
.
proj
(
x
)
logits
=
True
if
prediction_mode
==
"lin_softmax"
:
x
=
F
.
logsigmoid
(
x
.
float
())
x
=
torch
.
logsumexp
(
x
+
x
,
dim
=
1
)
-
torch
.
logsumexp
(
x
,
dim
=
1
)
x
=
x
.
clamp
(
max
=
0
)
x
=
x
-
torch
.
log
(
-
(
torch
.
expm1
(
x
)))
elif
prediction_mode
==
"extremized_odds"
:
x
=
x
.
float
().
sum
(
dim
=
1
)
x
=
x
*
self
.
cfg
.
extreme_factor
elif
prediction_mode
==
"average_before"
:
x
=
x
.
float
()
elif
prediction_mode
==
"average"
:
x
=
x
.
float
().
mean
(
dim
=
1
)
elif
prediction_mode
==
"average_sigmoid"
:
x
=
torch
.
sigmoid
(
x
.
float
())
x
=
x
.
mean
(
dim
=
1
)
logits
=
False
elif
prediction_mode
==
"max"
:
x
,
_
=
x
.
float
().
max
(
dim
=
1
)
elif
prediction_mode
==
"max_sigmoid"
:
x
=
torch
.
sigmoid
(
x
.
float
())
x
,
_
=
x
.
float
().
max
(
dim
=
1
)
logits
=
False
elif
prediction_mode
==
"proj_avg_proj"
:
x
=
x
.
mean
(
dim
=
1
)
x
=
self
.
proj2
(
x
)
elif
prediction_mode
==
"summary_mha"
or
prediction_mode
==
"summary_proj"
:
x
=
self
.
d2v_model
.
summary
(
x
,
padding_mask
,
proj
=
prediction_mode
==
"summary_proj"
)
x
=
x
.
type_as
(
source
)
x
=
self
.
proj
(
x
)
elif
prediction_mode
==
"cls"
:
x
=
x
[:,
0
]
x
=
self
.
proj
(
x
)
else
:
raise
Exception
(
f
"unknown prediction mode
{
prediction_mode
}
"
)
if
label
is
None
:
return
torch
.
sigmoid
(
x
)
if
logits
else
x
x
=
torch
.
nan_to_num
(
x
)
if
logits
:
loss
=
F
.
binary_cross_entropy_with_logits
(
x
,
label
.
float
(),
reduction
=
"none"
)
else
:
loss
=
F
.
binary_cross_entropy
(
x
,
label
.
float
(),
reduction
=
"none"
)
result
=
{
"losses"
:
{
"main"
:
loss
,
},
"sample_size"
:
label
.
sum
(),
}
if
not
self
.
training
:
result
[
"_predictions"
]
=
torch
.
sigmoid
(
x
)
if
logits
else
x
result
[
"_targets"
]
=
label
return
result
examples/data2vec/models/data2vec2.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
math
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
,
Callable
from
functools
import
partial
import
numpy
as
np
from
omegaconf
import
II
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
fairseq.modules
import
EMAModule
,
EMAModuleConfig
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
examples.data2vec.data.modality
import
Modality
from
examples.data2vec.models.modalities.base
import
(
MaskSeed
,
D2vModalityConfig
,
ModalitySpecificEncoder
,
get_annealed_rate
,
)
from
examples.data2vec.models.modalities.modules
import
(
D2vDecoderConfig
,
AltBlock
,
Decoder1d
,
)
from
examples.data2vec.models.modalities.audio
import
(
D2vAudioConfig
,
AudioEncoder
,
)
from
examples.data2vec.models.modalities.images
import
(
D2vImageConfig
,
ImageEncoder
,
)
from
examples.data2vec.models.modalities.text
import
(
D2vTextConfig
,
TextEncoder
,
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
D2vModalitiesConfig
(
FairseqDataclass
):
audio
:
D2vAudioConfig
=
D2vAudioConfig
()
image
:
D2vImageConfig
=
D2vImageConfig
()
text
:
D2vTextConfig
=
D2vTextConfig
()
@
dataclass
class
Data2VecMultiConfig
(
FairseqDataclass
):
loss_beta
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"beta for smooth l1 loss. 0 means use l2 loss"
}
)
loss_scale
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
depth
:
int
=
8
start_drop_path_rate
:
float
=
0
end_drop_path_rate
:
float
=
0
num_heads
:
int
=
12
norm_eps
:
float
=
1e-6
norm_affine
:
bool
=
True
encoder_dropout
:
float
=
0.1
post_mlp_drop
:
float
=
0.1
attention_dropout
:
float
=
0.1
activation_dropout
:
float
=
0.0
dropout_input
:
float
=
0.0
layerdrop
:
float
=
0.0
embed_dim
:
int
=
768
mlp_ratio
:
float
=
4
layer_norm_first
:
bool
=
False
average_top_k_layers
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"how many layers to average"
}
)
end_of_block_targets
:
bool
=
False
clone_batch
:
int
=
1
layer_norm_target_layer
:
bool
=
False
batch_norm_target_layer
:
bool
=
False
instance_norm_target_layer
:
bool
=
False
instance_norm_targets
:
bool
=
False
layer_norm_targets
:
bool
=
False
ema_decay
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"initial ema decay rate"
})
ema_same_dtype
:
bool
=
True
log_norms
:
bool
=
True
ema_end_decay
:
float
=
field
(
default
=
0.9999
,
metadata
=
{
"help"
:
"final ema decay rate"
}
)
# when to finish annealing ema decay rate
ema_anneal_end_step
:
int
=
II
(
"optimization.max_update"
)
ema_encoder_only
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to momentum update only the shared transformer encoder"
},
)
max_update
:
int
=
II
(
"optimization.max_update"
)
modalities
:
D2vModalitiesConfig
=
D2vModalitiesConfig
()
shared_decoder
:
Optional
[
D2vDecoderConfig
]
=
None
min_target_var
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"stop training if target var falls below this"
}
)
min_pred_var
:
float
=
field
(
default
=
0.01
,
metadata
=
{
"help"
:
"stop training if prediction var falls below this"
},
)
supported_modality
:
Optional
[
Modality
]
=
None
mae_init
:
bool
=
False
seed
:
int
=
II
(
"common.seed"
)
skip_ema
:
bool
=
False
cls_loss
:
float
=
0
recon_loss
:
float
=
0
d2v_loss
:
float
=
1
decoder_group
:
bool
=
False
@
register_model
(
"data2vec_multi"
,
dataclass
=
Data2VecMultiConfig
)
class
Data2VecMultiModel
(
BaseFairseqModel
):
def
make_modality_encoder
(
self
,
cfg
:
D2vModalityConfig
,
embed_dim
:
int
,
make_block
:
Callable
[[
float
],
nn
.
ModuleList
],
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
],
layer_norm_first
:
bool
,
alibi_biases
,
task
,
)
->
ModalitySpecificEncoder
:
if
cfg
.
type
==
Modality
.
AUDIO
:
enc_cls
=
AudioEncoder
elif
cfg
.
type
==
Modality
.
IMAGE
:
enc_cls
=
ImageEncoder
elif
cfg
.
type
==
Modality
.
TEXT
:
enc_cls
=
TextEncoder
if
hasattr
(
task
,
"text_task"
):
task
=
task
.
text_task
else
:
raise
Exception
(
f
"unsupported modality
{
cfg
.
type
}
"
)
return
enc_cls
(
cfg
,
embed_dim
,
make_block
,
norm_layer
,
layer_norm_first
,
alibi_biases
,
task
,
)
def
__init__
(
self
,
cfg
:
Data2VecMultiConfig
,
modalities
,
skip_ema
=
False
,
task
=
None
):
super
().
__init__
()
self
.
cfg
=
cfg
self
.
modalities
=
modalities
self
.
task
=
task
make_layer_norm
=
partial
(
nn
.
LayerNorm
,
eps
=
cfg
.
norm_eps
,
elementwise_affine
=
cfg
.
norm_affine
)
def
make_block
(
drop_path
,
dim
=
None
,
heads
=
None
):
return
AltBlock
(
cfg
.
embed_dim
if
dim
is
None
else
dim
,
cfg
.
num_heads
if
heads
is
None
else
heads
,
cfg
.
mlp_ratio
,
qkv_bias
=
True
,
drop
=
cfg
.
encoder_dropout
,
attn_drop
=
cfg
.
attention_dropout
,
mlp_drop
=
cfg
.
activation_dropout
,
post_mlp_drop
=
cfg
.
post_mlp_drop
,
drop_path
=
drop_path
,
norm_layer
=
make_layer_norm
,
layer_norm_first
=
cfg
.
layer_norm_first
,
ffn_targets
=
not
cfg
.
end_of_block_targets
,
)
self
.
alibi_biases
=
{}
self
.
modality_encoders
=
nn
.
ModuleDict
()
for
mod
in
self
.
modalities
:
mod_cfg
=
getattr
(
cfg
.
modalities
,
mod
.
name
.
lower
())
enc
=
self
.
make_modality_encoder
(
mod_cfg
,
cfg
.
embed_dim
,
make_block
,
make_layer_norm
,
cfg
.
layer_norm_first
,
self
.
alibi_biases
,
task
,
)
self
.
modality_encoders
[
mod
.
name
]
=
enc
self
.
ema
=
None
self
.
average_top_k_layers
=
cfg
.
average_top_k_layers
self
.
loss_beta
=
cfg
.
loss_beta
self
.
loss_scale
=
cfg
.
loss_scale
self
.
dropout_input
=
nn
.
Dropout
(
cfg
.
dropout_input
)
dpr
=
np
.
linspace
(
cfg
.
start_drop_path_rate
,
cfg
.
end_drop_path_rate
,
cfg
.
depth
)
self
.
blocks
=
nn
.
ModuleList
([
make_block
(
dpr
[
i
])
for
i
in
range
(
cfg
.
depth
)])
self
.
norm
=
None
if
cfg
.
layer_norm_first
:
self
.
norm
=
make_layer_norm
(
cfg
.
embed_dim
)
if
self
.
cfg
.
mae_init
:
self
.
apply
(
self
.
_init_weights
)
else
:
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
self
.
apply
(
init_bert_params
)
for
mod_enc
in
self
.
modality_encoders
.
values
():
mod_enc
.
reset_parameters
()
if
not
skip_ema
:
self
.
ema
=
self
.
make_ema_teacher
(
cfg
.
ema_decay
)
self
.
shared_decoder
=
(
Decoder1d
(
cfg
.
shared_decoder
,
cfg
.
embed_dim
)
if
self
.
cfg
.
shared_decoder
is
not
None
else
None
)
if
self
.
shared_decoder
is
not
None
:
self
.
shared_decoder
.
apply
(
self
.
_init_weights
)
self
.
recon_proj
=
None
if
cfg
.
recon_loss
>
0
:
self
.
recon_proj
=
nn
.
Linear
(
cfg
.
embed_dim
,
cfg
.
embed_dim
)
for
pn
,
p
in
self
.
named_parameters
():
if
len
(
p
.
shape
)
==
1
or
pn
.
endswith
(
".bias"
)
or
"alibi_scale"
in
pn
:
p
.
optim_overrides
=
{
"optimizer"
:
{
"weight_decay_scale"
:
0
}}
if
cfg
.
decoder_group
and
"decoder"
in
pn
:
p
.
param_group
=
"decoder"
self
.
num_updates
=
0
def
_init_weights
(
self
,
m
):
try
:
from
apex.normalization
import
FusedLayerNorm
fn
=
FusedLayerNorm
except
:
fn
=
nn
.
LayerNorm
if
isinstance
(
m
,
nn
.
Linear
):
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
)
or
isinstance
(
m
,
fn
):
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
if
m
.
weight
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
no_grad
()
def
make_ema_teacher
(
self
,
ema_decay
):
ema_config
=
EMAModuleConfig
(
ema_decay
=
ema_decay
,
ema_fp32
=
True
,
log_norms
=
self
.
cfg
.
log_norms
,
add_missing_params
=
False
,
)
model_copy
=
self
.
make_target_model
()
return
EMAModule
(
model_copy
,
ema_config
,
copy_model
=
False
,
)
def
make_target_model
(
self
):
logger
.
info
(
"making target model"
)
model_copy
=
Data2VecMultiModel
(
self
.
cfg
,
self
.
modalities
,
skip_ema
=
True
,
task
=
self
.
task
)
if
self
.
cfg
.
ema_encoder_only
:
model_copy
=
model_copy
.
blocks
for
p_s
,
p_t
in
zip
(
self
.
blocks
.
parameters
(),
model_copy
.
parameters
()):
p_t
.
data
.
copy_
(
p_s
.
data
)
else
:
for
p_s
,
p_t
in
zip
(
self
.
parameters
(),
model_copy
.
parameters
()):
p_t
.
data
.
copy_
(
p_s
.
data
)
for
mod_enc
in
model_copy
.
modality_encoders
.
values
():
mod_enc
.
decoder
=
None
if
not
mod_enc
.
modality_cfg
.
ema_local_encoder
:
mod_enc
.
local_encoder
=
None
mod_enc
.
project_features
=
None
model_copy
.
requires_grad_
(
False
)
return
model_copy
def
set_num_updates
(
self
,
num_updates
):
super
().
set_num_updates
(
num_updates
)
if
self
.
ema
is
not
None
and
(
(
self
.
num_updates
==
0
and
num_updates
>
1
)
or
self
.
num_updates
>=
num_updates
):
pass
elif
self
.
training
and
self
.
ema
is
not
None
:
ema_weight_decay
=
None
if
self
.
cfg
.
ema_decay
!=
self
.
cfg
.
ema_end_decay
:
if
num_updates
>=
self
.
cfg
.
ema_anneal_end_step
:
decay
=
self
.
cfg
.
ema_end_decay
else
:
decay
=
get_annealed_rate
(
self
.
cfg
.
ema_decay
,
self
.
cfg
.
ema_end_decay
,
num_updates
,
self
.
cfg
.
ema_anneal_end_step
,
)
self
.
ema
.
set_decay
(
decay
,
weight_decay
=
ema_weight_decay
)
if
self
.
ema
.
get_decay
()
<
1
:
self
.
ema
.
step
(
self
.
blocks
if
self
.
cfg
.
ema_encoder_only
else
self
)
self
.
num_updates
=
num_updates
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
state
=
super
().
state_dict
(
destination
,
prefix
,
keep_vars
)
if
self
.
ema
is
not
None
:
state
[
prefix
+
"_ema"
]
=
self
.
ema
.
fp32_params
return
state
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
k
=
prefix
+
"_ema"
if
self
.
ema
is
not
None
:
assert
k
in
state_dict
self
.
ema
.
restore
(
state_dict
[
k
],
True
)
del
state_dict
[
k
]
elif
k
in
state_dict
:
del
state_dict
[
k
]
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
@
classmethod
def
build_model
(
cls
,
cfg
:
Data2VecMultiConfig
,
task
=
None
):
"""Build a new model instance."""
if
task
is
None
or
not
hasattr
(
task
,
"supported_modalities"
):
modalities
=
(
[
cfg
.
supported_modality
]
if
cfg
.
supported_modality
is
not
None
else
[
Modality
.
AUDIO
,
Modality
.
IMAGE
,
Modality
.
TEXT
,
]
)
else
:
modalities
=
task
.
supported_modalities
return
cls
(
cfg
,
modalities
,
task
=
task
,
skip_ema
=
cfg
.
skip_ema
)
def
forward
(
self
,
source
,
target
=
None
,
id
=
None
,
mode
=
None
,
padding_mask
=
None
,
mask
=
True
,
features_only
=
False
,
force_remove_masked
=
False
,
remove_extra_tokens
=
True
,
precomputed_mask
=
None
,
):
if
mode
is
None
:
assert
self
.
cfg
.
supported_modality
is
not
None
mode
=
self
.
cfg
.
supported_modality
if
isinstance
(
mode
,
Modality
):
mode
=
mode
.
name
feature_extractor
=
self
.
modality_encoders
[
mode
]
mask_seeds
=
None
if
id
is
not
None
:
mask_seeds
=
MaskSeed
(
seed
=
self
.
cfg
.
seed
,
update
=
self
.
num_updates
,
ids
=
id
)
extractor_out
=
feature_extractor
(
source
,
padding_mask
,
mask
,
remove_masked
=
not
features_only
or
force_remove_masked
,
clone_batch
=
self
.
cfg
.
clone_batch
if
not
features_only
else
1
,
mask_seeds
=
mask_seeds
,
precomputed_mask
=
precomputed_mask
,
)
x
=
extractor_out
[
"x"
]
encoder_mask
=
extractor_out
[
"encoder_mask"
]
masked_padding_mask
=
extractor_out
[
"padding_mask"
]
masked_alibi_bias
=
extractor_out
.
get
(
"alibi_bias"
,
None
)
alibi_scale
=
extractor_out
.
get
(
"alibi_scale"
,
None
)
if
self
.
dropout_input
is
not
None
:
x
=
self
.
dropout_input
(
x
)
layer_results
=
[]
for
i
,
blk
in
enumerate
(
self
.
blocks
):
if
(
not
self
.
training
or
self
.
cfg
.
layerdrop
==
0
or
(
np
.
random
.
random
()
>
self
.
cfg
.
layerdrop
)
):
ab
=
masked_alibi_bias
if
ab
is
not
None
and
alibi_scale
is
not
None
:
scale
=
(
alibi_scale
[
i
]
if
alibi_scale
.
size
(
0
)
>
1
else
alibi_scale
.
squeeze
(
0
)
)
ab
=
ab
*
scale
.
type_as
(
ab
)
x
,
lr
=
blk
(
x
,
padding_mask
=
masked_padding_mask
,
alibi_bias
=
ab
,
)
if
features_only
:
layer_results
.
append
(
lr
)
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
if
features_only
:
if
remove_extra_tokens
:
x
=
x
[:,
feature_extractor
.
modality_cfg
.
num_extra_tokens
:]
if
masked_padding_mask
is
not
None
:
masked_padding_mask
=
masked_padding_mask
[
:,
feature_extractor
.
modality_cfg
.
num_extra_tokens
:
]
return
{
"x"
:
x
,
"padding_mask"
:
masked_padding_mask
,
"layer_results"
:
layer_results
,
"mask"
:
encoder_mask
,
}
xs
=
[]
if
self
.
shared_decoder
is
not
None
:
dx
=
self
.
forward_decoder
(
x
,
feature_extractor
,
self
.
shared_decoder
,
encoder_mask
,
)
xs
.
append
(
dx
)
if
feature_extractor
.
decoder
is
not
None
:
dx
=
self
.
forward_decoder
(
x
,
feature_extractor
,
feature_extractor
.
decoder
,
encoder_mask
,
)
xs
.
append
(
dx
)
orig_x
=
x
assert
len
(
xs
)
>
0
p
=
next
(
self
.
ema
.
model
.
parameters
())
device
=
x
.
device
dtype
=
x
.
dtype
ema_device
=
p
.
device
ema_dtype
=
p
.
dtype
if
not
self
.
cfg
.
ema_same_dtype
:
dtype
=
ema_dtype
if
ema_device
!=
device
or
ema_dtype
!=
dtype
:
logger
.
info
(
f
"adjusting ema dtype to
{
dtype
}
and device to
{
device
}
"
)
self
.
ema
.
model
=
self
.
ema
.
model
.
to
(
dtype
=
dtype
,
device
=
device
)
ema_dtype
=
dtype
def
to_device
(
d
):
for
k
,
p
in
d
.
items
():
if
isinstance
(
d
[
k
],
dict
):
to_device
(
d
[
k
])
else
:
d
[
k
]
=
p
.
to
(
device
=
device
)
to_device
(
self
.
ema
.
fp32_params
)
tm
=
self
.
ema
.
model
with
torch
.
no_grad
():
tm
.
eval
()
if
self
.
cfg
.
ema_encoder_only
:
assert
target
is
None
ema_input
=
extractor_out
[
"local_features"
]
ema_input
=
feature_extractor
.
contextualized_features
(
ema_input
.
to
(
dtype
=
ema_dtype
),
padding_mask
,
mask
=
False
,
remove_masked
=
False
,
)
ema_blocks
=
tm
else
:
ema_blocks
=
tm
.
blocks
if
feature_extractor
.
modality_cfg
.
ema_local_encoder
:
inp
=
(
target
.
to
(
dtype
=
ema_dtype
)
if
target
is
not
None
else
source
.
to
(
dtype
=
ema_dtype
)
)
ema_input
=
tm
.
modality_encoders
[
mode
](
inp
,
padding_mask
,
mask
=
False
,
remove_masked
=
False
,
)
else
:
assert
target
is
None
ema_input
=
extractor_out
[
"local_features"
]
ema_feature_enc
=
tm
.
modality_encoders
[
mode
]
ema_input
=
ema_feature_enc
.
contextualized_features
(
ema_input
.
to
(
dtype
=
ema_dtype
),
padding_mask
,
mask
=
False
,
remove_masked
=
False
,
)
ema_padding_mask
=
ema_input
[
"padding_mask"
]
ema_alibi_bias
=
ema_input
.
get
(
"alibi_bias"
,
None
)
ema_alibi_scale
=
ema_input
.
get
(
"alibi_scale"
,
None
)
ema_input
=
ema_input
[
"x"
]
y
=
[]
ema_x
=
[]
extra_tokens
=
feature_extractor
.
modality_cfg
.
num_extra_tokens
for
i
,
blk
in
enumerate
(
ema_blocks
):
ab
=
ema_alibi_bias
if
ab
is
not
None
and
alibi_scale
is
not
None
:
scale
=
(
ema_alibi_scale
[
i
]
if
ema_alibi_scale
.
size
(
0
)
>
1
else
ema_alibi_scale
.
squeeze
(
0
)
)
ab
=
ab
*
scale
.
type_as
(
ab
)
ema_input
,
lr
=
blk
(
ema_input
,
padding_mask
=
ema_padding_mask
,
alibi_bias
=
ab
,
)
y
.
append
(
lr
[:,
extra_tokens
:])
ema_x
.
append
(
ema_input
[:,
extra_tokens
:])
y
=
self
.
make_targets
(
y
,
self
.
average_top_k_layers
)
orig_targets
=
y
if
self
.
cfg
.
clone_batch
>
1
:
y
=
y
.
repeat_interleave
(
self
.
cfg
.
clone_batch
,
0
)
masked
=
encoder_mask
.
mask
.
unsqueeze
(
-
1
)
masked_b
=
encoder_mask
.
mask
.
bool
()
y
=
y
[
masked_b
]
if
xs
[
0
].
size
(
1
)
==
masked_b
.
size
(
1
):
xs
=
[
x
[
masked_b
]
for
x
in
xs
]
else
:
xs
=
[
x
.
reshape
(
-
1
,
x
.
size
(
-
1
))
for
x
in
xs
]
sample_size
=
masked
.
sum
().
long
()
result
=
{
"losses"
:
{},
"sample_size"
:
sample_size
,
}
sample_size
=
result
[
"sample_size"
]
if
self
.
cfg
.
cls_loss
>
0
:
assert
extra_tokens
>
0
cls_target
=
orig_targets
.
mean
(
dim
=
1
)
if
self
.
cfg
.
clone_batch
>
1
:
cls_target
=
cls_target
.
repeat_interleave
(
self
.
cfg
.
clone_batch
,
0
)
cls_pred
=
x
[:,
extra_tokens
-
1
]
result
[
"losses"
][
"cls"
]
=
self
.
d2v_loss
(
cls_pred
,
cls_target
)
*
(
self
.
cfg
.
cls_loss
*
sample_size
)
if
self
.
cfg
.
recon_loss
>
0
:
with
torch
.
no_grad
():
target
=
feature_extractor
.
patchify
(
source
)
mean
=
target
.
mean
(
dim
=-
1
,
keepdim
=
True
)
var
=
target
.
var
(
dim
=-
1
,
keepdim
=
True
)
target
=
(
target
-
mean
)
/
(
var
+
1.0e-6
)
**
0.5
if
self
.
cfg
.
clone_batch
>
1
:
target
=
target
.
repeat_interleave
(
self
.
cfg
.
clone_batch
,
0
)
if
masked_b
is
not
None
:
target
=
target
[
masked_b
]
recon
=
xs
[
0
]
if
self
.
recon_proj
is
not
None
:
recon
=
self
.
recon_proj
(
recon
)
result
[
"losses"
][
"recon"
]
=
(
self
.
d2v_loss
(
recon
,
target
.
float
())
*
self
.
cfg
.
recon_loss
)
if
self
.
cfg
.
d2v_loss
>
0
:
for
i
,
x
in
enumerate
(
xs
):
reg_loss
=
self
.
d2v_loss
(
x
,
y
)
n
=
f
"
{
mode
}
_regression_
{
i
}
"
if
len
(
xs
)
>
1
else
f
"
{
mode
}
_regression"
result
[
"losses"
][
n
]
=
reg_loss
*
self
.
cfg
.
d2v_loss
suffix
=
""
if
len
(
self
.
modalities
)
==
1
else
f
"_
{
mode
}
"
with
torch
.
no_grad
():
if
encoder_mask
is
not
None
:
result
[
"masked_pct"
]
=
1
-
(
encoder_mask
.
ids_keep
.
size
(
1
)
/
encoder_mask
.
ids_restore
.
size
(
1
)
)
for
i
,
x
in
enumerate
(
xs
):
n
=
f
"pred_var
{
suffix
}
_
{
i
}
"
if
len
(
xs
)
>
1
else
f
"pred_var
{
suffix
}
"
result
[
n
]
=
self
.
compute_var
(
x
.
float
())
if
self
.
ema
is
not
None
:
for
k
,
v
in
self
.
ema
.
logs
.
items
():
result
[
k
]
=
v
y
=
y
.
float
()
result
[
f
"target_var
{
suffix
}
"
]
=
self
.
compute_var
(
y
)
if
self
.
num_updates
>
5000
:
if
result
[
f
"target_var
{
suffix
}
"
]
<
self
.
cfg
.
min_target_var
:
logger
.
error
(
f
"target var is
{
result
[
f
'target_var
{
suffix
}
'].item()
}
<
{
self
.
cfg
.
min_target_var
}
, exiting (
{
mode
}
)"
)
raise
Exception
(
f
"target var is
{
result
[
f
'target_var
{
suffix
}
'].item()
}
<
{
self
.
cfg
.
min_target_var
}
, exiting (
{
mode
}
)"
)
for
k
in
result
.
keys
():
if
k
.
startswith
(
"pred_var"
)
and
result
[
k
]
<
self
.
cfg
.
min_pred_var
:
logger
.
error
(
f
"
{
k
}
is
{
result
[
k
].
item
()
}
<
{
self
.
cfg
.
min_pred_var
}
, exiting (
{
mode
}
)"
)
raise
Exception
(
f
"
{
k
}
is
{
result
[
k
].
item
()
}
<
{
self
.
cfg
.
min_pred_var
}
, exiting (
{
mode
}
)"
)
result
[
"ema_decay"
]
=
self
.
ema
.
get_decay
()
*
1000
return
result
def
forward_decoder
(
self
,
x
,
feature_extractor
,
decoder
,
mask_info
,
):
x
=
feature_extractor
.
decoder_input
(
x
,
mask_info
)
x
=
decoder
(
*
x
)
return
x
def
d2v_loss
(
self
,
x
,
y
):
x
=
x
.
view
(
-
1
,
x
.
size
(
-
1
)).
float
()
y
=
y
.
view
(
-
1
,
x
.
size
(
-
1
))
if
self
.
loss_beta
==
0
:
loss
=
F
.
mse_loss
(
x
,
y
,
reduction
=
"none"
)
else
:
loss
=
F
.
smooth_l1_loss
(
x
,
y
,
reduction
=
"none"
,
beta
=
self
.
loss_beta
)
if
self
.
loss_scale
is
not
None
:
scale
=
self
.
loss_scale
else
:
scale
=
1
/
math
.
sqrt
(
x
.
size
(
-
1
))
reg_loss
=
loss
*
scale
return
reg_loss
def
make_targets
(
self
,
y
,
num_layers
):
with
torch
.
no_grad
():
target_layer_results
=
y
[
-
num_layers
:]
permuted
=
False
if
self
.
cfg
.
instance_norm_target_layer
or
self
.
cfg
.
batch_norm_target_layer
:
target_layer_results
=
[
tl
.
transpose
(
1
,
2
)
for
tl
in
target_layer_results
# BTC -> BCT
]
permuted
=
True
if
self
.
cfg
.
batch_norm_target_layer
:
target_layer_results
=
[
F
.
batch_norm
(
tl
.
float
(),
running_mean
=
None
,
running_var
=
None
,
training
=
True
)
for
tl
in
target_layer_results
]
if
self
.
cfg
.
instance_norm_target_layer
:
target_layer_results
=
[
F
.
instance_norm
(
tl
.
float
())
for
tl
in
target_layer_results
]
if
permuted
:
target_layer_results
=
[
tl
.
transpose
(
1
,
2
)
for
tl
in
target_layer_results
# BCT -> BTC
]
if
self
.
cfg
.
layer_norm_target_layer
:
target_layer_results
=
[
F
.
layer_norm
(
tl
.
float
(),
tl
.
shape
[
-
1
:])
for
tl
in
target_layer_results
]
y
=
target_layer_results
[
0
].
float
()
for
tl
in
target_layer_results
[
1
:]:
y
.
add_
(
tl
.
float
())
y
=
y
.
div_
(
len
(
target_layer_results
))
if
self
.
cfg
.
layer_norm_targets
:
y
=
F
.
layer_norm
(
y
,
y
.
shape
[
-
1
:])
if
self
.
cfg
.
instance_norm_targets
:
y
=
F
.
instance_norm
(
y
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
return
y
@
staticmethod
def
compute_var
(
y
):
y
=
y
.
view
(
-
1
,
y
.
size
(
-
1
))
if
dist
.
is_initialized
():
zc
=
torch
.
tensor
(
y
.
size
(
0
)).
cuda
()
zs
=
y
.
sum
(
dim
=
0
)
zss
=
(
y
**
2
).
sum
(
dim
=
0
)
dist
.
all_reduce
(
zc
)
dist
.
all_reduce
(
zs
)
dist
.
all_reduce
(
zss
)
var
=
zss
/
(
zc
-
1
)
-
(
zs
**
2
)
/
(
zc
*
(
zc
-
1
))
return
torch
.
sqrt
(
var
+
1e-6
).
mean
()
else
:
return
torch
.
sqrt
(
y
.
var
(
dim
=
0
)
+
1e-6
).
mean
()
def
extract_features
(
self
,
source
,
mode
=
None
,
padding_mask
=
None
,
mask
=
False
,
remove_extra_tokens
=
True
):
res
=
self
.
forward
(
source
,
mode
=
mode
,
padding_mask
=
padding_mask
,
mask
=
mask
,
features_only
=
True
,
remove_extra_tokens
=
remove_extra_tokens
,
)
return
res
def
remove_pretraining_modules
(
self
,
modality
=
None
,
keep_decoder
=
False
):
self
.
ema
=
None
self
.
cfg
.
clone_batch
=
1
self
.
recon_proj
=
None
if
not
keep_decoder
:
self
.
shared_decoder
=
None
modality
=
modality
.
lower
()
if
modality
is
not
None
else
None
for
k
in
list
(
self
.
modality_encoders
.
keys
()):
if
modality
is
not
None
and
k
.
lower
()
!=
modality
:
del
self
.
modality_encoders
[
k
]
else
:
self
.
modality_encoders
[
k
].
remove_pretraining_modules
(
keep_decoder
=
keep_decoder
)
if
not
keep_decoder
:
self
.
modality_encoders
[
k
].
decoder
=
None
examples/data2vec/models/data2vec_audio.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
math
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
omegaconf
import
II
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
fairseq.modules
import
EMAModule
,
EMAModuleConfig
from
fairseq.data.data_utils
import
compute_mask_indices
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.wav2vec
import
(
ConvFeatureExtractionModel
,
Wav2Vec2Config
,
TransformerEncoder
,
)
from
fairseq.modules
import
(
GradMultiply
,
LayerNorm
,
)
from
fairseq.utils
import
index_put
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Data2VecAudioConfig
(
Wav2Vec2Config
):
loss_beta
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"beta for smooth l1 loss. 0 means use l2 loss"
}
)
loss_scale
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"how many layers to average"
}
)
layer_norm_target_layer
:
bool
=
False
instance_norm_target_layer
:
bool
=
False
instance_norm_targets
:
bool
=
False
layer_norm_targets
:
bool
=
False
batch_norm_target_layer
:
bool
=
False
group_norm_target_layer
:
bool
=
False
ema_decay
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"initial ema decay rate"
})
ema_end_decay
:
float
=
field
(
default
=
0.9999
,
metadata
=
{
"help"
:
"final ema decay rate"
}
)
# when to finish annealing ema decay rate
ema_anneal_end_step
:
int
=
II
(
"optimization.max_update"
)
ema_transformer_only
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to momentum update only the transformer"
},
)
ema_layers_only
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to momentum update only the transformer layers"
},
)
max_update
:
int
=
II
(
"optimization.max_update"
)
min_target_var
:
float
=
field
(
default
=
0.1
,
metadata
=
{
"help"
:
"stop training if target var falls below this"
}
)
min_pred_var
:
float
=
field
(
default
=
0.01
,
metadata
=
{
"help"
:
"stop training if prediction var falls below this"
},
)
def
get_annealed_rate
(
start
,
end
,
curr_step
,
total_steps
):
r
=
end
-
start
pct_remaining
=
1
-
curr_step
/
total_steps
return
end
-
r
*
pct_remaining
@
register_model
(
"data2vec_audio"
,
dataclass
=
Data2VecAudioConfig
)
class
Data2VecAudioModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
Data2VecAudioConfig
):
super
().
__init__
()
self
.
cfg
=
cfg
feature_enc_layers
=
eval
(
cfg
.
conv_feature_layers
)
self
.
extractor_embed
=
feature_enc_layers
[
-
1
][
0
]
self
.
ema
=
None
self
.
embed
=
cfg
.
encoder_embed_dim
self
.
average_top_k_layers
=
cfg
.
average_top_k_layers
self
.
loss_beta
=
cfg
.
loss_beta
self
.
loss_scale
=
cfg
.
loss_scale
self
.
feature_extractor
=
ConvFeatureExtractionModel
(
conv_layers
=
feature_enc_layers
,
dropout
=
0.0
,
mode
=
cfg
.
extractor_mode
,
conv_bias
=
cfg
.
conv_bias
,
)
self
.
post_extract_proj
=
nn
.
Linear
(
self
.
extractor_embed
,
cfg
.
encoder_embed_dim
)
self
.
mask_prob
=
cfg
.
mask_prob
self
.
mask_selection
=
cfg
.
mask_selection
self
.
mask_other
=
cfg
.
mask_other
self
.
mask_length
=
cfg
.
mask_length
self
.
no_mask_overlap
=
cfg
.
no_mask_overlap
self
.
mask_min_space
=
cfg
.
mask_min_space
self
.
mask_channel_prob
=
cfg
.
mask_channel_prob
self
.
mask_channel_before
=
cfg
.
mask_channel_before
self
.
mask_channel_selection
=
cfg
.
mask_channel_selection
self
.
mask_channel_other
=
cfg
.
mask_channel_other
self
.
mask_channel_length
=
cfg
.
mask_channel_length
self
.
no_mask_channel_overlap
=
cfg
.
no_mask_channel_overlap
self
.
mask_channel_min_space
=
cfg
.
mask_channel_min_space
self
.
dropout_input
=
nn
.
Dropout
(
cfg
.
dropout_input
)
self
.
dropout_features
=
nn
.
Dropout
(
cfg
.
dropout_features
)
self
.
feature_grad_mult
=
cfg
.
feature_grad_mult
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
cfg
.
encoder_embed_dim
).
uniform_
()
)
self
.
encoder
=
TransformerEncoder
(
cfg
)
self
.
layer_norm
=
LayerNorm
(
self
.
extractor_embed
)
self
.
final_proj
=
nn
.
Linear
(
self
.
embed
,
self
.
embed
)
self
.
num_updates
=
0
def
make_ema_teacher
(
self
):
ema_config
=
EMAModuleConfig
(
ema_decay
=
self
.
cfg
.
ema_decay
,
ema_fp32
=
True
,
)
skip_keys
=
set
()
if
self
.
cfg
.
ema_layers_only
:
self
.
cfg
.
ema_transformer_only
=
True
for
k
,
_
in
self
.
encoder
.
pos_conv
.
named_parameters
():
skip_keys
.
add
(
f
"pos_conv.
{
k
}
"
)
self
.
ema
=
EMAModule
(
self
.
encoder
if
self
.
cfg
.
ema_transformer_only
else
self
,
ema_config
,
skip_keys
=
skip_keys
,
)
def
set_num_updates
(
self
,
num_updates
):
super
().
set_num_updates
(
num_updates
)
if
self
.
ema
is
None
and
self
.
final_proj
is
not
None
:
logger
.
info
(
f
"making ema teacher"
)
self
.
make_ema_teacher
()
elif
self
.
training
and
self
.
ema
is
not
None
:
if
self
.
cfg
.
ema_decay
!=
self
.
cfg
.
ema_end_decay
:
if
num_updates
>=
self
.
cfg
.
ema_anneal_end_step
:
decay
=
self
.
cfg
.
ema_end_decay
else
:
decay
=
get_annealed_rate
(
self
.
cfg
.
ema_decay
,
self
.
cfg
.
ema_end_decay
,
num_updates
,
self
.
cfg
.
ema_anneal_end_step
,
)
self
.
ema
.
set_decay
(
decay
)
if
self
.
ema
.
get_decay
()
<
1
:
self
.
ema
.
step
(
self
.
encoder
if
self
.
cfg
.
ema_transformer_only
else
self
)
self
.
num_updates
=
num_updates
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
state
=
super
().
state_dict
(
destination
,
prefix
,
keep_vars
)
if
self
.
ema
is
not
None
:
state
[
prefix
+
"_ema"
]
=
self
.
ema
.
fp32_params
return
state
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
if
self
.
ema
is
not
None
:
k
=
prefix
+
"_ema"
assert
k
in
state_dict
self
.
ema
.
restore
(
state_dict
[
k
],
True
)
del
state_dict
[
k
]
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
@
classmethod
def
build_model
(
cls
,
cfg
:
Data2VecAudioConfig
,
task
=
None
):
"""Build a new model instance."""
return
cls
(
cfg
)
def
apply_mask
(
self
,
x
,
padding_mask
,
mask_indices
=
None
,
mask_channel_indices
=
None
,
):
B
,
T
,
C
=
x
.
shape
if
self
.
mask_channel_prob
>
0
and
self
.
mask_channel_before
:
mask_channel_indices
=
compute_mask_indices
(
(
B
,
C
),
None
,
self
.
mask_channel_prob
,
self
.
mask_channel_length
,
self
.
mask_channel_selection
,
self
.
mask_channel_other
,
no_overlap
=
self
.
no_mask_channel_overlap
,
min_space
=
self
.
mask_channel_min_space
,
)
mask_channel_indices
=
(
torch
.
from_numpy
(
mask_channel_indices
)
.
to
(
x
.
device
)
.
unsqueeze
(
1
)
.
expand
(
-
1
,
T
,
-
1
)
)
x
[
mask_channel_indices
]
=
0
if
self
.
mask_prob
>
0
:
if
mask_indices
is
None
:
mask_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_prob
,
self
.
mask_length
,
self
.
mask_selection
,
self
.
mask_other
,
min_masks
=
1
,
no_overlap
=
self
.
no_mask_overlap
,
min_space
=
self
.
mask_min_space
,
require_same_masks
=
self
.
cfg
.
require_same_masks
,
mask_dropout
=
self
.
cfg
.
mask_dropout
,
)
mask_indices
=
torch
.
from_numpy
(
mask_indices
).
to
(
x
.
device
)
x
=
index_put
(
x
,
mask_indices
,
self
.
mask_emb
)
else
:
mask_indices
=
None
if
self
.
mask_channel_prob
>
0
and
not
self
.
mask_channel_before
:
if
mask_channel_indices
is
None
:
mask_channel_indices
=
compute_mask_indices
(
(
B
,
C
),
None
,
self
.
mask_channel_prob
,
self
.
mask_channel_length
,
self
.
mask_channel_selection
,
self
.
mask_channel_other
,
no_overlap
=
self
.
no_mask_channel_overlap
,
min_space
=
self
.
mask_channel_min_space
,
)
mask_channel_indices
=
(
torch
.
from_numpy
(
mask_channel_indices
)
.
to
(
x
.
device
)
.
unsqueeze
(
1
)
.
expand
(
-
1
,
T
,
-
1
)
)
x
=
index_put
(
x
,
mask_channel_indices
,
0
)
return
x
,
mask_indices
def
_get_feat_extract_output_lengths
(
self
,
input_lengths
:
torch
.
LongTensor
):
"""
Computes the output length of the convolutional layers
"""
def
_conv_out_length
(
input_length
,
kernel_size
,
stride
):
return
torch
.
floor
((
input_length
-
kernel_size
)
/
stride
+
1
)
conv_cfg_list
=
eval
(
self
.
cfg
.
conv_feature_layers
)
for
i
in
range
(
len
(
conv_cfg_list
)):
input_lengths
=
_conv_out_length
(
input_lengths
,
conv_cfg_list
[
i
][
1
],
conv_cfg_list
[
i
][
2
]
)
return
input_lengths
.
to
(
torch
.
long
)
def
forward
(
self
,
source
,
padding_mask
=
None
,
mask
=
True
,
features_only
=
False
,
layer
=
None
,
mask_indices
=
None
,
mask_channel_indices
=
None
,
padding_count
=
None
,
):
features
=
source
if
self
.
feature_grad_mult
>
0
:
features
=
self
.
feature_extractor
(
features
)
if
self
.
feature_grad_mult
!=
1.0
:
features
=
GradMultiply
.
apply
(
features
,
self
.
feature_grad_mult
)
else
:
with
torch
.
no_grad
():
features
=
self
.
feature_extractor
(
features
)
features
=
features
.
transpose
(
1
,
2
)
features
=
self
.
layer_norm
(
features
)
orig_padding_mask
=
padding_mask
if
padding_mask
is
not
None
and
padding_mask
.
any
():
input_lengths
=
(
1
-
padding_mask
.
long
()).
sum
(
-
1
)
# apply conv formula to get real output_lengths
output_lengths
=
self
.
_get_feat_extract_output_lengths
(
input_lengths
)
padding_mask
=
torch
.
zeros
(
features
.
shape
[:
2
],
dtype
=
features
.
dtype
,
device
=
features
.
device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask
[
(
torch
.
arange
(
padding_mask
.
shape
[
0
],
device
=
padding_mask
.
device
),
output_lengths
-
1
,
)
]
=
1
padding_mask
=
(
1
-
padding_mask
.
flip
([
-
1
]).
cumsum
(
-
1
).
flip
([
-
1
])).
bool
()
else
:
padding_mask
=
None
if
self
.
post_extract_proj
is
not
None
:
features
=
self
.
post_extract_proj
(
features
)
pre_encoder_features
=
None
if
self
.
cfg
.
ema_transformer_only
:
pre_encoder_features
=
features
.
clone
()
features
=
self
.
dropout_input
(
features
)
if
mask
:
x
,
mask_indices
=
self
.
apply_mask
(
features
,
padding_mask
,
mask_indices
=
mask_indices
,
mask_channel_indices
=
mask_channel_indices
,
)
else
:
x
=
features
mask_indices
=
None
x
,
layer_results
=
self
.
encoder
(
x
,
padding_mask
=
padding_mask
,
layer
=
layer
,
)
if
features_only
:
return
{
"x"
:
x
,
"padding_mask"
:
padding_mask
,
"layer_results"
:
layer_results
,
}
result
=
{
"losses"
:
{},
}
with
torch
.
no_grad
():
self
.
ema
.
model
.
eval
()
if
self
.
cfg
.
ema_transformer_only
:
y
,
layer_results
=
self
.
ema
.
model
.
extract_features
(
pre_encoder_features
,
padding_mask
=
padding_mask
,
min_layer
=
self
.
cfg
.
encoder_layers
-
self
.
average_top_k_layers
,
)
y
=
{
"x"
:
y
,
"padding_mask"
:
padding_mask
,
"layer_results"
:
layer_results
,
}
else
:
y
=
self
.
ema
.
model
.
extract_features
(
source
=
source
,
padding_mask
=
orig_padding_mask
,
mask
=
False
,
)
target_layer_results
=
[
l
[
2
]
for
l
in
y
[
"layer_results"
]]
permuted
=
False
if
self
.
cfg
.
instance_norm_target_layer
or
self
.
cfg
.
batch_norm_target_layer
:
target_layer_results
=
[
tl
.
permute
(
1
,
2
,
0
)
for
tl
in
target_layer_results
# TBC -> BCT
]
permuted
=
True
if
self
.
cfg
.
batch_norm_target_layer
:
target_layer_results
=
[
F
.
batch_norm
(
tl
.
float
(),
running_mean
=
None
,
running_var
=
None
,
training
=
True
)
for
tl
in
target_layer_results
]
if
self
.
cfg
.
instance_norm_target_layer
:
target_layer_results
=
[
F
.
instance_norm
(
tl
.
float
())
for
tl
in
target_layer_results
]
if
permuted
:
target_layer_results
=
[
tl
.
transpose
(
1
,
2
)
for
tl
in
target_layer_results
# BCT -> BTC
]
if
self
.
cfg
.
group_norm_target_layer
:
target_layer_results
=
[
F
.
layer_norm
(
tl
.
float
(),
tl
.
shape
[
-
2
:])
for
tl
in
target_layer_results
]
if
self
.
cfg
.
layer_norm_target_layer
:
target_layer_results
=
[
F
.
layer_norm
(
tl
.
float
(),
tl
.
shape
[
-
1
:])
for
tl
in
target_layer_results
]
y
=
sum
(
target_layer_results
)
/
len
(
target_layer_results
)
if
self
.
cfg
.
layer_norm_targets
:
y
=
F
.
layer_norm
(
y
.
float
(),
y
.
shape
[
-
1
:])
if
self
.
cfg
.
instance_norm_targets
:
y
=
F
.
instance_norm
(
y
.
float
().
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
if
not
permuted
:
y
=
y
.
transpose
(
0
,
1
)
y
=
y
[
mask_indices
]
x
=
x
[
mask_indices
]
x
=
self
.
final_proj
(
x
)
sz
=
x
.
size
(
-
1
)
if
self
.
loss_beta
==
0
:
loss
=
F
.
mse_loss
(
x
.
float
(),
y
.
float
(),
reduction
=
"none"
).
sum
(
dim
=-
1
)
else
:
loss
=
F
.
smooth_l1_loss
(
x
.
float
(),
y
.
float
(),
reduction
=
"none"
,
beta
=
self
.
loss_beta
).
sum
(
dim
=-
1
)
if
self
.
loss_scale
is
not
None
:
scale
=
self
.
loss_scale
else
:
scale
=
1
/
math
.
sqrt
(
sz
)
result
[
"losses"
][
"regression"
]
=
loss
.
sum
()
*
scale
if
"sample_size"
not
in
result
:
result
[
"sample_size"
]
=
loss
.
numel
()
with
torch
.
no_grad
():
result
[
"target_var"
]
=
self
.
compute_var
(
y
)
result
[
"pred_var"
]
=
self
.
compute_var
(
x
.
float
())
if
self
.
num_updates
>
5000
and
result
[
"target_var"
]
<
self
.
cfg
.
min_target_var
:
logger
.
error
(
f
"target var is
{
result
[
'target_var'
].
item
()
}
<
{
self
.
cfg
.
min_target_var
}
, exiting"
)
raise
Exception
(
f
"target var is
{
result
[
'target_var'
].
item
()
}
<
{
self
.
cfg
.
min_target_var
}
, exiting"
)
if
self
.
num_updates
>
5000
and
result
[
"pred_var"
]
<
self
.
cfg
.
min_pred_var
:
logger
.
error
(
f
"pred var is
{
result
[
'pred_var'
].
item
()
}
<
{
self
.
cfg
.
min_pred_var
}
, exiting"
)
raise
Exception
(
f
"pred var is
{
result
[
'pred_var'
].
item
()
}
<
{
self
.
cfg
.
min_pred_var
}
, exiting"
)
if
self
.
ema
is
not
None
:
result
[
"ema_decay"
]
=
self
.
ema
.
get_decay
()
*
1000
return
result
@
staticmethod
def
compute_var
(
y
):
y
=
y
.
view
(
-
1
,
y
.
size
(
-
1
))
if
dist
.
is_initialized
():
zc
=
torch
.
tensor
(
y
.
size
(
0
)).
cuda
()
zs
=
y
.
sum
(
dim
=
0
)
zss
=
(
y
**
2
).
sum
(
dim
=
0
)
dist
.
all_reduce
(
zc
)
dist
.
all_reduce
(
zs
)
dist
.
all_reduce
(
zss
)
var
=
zss
/
(
zc
-
1
)
-
(
zs
**
2
)
/
(
zc
*
(
zc
-
1
))
return
torch
.
sqrt
(
var
+
1e-6
).
mean
()
else
:
return
torch
.
sqrt
(
y
.
var
(
dim
=
0
)
+
1e-6
).
mean
()
def
extract_features
(
self
,
source
,
padding_mask
,
mask
=
False
,
layer
=
None
):
res
=
self
.
forward
(
source
,
padding_mask
,
mask
=
mask
,
features_only
=
True
,
layer
=
layer
,
)
return
res
def
remove_pretraining_modules
(
self
,
last_layer
=
None
):
self
.
final_proj
=
None
self
.
ema
=
None
if
last_layer
is
not
None
:
self
.
encoder
.
layers
=
nn
.
ModuleList
(
l
for
i
,
l
in
enumerate
(
self
.
encoder
.
layers
)
if
i
<=
last_layer
)
examples/data2vec/models/data2vec_image_classification.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import
logging
from
dataclasses
import
dataclass
from
typing
import
Any
from
omegaconf
import
II
,
MISSING
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
checkpoint_utils
,
tasks
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Data2VecImageClassificationConfig
(
FairseqDataclass
):
model_path
:
str
=
MISSING
no_pretrained_weights
:
bool
=
False
num_classes
:
int
=
1000
mixup
:
float
=
0.8
cutmix
:
float
=
1.0
label_smoothing
:
float
=
0.1
pretrained_model_args
:
Any
=
None
data
:
str
=
II
(
"task.data"
)
@
register_model
(
"data2vec_image_classification"
,
dataclass
=
Data2VecImageClassificationConfig
)
class
Data2VecImageClassificationModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
Data2VecImageClassificationConfig
):
super
().
__init__
()
self
.
cfg
=
cfg
if
cfg
.
pretrained_model_args
is
None
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
cfg
.
model_path
,
{})
pretrained_args
=
state
.
get
(
"cfg"
,
None
)
pretrained_args
.
criterion
=
None
pretrained_args
.
lr_scheduler
=
None
cfg
.
pretrained_model_args
=
pretrained_args
logger
.
info
(
pretrained_args
)
else
:
state
=
None
pretrained_args
=
cfg
.
pretrained_model_args
pretrained_args
.
task
.
data
=
cfg
.
data
task
=
tasks
.
setup_task
(
pretrained_args
.
task
)
model
=
task
.
build_model
(
pretrained_args
.
model
,
from_checkpoint
=
True
)
model
.
remove_pretraining_modules
()
self
.
model
=
model
if
state
is
not
None
and
not
cfg
.
no_pretrained_weights
:
self
.
load_model_weights
(
state
,
model
,
cfg
)
self
.
fc_norm
=
nn
.
LayerNorm
(
pretrained_args
.
model
.
embed_dim
)
self
.
head
=
nn
.
Linear
(
pretrained_args
.
model
.
embed_dim
,
cfg
.
num_classes
)
self
.
head
.
weight
.
data
.
mul_
(
1e-3
)
self
.
head
.
bias
.
data
.
mul_
(
1e-3
)
self
.
mixup_fn
=
None
if
cfg
.
mixup
>
0
or
cfg
.
cutmix
>
0
:
from
timm.data
import
Mixup
self
.
mixup_fn
=
Mixup
(
mixup_alpha
=
cfg
.
mixup
,
cutmix_alpha
=
cfg
.
cutmix
,
cutmix_minmax
=
None
,
prob
=
1.0
,
switch_prob
=
0.5
,
mode
=
"batch"
,
label_smoothing
=
cfg
.
label_smoothing
,
num_classes
=
cfg
.
num_classes
,
)
def
load_model_weights
(
self
,
state
,
model
,
cfg
):
if
"_ema"
in
state
[
"model"
]:
del
state
[
"model"
][
"_ema"
]
model
.
load_state_dict
(
state
[
"model"
],
strict
=
True
)
@
classmethod
def
build_model
(
cls
,
cfg
:
Data2VecImageClassificationConfig
,
task
=
None
):
"""Build a new model instance."""
return
cls
(
cfg
)
def
forward
(
self
,
img
,
label
=
None
,
):
if
self
.
training
and
self
.
mixup_fn
is
not
None
and
label
is
not
None
:
img
,
label
=
self
.
mixup_fn
(
img
,
label
)
x
=
self
.
model
(
img
,
mask
=
False
)
x
=
x
[:,
1
:]
x
=
self
.
fc_norm
(
x
.
mean
(
1
))
x
=
self
.
head
(
x
)
if
label
is
None
:
return
x
if
self
.
training
and
self
.
mixup_fn
is
not
None
:
loss
=
-
label
*
F
.
log_softmax
(
x
.
float
(),
dim
=-
1
)
else
:
loss
=
F
.
cross_entropy
(
x
.
float
(),
label
,
label_smoothing
=
self
.
cfg
.
label_smoothing
if
self
.
training
else
0
,
reduction
=
"none"
,
)
result
=
{
"losses"
:
{
"regression"
:
loss
},
"sample_size"
:
img
.
size
(
0
),
}
if
not
self
.
training
:
with
torch
.
no_grad
():
pred
=
x
.
argmax
(
-
1
)
correct
=
(
pred
==
label
).
sum
()
result
[
"correct"
]
=
correct
return
result
examples/data2vec/models/data2vec_text.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
import
logging
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
omegaconf
import
II
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.modules
import
EMAModule
,
EMAModuleConfig
from
fairseq.models
import
(
FairseqEncoder
,
FairseqEncoderModel
,
register_model
,
)
from
fairseq.models.roberta.model
import
RobertaLMHead
,
RobertaClassificationHead
from
fairseq.models.transformer
import
TransformerEncoder
,
TransformerConfig
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Data2VecTextConfig
(
FairseqDataclass
):
max_positions
:
int
=
II
(
"task.tokens_per_sample"
)
head_layers
:
int
=
1
transformer
:
TransformerConfig
=
TransformerConfig
()
load_checkpoint_heads
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"(re-)register and load heads when loading checkpoints"
},
)
loss_beta
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"beta for smooth l1 loss. 0 means use l2 loss"
}
)
loss_scale
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"how many layers to average"
}
)
layer_norm_target_layer
:
bool
=
False
instance_norm_target_layer
:
bool
=
False
batch_norm_target_layer
:
bool
=
False
instance_norm_targets
:
bool
=
False
layer_norm_targets
:
bool
=
False
ema_decay
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"initial ema decay rate"
})
ema_end_decay
:
float
=
field
(
default
=
0.9999
,
metadata
=
{
"help"
:
"final ema decay rate"
}
)
# when to finish annealing ema decay rate
ema_anneal_end_step
:
int
=
II
(
"optimization.max_update"
)
ema_transformer_layers_only
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to momentum update only the transformer layers"
},
)
def
get_annealed_rate
(
start
,
end
,
curr_step
,
total_steps
):
r
=
end
-
start
pct_remaining
=
1
-
curr_step
/
total_steps
return
end
-
r
*
pct_remaining
@
register_model
(
"data2vec_text"
,
dataclass
=
Data2VecTextConfig
)
class
Data2VecTextModel
(
FairseqEncoderModel
):
def
__init__
(
self
,
cfg
:
Data2VecTextConfig
,
encoder
):
super
().
__init__
(
encoder
)
self
.
cfg
=
cfg
# We follow BERT's random weight initialization
self
.
apply
(
init_bert_params
)
self
.
classification_heads
=
nn
.
ModuleDict
()
@
classmethod
def
build_model
(
cls
,
cfg
,
task
):
"""Build a new model instance."""
encoder
=
Data2VecTextEncoder
(
cfg
,
task
.
source_dictionary
,
task
.
cfg
.
data
)
return
cls
(
cfg
,
encoder
)
def
forward
(
self
,
src_tokens
,
target_tokens
=
None
,
features_only
=
False
,
return_all_hiddens
=
False
,
classification_head_name
=
None
,
**
kwargs
,
):
if
classification_head_name
is
not
None
:
features_only
=
True
res
=
self
.
encoder
(
src_tokens
,
target_tokens
,
features_only
,
return_all_hiddens
,
**
kwargs
)
if
isinstance
(
res
,
tuple
):
x
,
extra
=
res
else
:
return
res
if
classification_head_name
is
not
None
:
x
=
self
.
classification_heads
[
classification_head_name
](
x
)
return
x
,
extra
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
"""Get normalized probabilities (or log probs) from a net's output."""
logits
=
net_output
[
0
].
float
()
if
log_probs
:
return
F
.
log_softmax
(
logits
,
dim
=-
1
)
else
:
return
F
.
softmax
(
logits
,
dim
=-
1
)
def
register_classification_head
(
self
,
name
,
num_classes
=
None
,
inner_dim
=
None
,
**
kwargs
):
"""Register a classification head."""
if
name
in
self
.
classification_heads
:
prev_num_classes
=
self
.
classification_heads
[
name
].
out_proj
.
out_features
prev_inner_dim
=
self
.
classification_heads
[
name
].
dense
.
out_features
if
num_classes
!=
prev_num_classes
or
inner_dim
!=
prev_inner_dim
:
logger
.
warning
(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})"
.
format
(
name
,
num_classes
,
prev_num_classes
,
inner_dim
,
prev_inner_dim
)
)
self
.
classification_heads
[
name
]
=
RobertaClassificationHead
(
input_dim
=
self
.
cfg
.
transformer
.
encoder
.
embed_dim
,
inner_dim
=
inner_dim
or
self
.
cfg
.
transformer
.
encoder
.
embed_dim
,
num_classes
=
num_classes
,
activation_fn
=
"tanh"
,
pooler_dropout
=
0
,
)
@
property
def
supported_targets
(
self
):
return
{
"self"
}
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
prefix
=
name
+
"."
if
name
!=
""
else
""
# rename decoder -> encoder before upgrading children modules
for
k
in
list
(
state_dict
.
keys
()):
if
k
.
startswith
(
prefix
+
"decoder"
):
new_k
=
prefix
+
"encoder"
+
k
[
len
(
prefix
+
"decoder"
)
:]
state_dict
[
new_k
]
=
state_dict
[
k
]
del
state_dict
[
k
]
# rename emb_layer_norm -> layernorm_embedding
for
k
in
list
(
state_dict
.
keys
()):
if
".emb_layer_norm."
in
k
:
new_k
=
k
.
replace
(
".emb_layer_norm."
,
".layernorm_embedding."
)
state_dict
[
new_k
]
=
state_dict
[
k
]
del
state_dict
[
k
]
if
self
.
encoder
.
regression_head
is
not
None
:
if
".lm_head."
in
k
:
new_k
=
k
.
replace
(
".lm_head."
,
".regression_head."
)
state_dict
[
new_k
]
=
state_dict
[
k
]
del
state_dict
[
k
]
else
:
if
".regression_head."
in
k
:
del
state_dict
[
k
]
# upgrade children modules
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
# Handle new classification heads present in the state dict.
current_head_names
=
(
[]
if
not
hasattr
(
self
,
"classification_heads"
)
or
self
.
classification_heads
is
None
else
self
.
classification_heads
.
keys
()
)
keys_to_delete
=
[]
for
k
in
state_dict
.
keys
():
if
not
k
.
startswith
(
prefix
+
"classification_heads."
):
continue
head_name
=
k
[
len
(
prefix
+
"classification_heads."
)
:].
split
(
"."
)[
0
]
num_classes
=
state_dict
[
prefix
+
"classification_heads."
+
head_name
+
".out_proj.weight"
].
size
(
0
)
inner_dim
=
state_dict
[
prefix
+
"classification_heads."
+
head_name
+
".dense.weight"
].
size
(
0
)
if
self
.
cfg
.
load_checkpoint_heads
:
if
head_name
not
in
current_head_names
:
self
.
register_classification_head
(
head_name
,
num_classes
,
inner_dim
)
else
:
if
head_name
not
in
current_head_names
:
logger
.
warning
(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}"
.
format
(
head_name
,
k
)
)
keys_to_delete
.
append
(
k
)
elif
(
num_classes
!=
self
.
classification_heads
[
head_name
].
out_proj
.
out_features
or
inner_dim
!=
self
.
classification_heads
[
head_name
].
dense
.
out_features
):
logger
.
warning
(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}"
.
format
(
head_name
,
k
)
)
keys_to_delete
.
append
(
k
)
for
k
in
keys_to_delete
:
del
state_dict
[
k
]
# Copy any newly-added classification heads into the state dict
# with their current weights.
if
(
hasattr
(
self
,
"classification_heads"
)
and
self
.
classification_heads
is
not
None
and
len
(
self
.
classification_heads
)
>
0
):
cur_state
=
self
.
classification_heads
.
state_dict
()
for
k
,
v
in
cur_state
.
items
():
if
prefix
+
"classification_heads."
+
k
not
in
state_dict
:
logger
.
info
(
"Overwriting "
+
prefix
+
"classification_heads."
+
k
)
state_dict
[
prefix
+
"classification_heads."
+
k
]
=
v
for
k
in
list
(
state_dict
.
keys
()):
if
k
.
startswith
(
prefix
+
"encoder.lm_head."
)
or
k
.
startswith
(
prefix
+
"encoder.emb_head."
):
del
state_dict
[
k
]
self
.
encoder
.
lm_head
=
None
if
self
.
encoder
.
target_model
is
None
:
for
k
in
list
(
state_dict
.
keys
()):
if
k
.
startswith
(
prefix
+
"encoder.target_model."
):
del
state_dict
[
k
]
if
(
self
.
encoder
.
ema
is
None
)
and
(
prefix
+
"encoder._ema"
in
state_dict
):
del
state_dict
[
prefix
+
"encoder._ema"
]
def
remove_pretraining_modules
(
self
,
last_layer
=
None
):
self
.
encoder
.
lm_head
=
None
self
.
encoder
.
regression_head
=
None
self
.
encoder
.
ema
=
None
self
.
classification_heads
=
None
if
last_layer
is
not
None
:
self
.
encoder
.
sentence_encoder
.
layers
=
nn
.
ModuleList
(
l
for
i
,
l
in
enumerate
(
self
.
encoder
.
sentence_encoder
.
layers
)
if
i
<=
last_layer
)
self
.
encoder
.
sentence_encoder
.
layer_norm
=
None
class
Data2VecTextEncoder
(
FairseqEncoder
):
def
__init__
(
self
,
cfg
:
Data2VecTextConfig
,
dictionary
,
task_data
):
super
().
__init__
(
dictionary
)
self
.
cfg
=
cfg
embed_tokens
=
self
.
build_embedding
(
len
(
dictionary
),
cfg
.
transformer
.
encoder
.
embed_dim
,
dictionary
.
pad
()
)
self
.
sentence_encoder
=
self
.
build_encoder
(
cfg
,
dictionary
,
embed_tokens
)
self
.
mask_idx
=
dictionary
.
index
(
"<mask>"
)
assert
self
.
mask_idx
!=
dictionary
.
unk
(),
dictionary
.
symbols
self
.
ema
=
None
self
.
average_top_k_layers
=
cfg
.
average_top_k_layers
self
.
loss_scale
=
cfg
.
loss_scale
assert
self
.
cfg
.
head_layers
>=
1
embed_dim
=
cfg
.
transformer
.
encoder
.
embed_dim
curr_dim
=
embed_dim
projs
=
[]
for
i
in
range
(
self
.
cfg
.
head_layers
-
1
):
next_dim
=
embed_dim
*
2
if
i
==
0
else
curr_dim
projs
.
append
(
nn
.
Linear
(
curr_dim
,
next_dim
))
projs
.
append
(
nn
.
GELU
())
curr_dim
=
next_dim
projs
.
append
(
nn
.
Linear
(
curr_dim
,
embed_dim
))
self
.
regression_head
=
nn
.
Sequential
(
*
projs
)
self
.
num_updates
=
0
def
build_embedding
(
self
,
vocab_size
,
embedding_dim
,
padding_idx
):
return
nn
.
Embedding
(
vocab_size
,
embedding_dim
,
padding_idx
)
def
build_encoder
(
self
,
cfg
,
dictionary
,
embed_tokens
):
encoder
=
TransformerEncoder
(
cfg
.
transformer
,
dictionary
,
embed_tokens
,
return_fc
=
True
)
encoder
.
apply
(
init_bert_params
)
return
encoder
def
build_lm_head
(
self
,
embed_dim
,
output_dim
,
activation_fn
,
weight
):
return
RobertaLMHead
(
embed_dim
,
output_dim
,
activation_fn
,
weight
)
def
make_ema_teacher
(
self
):
ema_config
=
EMAModuleConfig
(
ema_decay
=
self
.
cfg
.
ema_decay
,
ema_fp32
=
True
,
)
skip_keys
=
set
()
if
self
.
cfg
.
ema_transformer_layers_only
:
for
k
,
_
in
self
.
sentence_encoder
.
embed_positions
.
named_parameters
():
skip_keys
.
add
(
f
"embed_tokens.
{
k
}
"
)
for
k
,
_
in
self
.
sentence_encoder
.
embed_positions
.
named_parameters
():
skip_keys
.
add
(
f
"embed_positions.
{
k
}
"
)
if
self
.
sentence_encoder
.
layernorm_embedding
is
not
None
:
for
(
k
,
_
,
)
in
self
.
sentence_encoder
.
layernorm_embedding
.
named_parameters
():
skip_keys
.
add
(
f
"layernorm_embedding.
{
k
}
"
)
if
self
.
sentence_encoder
.
layer_norm
is
not
None
:
for
k
,
_
in
self
.
sentence_encoder
.
layer_norm
.
named_parameters
():
skip_keys
.
add
(
f
"layernorm_embedding.
{
k
}
"
)
self
.
ema
=
EMAModule
(
self
.
sentence_encoder
,
ema_config
,
skip_keys
=
skip_keys
,
)
def
set_num_updates
(
self
,
num_updates
):
super
().
set_num_updates
(
num_updates
)
if
self
.
ema
is
None
and
self
.
regression_head
is
not
None
:
logger
.
info
(
f
"making ema teacher"
)
self
.
make_ema_teacher
()
elif
self
.
training
and
self
.
ema
is
not
None
:
if
self
.
cfg
.
ema_decay
!=
self
.
cfg
.
ema_end_decay
:
if
num_updates
>=
self
.
cfg
.
ema_anneal_end_step
:
decay
=
self
.
cfg
.
ema_end_decay
else
:
decay
=
get_annealed_rate
(
self
.
cfg
.
ema_decay
,
self
.
cfg
.
ema_end_decay
,
num_updates
,
self
.
cfg
.
ema_anneal_end_step
,
)
self
.
ema
.
set_decay
(
decay
)
if
self
.
ema
.
get_decay
()
<
1
:
self
.
ema
.
step
(
self
.
sentence_encoder
)
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
state
=
super
().
state_dict
(
destination
,
prefix
,
keep_vars
)
if
self
.
ema
is
not
None
:
state
[
prefix
+
"_ema"
]
=
self
.
ema
.
fp32_params
return
state
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
if
self
.
ema
is
not
None
:
k
=
prefix
+
"_ema"
assert
k
in
state_dict
self
.
ema
.
restore
(
state_dict
[
k
],
True
)
del
state_dict
[
k
]
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
def
forward
(
self
,
src_tokens
,
target_tokens
=
None
,
features_only
=
False
,
return_all_hiddens
=
False
,
masked_tokens
=
None
,
**
unused
,
):
"""
Args:
src_tokens (LongTensor): input tokens of shape `(batch, src_len)`
features_only (bool, optional): skip LM head and just return
features. If True, the output will be of shape
`(batch, src_len, embed_dim)`.
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
Returns:
tuple:
- the LM output of shape `(batch, src_len, vocab)`
- a dictionary of additional data, where 'inner_states'
is a list of hidden states. Note that the hidden
states have shape `(src_len, batch, vocab)`.
"""
x
,
extra
=
self
.
extract_features
(
src_tokens
,
return_all_hiddens
=
return_all_hiddens
)
if
features_only
:
return
x
,
extra
assert
target_tokens
is
not
None
with
torch
.
no_grad
():
# use EMA parameter as the teacher
self
.
ema
.
model
.
eval
()
encoder_out
=
self
.
ema
.
model
(
target_tokens
,
return_all_hiddens
=
True
,
)
y
=
encoder_out
[
"fc_results"
]
y
=
y
[
-
self
.
average_top_k_layers
:]
permuted
=
False
if
self
.
cfg
.
instance_norm_target_layer
or
self
.
cfg
.
batch_norm_target_layer
:
y
=
[
tl
.
permute
(
1
,
2
,
0
)
for
tl
in
y
]
# TBC -> BCT
permuted
=
True
if
self
.
cfg
.
batch_norm_target_layer
:
y
=
[
F
.
batch_norm
(
tl
.
float
(),
running_mean
=
None
,
running_var
=
None
,
training
=
True
)
for
tl
in
y
]
if
self
.
cfg
.
instance_norm_target_layer
:
y
=
[
F
.
instance_norm
(
tl
.
float
())
for
tl
in
y
]
if
permuted
:
y
=
[
tl
.
transpose
(
1
,
2
)
for
tl
in
y
]
# BCT -> BTC
if
self
.
cfg
.
layer_norm_target_layer
:
y
=
[
F
.
layer_norm
(
tl
.
float
(),
tl
.
shape
[
-
1
:])
for
tl
in
y
]
y
=
sum
(
y
)
/
len
(
y
)
if
not
permuted
:
y
=
y
.
transpose
(
0
,
1
)
if
self
.
cfg
.
layer_norm_targets
:
y
=
F
.
layer_norm
(
y
.
float
(),
y
.
shape
[
-
1
:])
if
self
.
cfg
.
instance_norm_targets
:
y
=
F
.
instance_norm
(
y
.
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
masked_indices
=
src_tokens
.
eq
(
self
.
mask_idx
)
x
=
x
[
masked_indices
]
y
=
y
[
masked_indices
]
x
=
self
.
regression_head
(
x
)
sz
=
x
.
size
(
-
1
)
if
self
.
cfg
.
loss_beta
==
0
:
loss
=
F
.
mse_loss
(
x
.
float
(),
y
.
float
(),
reduction
=
"none"
).
sum
(
dim
=-
1
)
else
:
loss
=
F
.
smooth_l1_loss
(
x
.
float
(),
y
.
float
(),
reduction
=
"none"
,
beta
=
self
.
cfg
.
loss_beta
).
sum
(
dim
=-
1
)
result
=
{
"losses"
:
{
"main"
:
loss
.
sum
()
/
math
.
sqrt
(
sz
)
if
self
.
loss_scale
<=
0
else
loss
.
sum
()
*
self
.
loss_scale
,
},
"sample_size"
:
loss
.
numel
(),
}
# logging other values
other_logs
=
{
"ema_decay"
:
self
.
ema
.
get_decay
()
*
1000
}
result
[
"logs"
]
=
other_logs
return
result
def
extract_features
(
self
,
src_tokens
,
return_all_hiddens
=
False
,
**
kwargs
):
encoder_out
=
self
.
sentence_encoder
(
src_tokens
,
return_all_hiddens
=
return_all_hiddens
,
token_embeddings
=
kwargs
.
get
(
"token_embeddings"
,
None
),
)
# T x B x C -> B x T x C
features
=
encoder_out
[
"encoder_out"
][
0
].
transpose
(
0
,
1
)
inner_states
=
encoder_out
[
"encoder_states"
]
if
return_all_hiddens
else
None
return
features
,
{
"inner_states"
:
inner_states
,
"encoder_embedding"
:
encoder_out
[
"encoder_embedding"
][
0
],
}
def
output_layer
(
self
,
features
,
masked_tokens
=
None
,
**
unused
):
return
self
.
lm_head
(
features
,
masked_tokens
)
def
max_positions
(
self
):
"""Maximum output length supported by the encoder."""
return
self
.
cfg
.
max_positions
examples/data2vec/models/data2vec_text_classification.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import
logging
from
dataclasses
import
dataclass
from
typing
import
Any
from
omegaconf
import
II
,
MISSING
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
checkpoint_utils
,
tasks
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.roberta.model
import
RobertaClassificationHead
from
examples.data2vec.data.modality
import
Modality
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Data2VecTextClassificationConfig
(
FairseqDataclass
):
pooler_dropout
:
float
=
0.0
pooler_activation_fn
:
str
=
"tanh"
quant_noise_pq
:
int
=
0
quant_noise_pq_block_size
:
int
=
8
spectral_norm_classification_head
:
bool
=
False
model_path
:
str
=
MISSING
no_pretrained_weights
:
bool
=
False
pretrained_model_args
:
Any
=
None
@
register_model
(
"data2vec_text_classification"
,
dataclass
=
Data2VecTextClassificationConfig
)
class
Data2VecTextClassificationModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
Data2VecTextClassificationConfig
):
super
().
__init__
()
self
.
cfg
=
cfg
if
cfg
.
pretrained_model_args
is
None
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
cfg
.
model_path
,
{})
pretrained_args
=
state
.
get
(
"cfg"
,
None
)
pretrained_args
.
criterion
=
None
pretrained_args
.
lr_scheduler
=
None
cfg
.
pretrained_model_args
=
pretrained_args
logger
.
info
(
pretrained_args
)
else
:
state
=
None
pretrained_args
=
cfg
.
pretrained_model_args
task
=
tasks
.
setup_task
(
pretrained_args
.
task
)
model
=
task
.
build_model
(
pretrained_args
.
model
,
from_checkpoint
=
True
)
model
.
remove_pretraining_modules
()
self
.
model
=
model
if
state
is
not
None
and
not
cfg
.
no_pretrained_weights
:
self
.
load_model_weights
(
state
,
model
,
cfg
)
self
.
classification_heads
=
nn
.
ModuleDict
()
def
load_model_weights
(
self
,
state
,
model
,
cfg
):
for
k
in
list
(
state
[
"model"
].
keys
()):
if
(
k
.
startswith
(
"shared_decoder"
)
or
k
.
startswith
(
"_ema"
)
or
"decoder"
in
k
):
logger
.
info
(
f
"Deleting
{
k
}
from checkpoint"
)
del
state
[
"model"
][
k
]
model
.
load_state_dict
(
state
[
"model"
],
strict
=
True
)
@
classmethod
def
build_model
(
cls
,
cfg
:
Data2VecTextClassificationConfig
,
task
=
None
):
"""Build a new model instance."""
return
cls
(
cfg
)
def
register_classification_head
(
self
,
name
,
num_classes
=
None
,
inner_dim
=
None
,
**
kwargs
):
"""Register a classification head."""
if
name
in
self
.
classification_heads
:
prev_num_classes
=
self
.
classification_heads
[
name
].
out_proj
.
out_features
prev_inner_dim
=
self
.
classification_heads
[
name
].
dense
.
out_features
if
num_classes
!=
prev_num_classes
or
inner_dim
!=
prev_inner_dim
:
logger
.
warning
(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})"
.
format
(
name
,
num_classes
,
prev_num_classes
,
inner_dim
,
prev_inner_dim
)
)
embed_dim
=
self
.
cfg
.
pretrained_model_args
.
model
.
embed_dim
self
.
classification_heads
[
name
]
=
RobertaClassificationHead
(
input_dim
=
embed_dim
,
inner_dim
=
inner_dim
or
embed_dim
,
num_classes
=
num_classes
,
activation_fn
=
self
.
cfg
.
pooler_activation_fn
,
pooler_dropout
=
self
.
cfg
.
pooler_dropout
,
q_noise
=
self
.
cfg
.
quant_noise_pq
,
qn_block_size
=
self
.
cfg
.
quant_noise_pq_block_size
,
do_spectral_norm
=
self
.
cfg
.
spectral_norm_classification_head
,
)
def
forward
(
self
,
source
,
id
,
padding_mask
,
features_only
=
True
,
remove_extra_tokens
=
True
,
classification_head_name
=
None
,
):
encoder_out
=
self
.
model
(
source
,
id
=
id
,
mode
=
Modality
.
TEXT
,
padding_mask
=
padding_mask
,
mask
=
False
,
features_only
=
features_only
,
remove_extra_tokens
=
remove_extra_tokens
)
logits
=
self
.
classification_heads
[
classification_head_name
](
encoder_out
[
"x"
])
return
logits
,
encoder_out
examples/data2vec/models/data2vec_vision.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import
logging
import
math
import
numpy
as
np
import
random
from
dataclasses
import
dataclass
,
field
from
typing
import
Optional
from
omegaconf
import
II
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.distributed
as
dist
from
fairseq.modules
import
EMAModule
,
EMAModuleConfig
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
Data2VecVisionConfig
(
FairseqDataclass
):
layer_scale_init_value
:
float
=
field
(
default
=
1e-4
,
metadata
=
{
"help"
:
"rescale layer outputs, 0 to disable"
}
)
num_mask_patches
:
int
=
field
(
default
=
75
,
metadata
=
{
"help"
:
"number of the visual tokens/patches need be masked"
},
)
min_mask_patches_per_block
:
int
=
16
max_mask_patches_per_block
:
int
=
196
image_size
:
int
=
224
patch_size
:
int
=
16
in_channels
:
int
=
3
shared_rel_pos_bias
:
bool
=
True
drop_path
:
float
=
0.1
attention_dropout
:
float
=
0.0
depth
:
int
=
12
embed_dim
:
int
=
768
num_heads
:
int
=
12
mlp_ratio
:
int
=
4
loss_beta
:
float
=
field
(
default
=
0
,
metadata
=
{
"help"
:
"beta for smooth l1 loss. 0 means use l2 loss"
}
)
loss_scale
:
Optional
[
float
]
=
field
(
default
=
None
,
metadata
=
{
"help"
:
"scale the reconstruction loss by this constant. if None then scales by 1/sqrt(dim)"
},
)
average_top_k_layers
:
int
=
field
(
default
=
8
,
metadata
=
{
"help"
:
"how many layers to average"
}
)
end_of_block_targets
:
bool
=
True
layer_norm_target_layer
:
bool
=
False
instance_norm_target_layer
:
bool
=
False
batch_norm_target_layer
:
bool
=
False
instance_norm_targets
:
bool
=
False
layer_norm_targets
:
bool
=
False
ema_decay
:
float
=
field
(
default
=
0.999
,
metadata
=
{
"help"
:
"initial ema decay rate"
})
ema_end_decay
:
float
=
field
(
default
=
0.9999
,
metadata
=
{
"help"
:
"final ema decay rate"
}
)
# when to finish annealing ema decay rate
ema_anneal_end_step
:
int
=
II
(
"optimization.max_update"
)
ema_transformer_only
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"whether to momentum update only the transformer layers"
},
)
def
get_annealed_rate
(
start
,
end
,
curr_step
,
total_steps
):
r
=
end
-
start
pct_remaining
=
1
-
curr_step
/
total_steps
return
end
-
r
*
pct_remaining
@
register_model
(
"data2vec_vision"
,
dataclass
=
Data2VecVisionConfig
)
class
Data2VecVisionModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
Data2VecVisionConfig
):
super
().
__init__
()
self
.
cfg
=
cfg
self
.
ema
=
None
self
.
average_top_k_layers
=
cfg
.
average_top_k_layers
self
.
loss_beta
=
cfg
.
loss_beta
self
.
loss_scale
=
(
cfg
.
loss_scale
if
cfg
.
loss_scale
is
not
None
else
1
/
math
.
sqrt
(
cfg
.
embed_dim
)
)
self
.
patch_embed
=
PatchEmbed
(
img_size
=
cfg
.
image_size
,
patch_size
=
cfg
.
patch_size
,
in_chans
=
cfg
.
in_channels
,
embed_dim
=
cfg
.
embed_dim
,
)
patch_size
=
self
.
patch_embed
.
patch_size
self
.
window_size
=
(
cfg
.
image_size
//
patch_size
[
0
],
cfg
.
image_size
//
patch_size
[
1
],
)
self
.
cls_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
1
,
1
,
cfg
.
embed_dim
))
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
1
,
1
,
cfg
.
embed_dim
))
nn
.
init
.
trunc_normal_
(
self
.
cls_emb
,
0.02
)
nn
.
init
.
trunc_normal_
(
self
.
mask_emb
,
0.02
)
self
.
encoder
=
TransformerEncoder
(
cfg
,
self
.
patch_embed
.
patch_shape
)
self
.
final_proj
=
nn
.
Linear
(
cfg
.
embed_dim
,
cfg
.
embed_dim
)
self
.
num_updates
=
0
def
make_ema_teacher
(
self
):
ema_config
=
EMAModuleConfig
(
ema_decay
=
self
.
cfg
.
ema_decay
,
ema_fp32
=
True
,
)
self
.
ema
=
EMAModule
(
self
.
encoder
if
self
.
cfg
.
ema_transformer_only
else
self
,
ema_config
,
)
def
set_num_updates
(
self
,
num_updates
):
super
().
set_num_updates
(
num_updates
)
if
self
.
ema
is
None
and
self
.
final_proj
is
not
None
:
logger
.
info
(
f
"making ema teacher"
)
self
.
make_ema_teacher
()
elif
self
.
training
and
self
.
ema
is
not
None
:
if
self
.
cfg
.
ema_decay
!=
self
.
cfg
.
ema_end_decay
:
if
num_updates
>=
self
.
cfg
.
ema_anneal_end_step
:
decay
=
self
.
cfg
.
ema_end_decay
else
:
decay
=
get_annealed_rate
(
self
.
cfg
.
ema_decay
,
self
.
cfg
.
ema_end_decay
,
num_updates
,
self
.
cfg
.
ema_anneal_end_step
,
)
self
.
ema
.
set_decay
(
decay
)
if
self
.
ema
.
get_decay
()
<
1
:
self
.
ema
.
step
(
self
.
encoder
if
self
.
cfg
.
ema_transformer_only
else
self
)
self
.
num_updates
=
num_updates
def
state_dict
(
self
,
destination
=
None
,
prefix
=
""
,
keep_vars
=
False
):
state
=
super
().
state_dict
(
destination
,
prefix
,
keep_vars
)
if
self
.
ema
is
not
None
:
state
[
prefix
+
"_ema"
]
=
self
.
ema
.
fp32_params
return
state
def
_load_from_state_dict
(
self
,
state_dict
,
prefix
,
*
args
,
**
kwargs
):
if
self
.
ema
is
not
None
:
k
=
prefix
+
"_ema"
assert
k
in
state_dict
self
.
ema
.
restore
(
state_dict
[
k
],
True
)
del
state_dict
[
k
]
return
super
().
_load_from_state_dict
(
state_dict
,
prefix
,
*
args
,
**
kwargs
)
@
classmethod
def
build_model
(
cls
,
cfg
:
Data2VecVisionConfig
,
task
=
None
):
"""Build a new model instance."""
return
cls
(
cfg
)
def
make_mask
(
self
,
bsz
,
num_masks
,
min_masks
,
max_masks
):
height
,
width
=
self
.
window_size
masks
=
np
.
zeros
(
shape
=
(
bsz
,
height
,
width
),
dtype
=
np
.
int
)
for
i
in
range
(
bsz
):
mask
=
masks
[
i
]
mask_count
=
0
min_aspect
=
0.3
max_aspect
=
1
/
min_aspect
log_aspect_ratio
=
(
math
.
log
(
min_aspect
),
math
.
log
(
max_aspect
))
def
_mask
(
mask
,
max_mask_patches
):
delta
=
0
for
attempt
in
range
(
10
):
target_area
=
random
.
uniform
(
min_masks
,
max_mask_patches
)
aspect_ratio
=
math
.
exp
(
random
.
uniform
(
*
log_aspect_ratio
))
h
=
int
(
round
(
math
.
sqrt
(
target_area
*
aspect_ratio
)))
w
=
int
(
round
(
math
.
sqrt
(
target_area
/
aspect_ratio
)))
if
w
<
width
and
h
<
height
:
top
=
random
.
randint
(
0
,
height
-
h
)
left
=
random
.
randint
(
0
,
width
-
w
)
num_masked
=
mask
[
top
:
top
+
h
,
left
:
left
+
w
].
sum
()
# Overlap
if
0
<
h
*
w
-
num_masked
<=
max_mask_patches
:
for
i
in
range
(
top
,
top
+
h
):
for
j
in
range
(
left
,
left
+
w
):
if
mask
[
i
,
j
]
==
0
:
mask
[
i
,
j
]
=
1
delta
+=
1
if
delta
>
0
:
break
return
delta
while
mask_count
<
num_masks
:
max_mask_patches
=
min
(
num_masks
-
mask_count
,
max_masks
)
delta
=
_mask
(
mask
,
max_mask_patches
)
if
delta
==
0
:
break
else
:
mask_count
+=
delta
return
torch
.
from_numpy
(
masks
)
def
forward
(
self
,
img
,
mask
:
bool
=
True
,
layer_results
:
bool
=
False
,
):
x
=
self
.
patch_embed
(
img
)
batch_size
,
seq_len
,
_
=
x
.
size
()
if
mask
:
mask_indices
=
self
.
make_mask
(
img
.
size
(
0
),
self
.
cfg
.
num_mask_patches
,
self
.
cfg
.
min_mask_patches_per_block
,
self
.
cfg
.
max_mask_patches_per_block
,
)
bool_mask
=
mask_indices
.
view
(
mask_indices
.
size
(
0
),
-
1
).
bool
()
else
:
mask_indices
=
bool_mask
=
None
cls_tokens
=
self
.
cls_emb
.
expand
(
batch_size
,
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
if
self
.
ema
is
not
None
:
with
torch
.
no_grad
():
self
.
ema
.
model
.
eval
()
if
self
.
cfg
.
ema_transformer_only
:
y
=
self
.
ema
.
model
(
x
,
layer_results
=
"end"
if
self
.
cfg
.
end_of_block_targets
else
"fc"
,
)
else
:
y
=
self
.
ema
.
model
(
img
,
mask
=
False
,
layer_results
=
True
,
)
y
=
y
[
-
self
.
cfg
.
average_top_k_layers
:]
permuted
=
False
if
self
.
cfg
.
instance_norm_target_layer
or
self
.
cfg
.
batch_norm_target_layer
:
y
=
[
tl
.
transpose
(
1
,
2
)
for
tl
in
y
]
# BTC -> BCT
permuted
=
True
if
self
.
cfg
.
batch_norm_target_layer
:
y
=
[
F
.
batch_norm
(
tl
.
float
(),
running_mean
=
None
,
running_var
=
None
,
training
=
True
)
for
tl
in
y
]
if
self
.
cfg
.
instance_norm_target_layer
:
y
=
[
F
.
instance_norm
(
tl
.
float
())
for
tl
in
y
]
if
permuted
:
y
=
[
tl
.
transpose
(
1
,
2
)
for
tl
in
y
]
# BCT -> BTC
if
self
.
cfg
.
layer_norm_target_layer
:
y
=
[
F
.
layer_norm
(
tl
.
float
(),
tl
.
shape
[
-
1
:])
for
tl
in
y
]
y
=
sum
(
y
)
/
len
(
y
)
if
self
.
cfg
.
layer_norm_targets
:
y
=
F
.
layer_norm
(
y
.
float
(),
y
.
shape
[
-
1
:])
if
self
.
cfg
.
instance_norm_targets
:
y
=
F
.
instance_norm
(
y
.
float
().
transpose
(
1
,
2
)).
transpose
(
1
,
2
)
y
=
y
[
bool_mask
].
float
()
if
mask_indices
is
not
None
:
mask_token
=
self
.
mask_emb
.
expand
(
batch_size
,
seq_len
,
-
1
)
w
=
mask_indices
.
view
(
mask_indices
.
size
(
0
),
-
1
,
1
).
type_as
(
mask_token
)
x
[:,
1
:]
=
x
[:,
1
:]
*
(
1
-
w
)
+
mask_token
*
w
if
layer_results
:
enc_layer_results
=
"end"
if
self
.
cfg
.
end_of_block_targets
else
"fc"
else
:
enc_layer_results
=
None
x
=
self
.
encoder
(
x
,
layer_results
=
enc_layer_results
)
if
layer_results
or
mask_indices
is
None
:
return
x
x
=
x
[
bool_mask
].
float
()
if
self
.
loss_beta
==
0
:
loss
=
F
.
mse_loss
(
x
,
y
,
reduction
=
"none"
).
sum
(
dim
=-
1
)
else
:
loss
=
F
.
smooth_l1_loss
(
x
,
y
,
reduction
=
"none"
,
beta
=
self
.
loss_beta
).
sum
(
dim
=-
1
)
if
self
.
loss_scale
>
0
:
loss
=
loss
*
self
.
loss_scale
result
=
{
"losses"
:
{
"regression"
:
loss
.
sum
()},
"sample_size"
:
loss
.
numel
(),
"target_var"
:
self
.
compute_var
(
y
),
"pred_var"
:
self
.
compute_var
(
x
),
"ema_decay"
:
self
.
ema
.
get_decay
()
*
1000
,
}
return
result
@
staticmethod
def
compute_var
(
y
):
y
=
y
.
view
(
-
1
,
y
.
size
(
-
1
))
if
dist
.
is_initialized
():
zc
=
torch
.
tensor
(
y
.
size
(
0
)).
cuda
()
zs
=
y
.
sum
(
dim
=
0
)
zss
=
(
y
**
2
).
sum
(
dim
=
0
)
dist
.
all_reduce
(
zc
)
dist
.
all_reduce
(
zs
)
dist
.
all_reduce
(
zss
)
var
=
zss
/
(
zc
-
1
)
-
(
zs
**
2
)
/
(
zc
*
(
zc
-
1
))
return
torch
.
sqrt
(
var
+
1e-6
).
mean
()
else
:
return
torch
.
sqrt
(
y
.
var
(
dim
=
0
)
+
1e-6
).
mean
()
def
remove_pretraining_modules
(
self
,
last_layer
=
None
):
self
.
final_proj
=
None
self
.
ema
=
None
self
.
encoder
.
norm
=
nn
.
Identity
()
self
.
mask_emb
=
None
if
last_layer
is
not
None
:
self
.
encoder
.
layers
=
nn
.
ModuleList
(
l
for
i
,
l
in
enumerate
(
self
.
encoder
.
layers
)
if
i
<=
last_layer
)
class
PatchEmbed
(
nn
.
Module
):
"""Image to Patch Embedding"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
if
isinstance
(
img_size
,
int
):
img_size
=
img_size
,
img_size
if
isinstance
(
patch_size
,
int
):
patch_size
=
patch_size
,
patch_size
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
(
img_size
[
0
]
//
patch_size
[
0
])
self
.
patch_shape
=
(
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
])
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
num_patches
=
num_patches
self
.
conv
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
def
forward
(
self
,
x
):
# BCHW -> BTC
x
=
self
.
conv
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
True
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
window_size
=
None
,
attn_head_dim
=
None
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
if
attn_head_dim
is
not
None
:
head_dim
=
attn_head_dim
all_head_dim
=
head_dim
*
self
.
num_heads
self
.
scale
=
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
all_head_dim
*
3
,
bias
=
False
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
if
window_size
:
self
.
window_size
=
window_size
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
3
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
)
)
# 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
(
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
)
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,)
*
2
,
dtype
=
relative_coords
.
dtype
,
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
else
:
self
.
window_size
=
None
self
.
relative_position_bias_table
=
None
self
.
relative_position_index
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
all_head_dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
rel_pos_bias
=
None
):
B
,
N
,
C
=
x
.
shape
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
torch
.
cat
(
(
self
.
q_bias
,
torch
.
zeros_like
(
self
.
v_bias
,
requires_grad
=
False
),
self
.
v_bias
,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv
=
F
.
linear
(
input
=
x
,
weight
=
self
.
qkv
.
weight
,
bias
=
qkv_bias
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
(
qkv
[
0
],
qkv
[
1
],
qkv
[
2
],
)
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
if
self
.
relative_position_bias_table
is
not
None
:
assert
1
==
2
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)
].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
,
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
print
(
"attn.size() :"
,
attn
.
size
())
print
(
"rel_pos_bias.size() :"
,
rel_pos_bias
.
size
())
if
rel_pos_bias
is
not
None
:
attn
=
attn
+
rel_pos_bias
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
RelativePositionBias
(
nn
.
Module
):
def
__init__
(
self
,
window_size
,
num_heads
):
super
().
__init__
()
self
.
window_size
=
window_size
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
3
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
)
)
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
(
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
)
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,)
*
2
,
dtype
=
relative_coords
.
dtype
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
def
forward
(
self
):
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)
].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
,
)
# Wh*Ww,Wh*Ww,nH
print
(
"self.window_size :"
,
self
.
window_size
)
print
(
"self.num_relative_distance :"
,
self
.
num_relative_distance
)
print
(
"self.relative_position_index :"
,
self
.
relative_position_index
.
size
(),
self
.
relative_position_index
)
print
(
"relative_position_bias.size(), relative_position_bias :"
,
relative_position_bias
.
size
(),
relative_position_bias
)
print
(
"self.relative_position_bias_table.size(), self.relative_position_bias_table :"
,
self
.
relative_position_bias_table
.
size
(),
self
.
relative_position_bias_table
)
return
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
class
DropPath
(
nn
.
Module
):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def
__init__
(
self
,
drop_prob
=
None
):
super
(
DropPath
,
self
).
__init__
()
self
.
drop_prob
=
drop_prob
def
forward
(
self
,
x
):
if
self
.
drop_prob
==
0.0
or
not
self
.
training
:
return
x
keep_prob
=
1
-
self
.
drop_prob
shape
=
(
x
.
shape
[
0
],)
+
(
1
,)
*
(
x
.
ndim
-
1
)
# work with diff dim tensors, not just 2D ConvNets
random_tensor
=
keep_prob
+
torch
.
rand
(
shape
,
dtype
=
x
.
dtype
,
device
=
x
.
device
)
random_tensor
.
floor_
()
output
=
x
.
div
(
keep_prob
)
*
random_tensor
return
output
def
extra_repr
(
self
)
->
str
:
return
"p={}"
.
format
(
self
.
drop_prob
)
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.0
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
init_values
=
None
,
window_size
=
None
,
):
super
().
__init__
()
self
.
norm1
=
nn
.
LayerNorm
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
window_size
=
window_size
,
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
nn
.
LayerNorm
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
nn
.
Sequential
(
nn
.
Linear
(
dim
,
mlp_hidden_dim
),
nn
.
GELU
(),
nn
.
Linear
(
mlp_hidden_dim
,
dim
),
nn
.
Dropout
(
drop
),
)
if
init_values
>
0
:
self
.
gamma_1
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
self
.
gamma_2
=
nn
.
Parameter
(
init_values
*
torch
.
ones
((
dim
)),
requires_grad
=
True
)
else
:
self
.
gamma_1
,
self
.
gamma_2
=
None
,
None
def
forward
(
self
,
x
,
rel_pos_bias
=
None
):
print
(
"inside block :"
,
x
.
size
())
if
self
.
gamma_1
is
None
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
rel_pos_bias
=
rel_pos_bias
))
fc_feature
=
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
x
=
x
+
fc_feature
else
:
x
=
x
+
self
.
drop_path
(
self
.
gamma_1
*
self
.
attn
(
self
.
norm1
(
x
),
rel_pos_bias
=
rel_pos_bias
)
)
fc_feature
=
self
.
drop_path
(
self
.
gamma_2
*
self
.
mlp
(
self
.
norm2
(
x
)))
x
=
x
+
fc_feature
return
x
,
fc_feature
class
TransformerEncoder
(
nn
.
Module
):
def
__init__
(
self
,
cfg
:
Data2VecVisionConfig
,
patch_shape
):
super
().
__init__
()
self
.
rel_pos_bias
=
None
if
cfg
.
shared_rel_pos_bias
:
self
.
rel_pos_bias
=
RelativePositionBias
(
window_size
=
patch_shape
,
num_heads
=
cfg
.
num_heads
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
cfg
.
drop_path
,
cfg
.
depth
)
]
# stochastic depth decay rule
print
(
"TransformerEncoder > patch_shape :"
,
patch_shape
)
self
.
blocks
=
nn
.
ModuleList
(
Block
(
dim
=
cfg
.
embed_dim
,
num_heads
=
cfg
.
num_heads
,
attn_drop
=
cfg
.
attention_dropout
,
drop_path
=
dpr
[
i
],
init_values
=
cfg
.
layer_scale_init_value
,
window_size
=
patch_shape
if
not
cfg
.
shared_rel_pos_bias
else
None
,
)
for
i
in
range
(
cfg
.
depth
)
)
self
.
norm
=
nn
.
LayerNorm
(
cfg
.
embed_dim
)
self
.
apply
(
self
.
init_weights
)
self
.
fix_init_weight
()
def
init_weights
(
self
,
m
):
std
=
0.02
if
isinstance
(
m
,
nn
.
Linear
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
std
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
nn
.
init
.
trunc_normal_
(
m
.
weight
,
std
=
std
)
if
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
def
fix_init_weight
(
self
):
def
rescale
(
param
,
layer_id
):
param
.
div_
(
math
.
sqrt
(
2.0
*
layer_id
))
for
layer_id
,
layer
in
enumerate
(
self
.
blocks
):
rescale
(
layer
.
attn
.
proj
.
weight
.
data
,
layer_id
+
1
)
rescale
(
layer
.
mlp
[
2
].
weight
.
data
,
layer_id
+
1
)
def
extract_features
(
self
,
x
,
layer_results
):
rel_pos_bias
=
self
.
rel_pos_bias
()
if
self
.
rel_pos_bias
is
not
None
else
None
z
=
[]
for
i
,
blk
in
enumerate
(
self
.
blocks
):
x
,
fc_feature
=
blk
(
x
,
rel_pos_bias
=
rel_pos_bias
)
if
layer_results
==
"end"
:
z
.
append
(
x
)
elif
layer_results
==
"fc"
:
z
.
append
(
fc_feature
)
return
z
if
layer_results
else
self
.
norm
(
x
)
def
forward
(
self
,
x
,
layer_results
=
None
):
x
=
self
.
extract_features
(
x
,
layer_results
=
layer_results
)
if
layer_results
:
return
[
z
[:,
1
:]
for
z
in
x
]
x
=
x
[:,
1
:]
return
x
examples/data2vec/models/mae.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import
logging
from
dataclasses
import
dataclass
from
functools
import
partial
from
timm.models.vision_transformer
import
PatchEmbed
,
Block
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.wav2vec.wav2vec2
import
TransformerSentenceEncoderLayer
try
:
from
apex.normalization
import
FusedLayerNorm
except
:
FusedLayerNorm
=
nn
.
LayerNorm
import
torch.nn.functional
as
F
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
MaeConfig
(
FairseqDataclass
):
input_size
:
int
=
224
in_chans
:
int
=
3
patch_size
:
int
=
16
embed_dim
:
int
=
768
depth
:
int
=
12
num_heads
:
int
=
12
decoder_embed_dim
:
int
=
512
decoder_depth
:
int
=
8
decoder_num_heads
:
int
=
16
mlp_ratio
:
int
=
4
norm_eps
:
float
=
1e-6
drop_path_rate
:
float
=
0.0
mask_ratio
:
float
=
0.75
norm_pix_loss
:
bool
=
True
w2v_block
:
bool
=
False
alt_block
:
bool
=
False
alt_block2
:
bool
=
False
alt_attention
:
bool
=
False
block_dropout
:
float
=
0
attention_dropout
:
float
=
0
activation_dropout
:
float
=
0
layer_norm_first
:
bool
=
False
fused_ln
:
bool
=
True
end_of_block_targets
:
bool
=
True
no_decoder_embed
:
bool
=
False
no_decoder_pos_embed
:
bool
=
False
mask_noise_std
:
float
=
0
single_qkv
:
bool
=
False
use_rel_pos_bias
:
bool
=
False
no_cls
:
bool
=
False
def
modify_relative_position_bias
(
orig_bias
,
bsz
,
mask
):
if
mask
is
None
:
return
orig_bias
.
unsqueeze
(
0
).
repeat
(
bsz
,
1
,
1
,
1
)
# heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
heads
,
max_seq_len
,
max_seq_len
=
orig_bias
.
shape
# includes CLS token
mask_for_rel_pos_bias
=
torch
.
cat
(
(
torch
.
zeros
(
bsz
,
1
,
dtype
=
mask
.
dtype
,
device
=
mask
.
device
),
mask
),
dim
=
1
).
bool
()
# bsz x seqlen (add CLS token)
unmasked_for_rel_pos_bias
=
~
mask_for_rel_pos_bias
unmasked_for_rel_pos_bias
=
unmasked_for_rel_pos_bias
.
unsqueeze
(
1
).
repeat
(
1
,
heads
,
1
)
# bsz x seq_len => bsz x heads x seq_len
b_t_t_rel_pos_bias
=
orig_bias
.
unsqueeze
(
0
).
repeat
(
bsz
,
1
,
1
,
1
)
# heads x seq_len x seq_len => bsz x heads x seq_len x seq_len
b_t_t_rel_pos_bias
=
b_t_t_rel_pos_bias
.
masked_select
(
unmasked_for_rel_pos_bias
.
unsqueeze
(
-
1
)
)
b_t_t_rel_pos_bias
=
b_t_t_rel_pos_bias
.
view
(
bsz
,
heads
,
-
1
,
max_seq_len
)
new_len
=
b_t_t_rel_pos_bias
.
size
(
-
2
)
b_t_t_rel_pos_bias
=
b_t_t_rel_pos_bias
.
masked_select
(
unmasked_for_rel_pos_bias
.
unsqueeze
(
-
2
)
)
b_t_t_rel_pos_bias
=
b_t_t_rel_pos_bias
.
view
(
bsz
,
heads
,
new_len
,
new_len
)
return
b_t_t_rel_pos_bias
class
AltBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
layer_norm_first
=
True
,
ffn_targets
=
False
,
use_rel_pos_bias
=
False
,
window_size
=
None
,
alt_attention
=
False
,
):
super
().
__init__
()
self
.
layer_norm_first
=
layer_norm_first
self
.
ffn_targets
=
ffn_targets
from
timm.models.vision_transformer
import
Attention
,
DropPath
,
Mlp
self
.
norm1
=
norm_layer
(
dim
)
self
.
use_rel_pos_bias
=
use_rel_pos_bias
if
use_rel_pos_bias
:
self
.
attn
=
AltAttention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
window_size
=
window_size
,
)
else
:
if
alt_attention
:
from
.multi.modules
import
AltAttention
as
AltAttention2
self
.
attn
=
AltAttention2
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
)
else
:
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
,
)
def
forward
(
self
,
x
,
rel_pos_bias
=
None
,
pos_mask
=
None
):
if
self
.
layer_norm_first
:
if
self
.
use_rel_pos_bias
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
rel_pos_bias
=
rel_pos_bias
,
pos_mask
=
pos_mask
)
)
else
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
)))
t
=
self
.
mlp
(
self
.
norm2
(
x
))
x
=
x
+
self
.
drop_path
(
t
)
if
not
self
.
ffn_targets
:
t
=
x
return
x
,
t
else
:
if
self
.
use_rel_pos_bias
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
x
,
rel_pos_bias
=
rel_pos_bias
,
pos_mask
=
pos_mask
)
)
else
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
x
))
r
=
x
=
self
.
norm1
(
x
)
x
=
self
.
mlp
(
x
)
t
=
x
x
=
self
.
norm2
(
r
+
self
.
drop_path
(
x
))
if
not
self
.
ffn_targets
:
t
=
x
return
x
,
t
class
AltAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
window_size
=
None
,
attn_head_dim
=
None
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
if
attn_head_dim
is
not
None
:
head_dim
=
attn_head_dim
all_head_dim
=
head_dim
*
self
.
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
all_head_dim
*
3
,
bias
=
False
)
if
qkv_bias
:
self
.
q_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
self
.
v_bias
=
nn
.
Parameter
(
torch
.
zeros
(
all_head_dim
))
else
:
self
.
q_bias
=
None
self
.
v_bias
=
None
if
window_size
:
self
.
window_size
=
window_size
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
3
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
)
)
# 2*Wh-1 * 2*Ww-1, nH
# cls to token & token 2 cls & cls to cls
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
(
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
)
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,)
*
2
,
dtype
=
relative_coords
.
dtype
,
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
else
:
self
.
window_size
=
None
self
.
relative_position_bias_table
=
None
self
.
relative_position_index
=
None
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
all_head_dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
def
forward
(
self
,
x
,
rel_pos_bias
=
None
,
pos_mask
=
None
):
B
,
N
,
C
=
x
.
shape
qkv_bias
=
None
if
self
.
q_bias
is
not
None
:
qkv_bias
=
torch
.
cat
(
(
self
.
q_bias
,
torch
.
zeros_like
(
self
.
v_bias
,
requires_grad
=
False
),
self
.
v_bias
,
)
)
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
qkv
=
F
.
linear
(
input
=
x
,
weight
=
self
.
qkv
.
weight
,
bias
=
qkv_bias
)
qkv
=
qkv
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
-
1
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
(
qkv
[
0
],
qkv
[
1
],
qkv
[
2
],
)
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
if
self
.
relative_position_bias_table
is
not
None
:
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)
].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
,
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
modify_relative_position_bias
(
relative_position_bias
,
x
.
size
(
0
),
pos_mask
)
if
rel_pos_bias
is
not
None
:
attn
=
attn
+
rel_pos_bias
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
-
1
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
RelativePositionBias
(
nn
.
Module
):
def
__init__
(
self
,
window_size
,
num_heads
):
super
().
__init__
()
self
.
window_size
=
window_size
self
.
num_relative_distance
=
(
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
)
+
3
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
(
self
.
num_relative_distance
,
num_heads
)
)
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
window_size
[
0
])
coords_w
=
torch
.
arange
(
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
(
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
)
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
window_size
[
1
]
-
1
relative_position_index
=
torch
.
zeros
(
size
=
(
window_size
[
0
]
*
window_size
[
1
]
+
1
,)
*
2
,
dtype
=
relative_coords
.
dtype
)
relative_position_index
[
1
:,
1
:]
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
relative_position_index
[
0
,
0
:]
=
self
.
num_relative_distance
-
3
relative_position_index
[
0
:,
0
]
=
self
.
num_relative_distance
-
2
relative_position_index
[
0
,
0
]
=
self
.
num_relative_distance
-
1
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
def
forward
(
self
):
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)
].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
self
.
window_size
[
0
]
*
self
.
window_size
[
1
]
+
1
,
-
1
,
)
# Wh*Ww,Wh*Ww,nH
return
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
def
get_2d_sincos_pos_embed
(
embed_dim
,
grid_size
,
cls_token
=
False
):
"""
grid_size: int of the grid height and width
return:
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
"""
grid_h
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid_w
=
np
.
arange
(
grid_size
,
dtype
=
np
.
float32
)
grid
=
np
.
meshgrid
(
grid_w
,
grid_h
)
# here w goes first
grid
=
np
.
stack
(
grid
,
axis
=
0
)
grid
=
grid
.
reshape
([
2
,
1
,
grid_size
,
grid_size
])
pos_embed
=
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
)
if
cls_token
:
pos_embed
=
np
.
concatenate
([
np
.
zeros
([
1
,
embed_dim
]),
pos_embed
],
axis
=
0
)
return
pos_embed
def
get_2d_sincos_pos_embed_from_grid
(
embed_dim
,
grid
):
assert
embed_dim
%
2
==
0
# use half of dimensions to encode grid_h
emb_h
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
0
])
# (H*W, D/2)
emb_w
=
get_1d_sincos_pos_embed_from_grid
(
embed_dim
//
2
,
grid
[
1
])
# (H*W, D/2)
emb
=
np
.
concatenate
([
emb_h
,
emb_w
],
axis
=
1
)
# (H*W, D)
return
emb
def
get_1d_sincos_pos_embed_from_grid
(
embed_dim
,
pos
):
"""
embed_dim: output dimension for each position
pos: a list of positions to be encoded: size (M,)
out: (M, D)
"""
assert
embed_dim
%
2
==
0
omega
=
np
.
arange
(
embed_dim
//
2
,
dtype
=
np
.
float
)
omega
/=
embed_dim
/
2.0
omega
=
1.0
/
10000
**
omega
# (D/2,)
pos
=
pos
.
reshape
(
-
1
)
# (M,)
out
=
np
.
einsum
(
"m,d->md"
,
pos
,
omega
)
# (M, D/2), outer product
emb_sin
=
np
.
sin
(
out
)
# (M, D/2)
emb_cos
=
np
.
cos
(
out
)
# (M, D/2)
emb
=
np
.
concatenate
([
emb_sin
,
emb_cos
],
axis
=
1
)
# (M, D)
return
emb
def
interpolate_pos_embed
(
model
,
checkpoint_model
):
if
"pos_embed"
in
checkpoint_model
:
pos_embed_checkpoint
=
checkpoint_model
[
"pos_embed"
]
embedding_size
=
pos_embed_checkpoint
.
shape
[
-
1
]
num_patches
=
model
.
patch_embed
.
num_patches
num_extra_tokens
=
model
.
pos_embed
.
shape
[
-
2
]
-
num_patches
# height (== width) for the checkpoint position embedding
orig_size
=
int
((
pos_embed_checkpoint
.
shape
[
-
2
]
-
num_extra_tokens
)
**
0.5
)
# height (== width) for the new position embedding
new_size
=
int
(
num_patches
**
0.5
)
# class_token and dist_token are kept unchanged
if
orig_size
!=
new_size
:
print
(
"Position interpolate from %dx%d to %dx%d"
%
(
orig_size
,
orig_size
,
new_size
,
new_size
)
)
extra_tokens
=
pos_embed_checkpoint
[:,
:
num_extra_tokens
]
# only the position tokens are interpolated
pos_tokens
=
pos_embed_checkpoint
[:,
num_extra_tokens
:]
pos_tokens
=
pos_tokens
.
reshape
(
-
1
,
orig_size
,
orig_size
,
embedding_size
).
permute
(
0
,
3
,
1
,
2
)
pos_tokens
=
torch
.
nn
.
functional
.
interpolate
(
pos_tokens
,
size
=
(
new_size
,
new_size
),
mode
=
"bicubic"
,
align_corners
=
False
,
)
pos_tokens
=
pos_tokens
.
permute
(
0
,
2
,
3
,
1
).
flatten
(
1
,
2
)
new_pos_embed
=
torch
.
cat
((
extra_tokens
,
pos_tokens
),
dim
=
1
)
checkpoint_model
[
"pos_embed"
]
=
new_pos_embed
@
register_model
(
"mae"
,
dataclass
=
MaeConfig
)
class
MaeModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
MaeConfig
):
super
().
__init__
()
self
.
cfg
=
cfg
self
.
mask_ratio
=
cfg
.
mask_ratio
# --------------------------------------------------------------------------
# MAE encoder specifics
self
.
patch_embed
=
PatchEmbed
(
cfg
.
input_size
,
cfg
.
patch_size
,
cfg
.
in_chans
,
cfg
.
embed_dim
)
num_patches
=
self
.
patch_embed
.
num_patches
self
.
cls_token
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
cfg
.
embed_dim
))
if
not
cfg
.
no_cls
else
None
self
.
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
int
(
not
cfg
.
no_cls
),
cfg
.
embed_dim
),
requires_grad
=
False
)
# fixed sin-cos embedding
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
cfg
.
norm_eps
)
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
cfg
.
drop_path_rate
,
cfg
.
depth
)
]
# stochastic depth decay rule
def
make_block
(
drop_path
):
if
cfg
.
w2v_block
:
return
TransformerSentenceEncoderLayer
(
embedding_dim
=
cfg
.
embed_dim
,
ffn_embedding_dim
=
cfg
.
embed_dim
*
cfg
.
mlp_ratio
,
num_attention_heads
=
cfg
.
num_heads
,
dropout
=
cfg
.
block_dropout
,
attention_dropout
=
cfg
.
attention_dropout
,
activation_dropout
=
cfg
.
activation_dropout
,
activation_fn
=
"gelu"
,
layer_norm_first
=
cfg
.
layer_norm_first
,
drop_path
=
drop_path
,
norm_eps
=
1e-6
,
single_qkv
=
cfg
.
single_qkv
,
fused_ln
=
cfg
.
fused_ln
,
)
elif
cfg
.
alt_block
:
window_size
=
(
cfg
.
input_size
//
self
.
patch_embed
.
patch_size
[
0
],
cfg
.
input_size
//
self
.
patch_embed
.
patch_size
[
1
],
)
return
AltBlock
(
cfg
.
embed_dim
,
cfg
.
num_heads
,
cfg
.
mlp_ratio
,
qkv_bias
=
True
,
qk_scale
=
None
,
norm_layer
=
norm_layer
,
drop_path
=
drop_path
,
layer_norm_first
=
cfg
.
layer_norm_first
,
ffn_targets
=
not
cfg
.
end_of_block_targets
,
use_rel_pos_bias
=
cfg
.
use_rel_pos_bias
,
window_size
=
window_size
if
(
self
.
cfg
.
use_rel_pos_bias
and
not
self
.
cfg
.
shared_rel_pos_bias
)
else
None
,
alt_attention
=
cfg
.
alt_attention
,
)
elif
cfg
.
alt_block2
:
from
.multi.modules
import
AltBlock
as
AltBlock2
return
AltBlock2
(
cfg
.
embed_dim
,
cfg
.
num_heads
,
cfg
.
mlp_ratio
,
qkv_bias
=
True
,
qk_scale
=
None
,
norm_layer
=
norm_layer
,
drop_path
=
drop_path
,
layer_norm_first
=
cfg
.
layer_norm_first
,
ffn_targets
=
not
cfg
.
end_of_block_targets
,
)
else
:
return
Block
(
cfg
.
embed_dim
,
cfg
.
num_heads
,
cfg
.
mlp_ratio
,
qkv_bias
=
True
,
qk_scale
=
None
,
norm_layer
=
norm_layer
,
drop_path
=
drop_path
,
)
self
.
blocks
=
nn
.
ModuleList
([
make_block
(
dpr
[
i
])
for
i
in
range
(
cfg
.
depth
)])
self
.
norm
=
norm_layer
(
cfg
.
embed_dim
)
# --------------------------------------------------------------------------
# --------------------------------------------------------------------------
# MAE decoder specifics
self
.
decoder_embed
=
(
nn
.
Linear
(
cfg
.
embed_dim
,
cfg
.
decoder_embed_dim
,
bias
=
True
)
if
not
cfg
.
no_decoder_embed
else
None
)
self
.
mask_token
=
(
nn
.
Parameter
(
torch
.
zeros
(
1
,
1
,
cfg
.
decoder_embed_dim
if
not
cfg
.
no_decoder_embed
else
cfg
.
embed_dim
,
)
)
if
cfg
.
mask_noise_std
<=
0
else
None
)
self
.
decoder_pos_embed
=
(
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
+
1
,
cfg
.
decoder_embed_dim
if
not
cfg
.
no_decoder_embed
else
cfg
.
embed_dim
,
),
requires_grad
=
False
,
)
if
not
cfg
.
no_decoder_pos_embed
else
None
)
self
.
decoder_blocks
=
nn
.
ModuleList
(
[
Block
(
cfg
.
decoder_embed_dim
,
cfg
.
decoder_num_heads
,
cfg
.
mlp_ratio
,
qkv_bias
=
True
,
qk_scale
=
None
,
norm_layer
=
norm_layer
,
)
for
_
in
range
(
cfg
.
decoder_depth
)
]
)
self
.
decoder_norm
=
norm_layer
(
cfg
.
decoder_embed_dim
)
self
.
decoder_pred
=
nn
.
Linear
(
cfg
.
decoder_embed_dim
,
cfg
.
patch_size
**
2
*
cfg
.
in_chans
,
bias
=
True
)
# decoder to patch
# --------------------------------------------------------------------------
self
.
norm_pix_loss
=
cfg
.
norm_pix_loss
self
.
initialize_weights
()
for
pn
,
p
in
self
.
named_parameters
():
if
len
(
p
.
shape
)
==
1
or
pn
.
endswith
(
".bias"
):
p
.
param_group
=
"no_decay"
else
:
p
.
param_group
=
"with_decay"
def
initialize_weights
(
self
):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed
=
get_2d_sincos_pos_embed
(
self
.
pos_embed
.
shape
[
-
1
],
int
(
self
.
patch_embed
.
num_patches
**
0.5
),
cls_token
=
not
self
.
cfg
.
no_cls
,
)
self
.
pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
pos_embed
).
float
().
unsqueeze
(
0
))
if
self
.
decoder_pos_embed
is
not
None
:
decoder_pos_embed
=
get_2d_sincos_pos_embed
(
self
.
decoder_pos_embed
.
shape
[
-
1
],
int
(
self
.
patch_embed
.
num_patches
**
0.5
),
cls_token
=
not
self
.
cfg
.
no_cls
,
)
self
.
decoder_pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
decoder_pos_embed
).
float
().
unsqueeze
(
0
)
)
# initialize patch_embed like nn.Linear (instead of nn.Conv2d)
w
=
self
.
patch_embed
.
proj
.
weight
.
data
torch
.
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
if
self
.
cls_token
is
not
None
:
torch
.
nn
.
init
.
normal_
(
self
.
cls_token
,
std
=
0.02
)
if
self
.
mask_token
is
not
None
:
torch
.
nn
.
init
.
normal_
(
self
.
mask_token
,
std
=
0.02
)
# initialize nn.Linear and nn.LayerNorm
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
# we use xavier_uniform following official JAX ViT:
torch
.
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
)
or
isinstance
(
m
,
FusedLayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
def
patchify
(
self
,
imgs
):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p
=
self
.
patch_embed
.
patch_size
[
0
]
assert
imgs
.
shape
[
2
]
==
imgs
.
shape
[
3
]
and
imgs
.
shape
[
2
]
%
p
==
0
h
=
w
=
imgs
.
shape
[
2
]
//
p
x
=
imgs
.
reshape
(
shape
=
(
imgs
.
shape
[
0
],
3
,
h
,
p
,
w
,
p
))
x
=
torch
.
einsum
(
"nchpwq->nhwpqc"
,
x
)
x
=
x
.
reshape
(
shape
=
(
imgs
.
shape
[
0
],
h
*
w
,
p
**
2
*
3
))
return
x
def
unpatchify
(
self
,
x
):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p
=
self
.
patch_embed
.
patch_size
[
0
]
h
=
w
=
int
(
x
.
shape
[
1
]
**
0.5
)
assert
h
*
w
==
x
.
shape
[
1
]
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
h
,
w
,
p
,
p
,
3
))
x
=
torch
.
einsum
(
"nhwpqc->nchpwq"
,
x
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
3
,
h
*
p
,
h
*
p
))
return
imgs
def
random_masking
(
self
,
x
,
mask_ratio
):
"""
Perform per-sample random masking by per-sample shuffling.
Per-sample shuffling is done by argsort random noise.
x: [N, L, D], sequence
"""
N
,
L
,
D
=
x
.
shape
# batch, length, dim
len_keep
=
int
(
L
*
(
1
-
mask_ratio
))
noise
=
torch
.
rand
(
N
,
L
,
device
=
x
.
device
)
# noise in [0, 1]
# sort noise for each sample
ids_shuffle
=
torch
.
argsort
(
noise
,
dim
=
1
)
# ascend: small is keep, large is remove
ids_restore
=
torch
.
argsort
(
ids_shuffle
,
dim
=
1
)
# keep the first subset
ids_keep
=
ids_shuffle
[:,
:
len_keep
]
x_masked
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
ids_keep
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
D
))
# generate the binary mask: 0 is keep, 1 is remove
mask
=
torch
.
ones
([
N
,
L
],
device
=
x
.
device
)
mask
[:,
:
len_keep
]
=
0
# unshuffle to get the binary mask
mask
=
torch
.
gather
(
mask
,
dim
=
1
,
index
=
ids_restore
)
return
x_masked
,
mask
,
ids_restore
# x_masked is actually unmasked x
@
classmethod
def
build_model
(
cls
,
cfg
:
MaeConfig
,
task
=
None
):
"""Build a new model instance."""
return
cls
(
cfg
)
def
forward_encoder
(
self
,
x
,
mask_ratio
):
# embed patches
x
=
self
.
patch_embed
(
x
)
# add pos embed w/o cls token
# if self.cls_token is not None:
# x = x + self.pos_embed
# else:
x
=
x
+
self
.
pos_embed
[:,
1
:,
:]
# masking: length -> length * mask_ratio
if
mask_ratio
>
0
:
x
,
mask
,
ids_restore
=
self
.
random_masking
(
x
,
mask_ratio
)
else
:
mask
=
ids_restore
=
None
# append cls token
if
self
.
cls_token
is
not
None
:
cls_token
=
self
.
cls_token
+
self
.
pos_embed
[:,
:
1
,
:]
cls_tokens
=
cls_token
.
expand
(
x
.
shape
[
0
],
-
1
,
-
1
)
x
=
torch
.
cat
((
cls_tokens
,
x
),
dim
=
1
)
# apply Transformer blocks
for
blk
in
self
.
blocks
:
x
=
blk
(
x
)
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
,
mask
,
ids_restore
def
forward_decoder
(
self
,
x
,
ids_restore
):
# embed tokens
x
=
self
.
decoder_embed
(
x
)
# append mask tokens to sequence
mask_tokens
=
self
.
mask_token
.
repeat
(
x
.
shape
[
0
],
ids_restore
.
shape
[
1
]
+
1
-
x
.
shape
[
1
],
1
)
if
self
.
cls_token
is
not
None
:
x_
=
torch
.
cat
([
x
[:,
1
:,
:],
mask_tokens
],
dim
=
1
)
# no cls token
else
:
x_
=
torch
.
cat
([
x
,
mask_tokens
],
dim
=
1
)
# no cls token
x_
=
torch
.
gather
(
x_
,
dim
=
1
,
index
=
ids_restore
.
unsqueeze
(
-
1
).
repeat
(
1
,
1
,
x
.
shape
[
2
])
)
# unshuffle
if
self
.
cls_token
is
not
None
:
x
=
torch
.
cat
([
x
[:,
:
1
,
:],
x_
],
dim
=
1
)
# append cls token
# add pos embed
x
=
x
+
self
.
decoder_pos_embed
# apply Transformer blocks
for
blk
in
self
.
decoder_blocks
:
x
=
blk
(
x
)
x
=
self
.
decoder_norm
(
x
)
# predictor projection
x
=
self
.
decoder_pred
(
x
)
if
self
.
cls_token
is
not
None
:
# remove cls token
x
=
x
[:,
1
:,
:]
return
x
def
forward_loss
(
self
,
imgs
,
pred
,
mask
):
"""
imgs: [N, 3, H, W]
pred: [N, L, p*p*3]
mask: [N, L], 0 is keep, 1 is remove,
"""
target
=
self
.
patchify
(
imgs
)
if
self
.
norm_pix_loss
:
mean
=
target
.
mean
(
dim
=-
1
,
keepdim
=
True
)
var
=
target
.
var
(
dim
=-
1
,
keepdim
=
True
)
target
=
(
target
-
mean
)
/
(
var
+
1.0e-6
)
**
0.5
loss
=
(
pred
-
target
)
**
2
loss
=
loss
.
mean
(
dim
=-
1
)
# [N, L], mean loss per patch
loss
=
(
loss
*
mask
).
sum
()
return
loss
,
mask
.
sum
()
def
forward
(
self
,
imgs
,
predictions_only
=
False
):
latent
,
mask
,
ids_restore
=
self
.
forward_encoder
(
imgs
,
self
.
mask_ratio
if
not
predictions_only
else
0
)
if
predictions_only
:
return
latent
pred
=
self
.
forward_decoder
(
latent
,
ids_restore
)
# [N, L, p*p*3]
loss
,
sample_size
=
self
.
forward_loss
(
imgs
,
pred
,
mask
)
result
=
{
"losses"
:
{
"regression"
:
loss
},
"sample_size"
:
sample_size
,
}
return
result
def
remove_pretraining_modules
(
self
):
self
.
decoder_embed
=
None
self
.
decoder_blocks
=
None
self
.
decoder_norm
=
None
self
.
decoder_pos_embed
=
None
self
.
decoder_pred
=
None
self
.
mask_token
=
None
if
self
.
cfg
.
layer_norm_first
:
self
.
norm
=
None
examples/data2vec/models/mae_image_classification.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# The code in this file is adapted from the BeiT implementation which can be found here:
# https://github.com/microsoft/unilm/tree/master/beit
import
logging
from
dataclasses
import
dataclass
from
enum
import
Enum
,
auto
from
typing
import
Any
,
Optional
import
numpy
as
np
from
omegaconf
import
II
,
MISSING
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
checkpoint_utils
,
tasks
from
omegaconf
import
open_dict
from
fairseq.dataclass
import
FairseqDataclass
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
.mae
import
interpolate_pos_embed
logger
=
logging
.
getLogger
(
__name__
)
class
PredictionMode
(
Enum
):
MEAN_POOLING
=
auto
()
CLS_TOKEN
=
auto
()
LIN_SOFTMAX
=
auto
()
@
dataclass
class
MaeImageClassificationConfig
(
FairseqDataclass
):
model_path
:
str
=
MISSING
no_pretrained_weights
:
bool
=
False
linear_classifier
:
bool
=
False
num_classes
:
int
=
1000
mixup
:
float
=
0.8
cutmix
:
float
=
1.0
label_smoothing
:
float
=
0.1
drop_path_rate
:
float
=
0.1
layer_decay
:
float
=
0.65
mixup_prob
:
float
=
1.0
mixup_switch_prob
:
float
=
0.5
mixup_mode
:
str
=
"batch"
pretrained_model_args
:
Any
=
None
data
:
str
=
II
(
"task.data"
)
norm_eps
:
Optional
[
float
]
=
None
remove_alibi
:
bool
=
False
# regularization overwrites
encoder_dropout
:
float
=
0
post_mlp_drop
:
float
=
0
attention_dropout
:
float
=
0
activation_dropout
:
float
=
0.0
dropout_input
:
float
=
0.0
layerdrop
:
float
=
0.0
prenet_layerdrop
:
float
=
0
prenet_dropout
:
float
=
0
use_fc_norm
:
bool
=
True
prediction_mode
:
PredictionMode
=
PredictionMode
.
MEAN_POOLING
no_decay_blocks
:
bool
=
True
def
get_layer_id_for_vit
(
name
,
num_layers
):
"""
Assign a parameter with its layer id
Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
"""
if
name
in
[
"cls_token"
,
"pos_embed"
]:
return
0
elif
name
.
startswith
(
"patch_embed"
):
return
0
elif
name
.
startswith
(
"rel_pos_bias"
):
return
num_layers
-
1
elif
name
.
startswith
(
"blocks"
):
return
int
(
name
.
split
(
"."
)[
1
])
+
1
else
:
return
num_layers
@
register_model
(
"mae_image_classification"
,
dataclass
=
MaeImageClassificationConfig
)
class
MaeImageClassificationModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
MaeImageClassificationConfig
):
super
().
__init__
()
self
.
cfg
=
cfg
if
cfg
.
pretrained_model_args
is
None
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
cfg
.
model_path
,
{})
pretrained_args
=
state
.
get
(
"cfg"
,
None
)
pretrained_args
.
criterion
=
None
pretrained_args
.
lr_scheduler
=
None
logger
.
info
(
pretrained_args
.
model
)
with
open_dict
(
pretrained_args
.
model
):
pretrained_args
.
model
.
drop_path_rate
=
cfg
.
drop_path_rate
if
cfg
.
norm_eps
is
not
None
:
pretrained_args
.
model
.
norm_eps
=
cfg
.
norm_eps
cfg
.
pretrained_model_args
=
pretrained_args
logger
.
info
(
pretrained_args
)
else
:
state
=
None
pretrained_args
=
cfg
.
pretrained_model_args
if
"data"
in
pretrained_args
.
task
:
pretrained_args
.
task
.
data
=
cfg
.
data
elif
"image"
in
pretrained_args
.
task
:
pretrained_args
.
task
.
image
.
data
=
cfg
.
data
if
"modalities"
in
pretrained_args
.
model
:
prenet_blocks
=
pretrained_args
.
model
[
"modalities"
][
"image"
][
"prenet_depth"
]
model_blocks
=
pretrained_args
.
model
[
"depth"
]
with
open_dict
(
pretrained_args
):
dpr
=
np
.
linspace
(
0
,
cfg
.
drop_path_rate
,
model_blocks
).
tolist
()
pretrained_args
.
model
[
"modalities"
][
"image"
][
"start_drop_path_rate"
]
=
dpr
[
0
]
pretrained_args
.
model
[
"modalities"
][
"image"
][
"end_drop_path_rate"
]
=
max
(
0
,
dpr
[
prenet_blocks
-
1
])
pretrained_args
.
model
[
"start_drop_path_rate"
]
=
dpr
[
prenet_blocks
]
pretrained_args
.
model
[
"end_drop_path_rate"
]
=
dpr
[
-
1
]
if
"mae_masking"
in
pretrained_args
.
model
[
"modalities"
][
"image"
]:
del
pretrained_args
.
model
[
"modalities"
][
"image"
][
"mae_masking"
]
if
cfg
.
remove_alibi
:
pretrained_args
.
model
[
"modalities"
][
"image"
][
"use_alibi_encoder"
]
=
False
if
(
state
is
not
None
and
"modality_encoders.IMAGE.alibi_bias"
in
state
[
"model"
]
):
del
state
[
"model"
][
"modality_encoders.IMAGE.alibi_bias"
]
pretrained_args
.
model
[
"encoder_dropout"
]
=
cfg
.
encoder_dropout
pretrained_args
.
model
[
"post_mlp_drop"
]
=
cfg
.
post_mlp_drop
pretrained_args
.
model
[
"attention_dropout"
]
=
cfg
.
attention_dropout
pretrained_args
.
model
[
"activation_dropout"
]
=
cfg
.
activation_dropout
pretrained_args
.
model
[
"dropout_input"
]
=
cfg
.
dropout_input
pretrained_args
.
model
[
"layerdrop"
]
=
cfg
.
layerdrop
pretrained_args
.
model
[
"modalities"
][
"image"
][
"prenet_layerdrop"
]
=
cfg
.
prenet_layerdrop
pretrained_args
.
model
[
"modalities"
][
"image"
][
"prenet_dropout"
]
=
cfg
.
prenet_dropout
else
:
# not d2v multi
with
open_dict
(
pretrained_args
):
pretrained_args
.
model
[
"drop_path_rate"
]
=
cfg
.
drop_path_rate
pretrained_args
.
model
[
"block_dropout"
]
=
cfg
.
encoder_dropout
pretrained_args
.
model
[
"attention_dropout"
]
=
cfg
.
attention_dropout
pretrained_args
.
model
[
"activation_dropout"
]
=
cfg
.
activation_dropout
task
=
tasks
.
setup_task
(
pretrained_args
.
task
)
model
=
task
.
build_model
(
pretrained_args
.
model
,
from_checkpoint
=
True
)
self
.
d2v_multi
=
"data2vec_multi"
in
pretrained_args
.
model
.
_name
self
.
linear_classifier
=
cfg
.
linear_classifier
self
.
model
=
model
if
state
is
not
None
and
not
cfg
.
no_pretrained_weights
:
interpolate_pos_embed
(
model
,
state
)
if
"modality_encoders.IMAGE.positional_encoder.pos_embed"
in
state
[
"model"
]:
state
[
"model"
][
"modality_encoders.IMAGE.positional_encoder.positions"
]
=
state
[
"model"
][
"modality_encoders.IMAGE.positional_encoder.pos_embed"
]
del
state
[
"model"
][
"modality_encoders.IMAGE.positional_encoder.pos_embed"
]
if
"modality_encoders.IMAGE.encoder_mask"
in
state
[
"model"
]:
del
state
[
"model"
][
"modality_encoders.IMAGE.encoder_mask"
]
model
.
load_state_dict
(
state
[
"model"
],
strict
=
True
)
if
self
.
d2v_multi
:
model
.
remove_pretraining_modules
(
modality
=
"image"
)
else
:
model
.
remove_pretraining_modules
()
if
self
.
linear_classifier
:
model
.
requires_grad_
(
False
)
self
.
fc_norm
=
None
if
self
.
cfg
.
use_fc_norm
:
self
.
fc_norm
=
nn
.
LayerNorm
(
pretrained_args
.
model
.
embed_dim
,
eps
=
1e-6
)
nn
.
init
.
constant_
(
self
.
fc_norm
.
bias
,
0
)
nn
.
init
.
constant_
(
self
.
fc_norm
.
weight
,
1.0
)
self
.
head
=
nn
.
Linear
(
pretrained_args
.
model
.
embed_dim
,
cfg
.
num_classes
)
nn
.
init
.
trunc_normal_
(
self
.
head
.
weight
,
std
=
0.02
)
nn
.
init
.
constant_
(
self
.
head
.
bias
,
0
)
self
.
mixup_fn
=
None
if
cfg
.
mixup
>
0
or
cfg
.
cutmix
>
0
:
from
timm.data
import
Mixup
self
.
mixup_fn
=
Mixup
(
mixup_alpha
=
cfg
.
mixup
,
cutmix_alpha
=
cfg
.
cutmix
,
cutmix_minmax
=
None
,
prob
=
cfg
.
mixup_prob
,
switch_prob
=
cfg
.
mixup_switch_prob
,
mode
=
cfg
.
mixup_mode
,
label_smoothing
=
cfg
.
label_smoothing
,
num_classes
=
cfg
.
num_classes
,
)
if
self
.
model
.
norm
is
not
None
:
for
pn
,
p
in
self
.
model
.
norm
.
named_parameters
():
if
len
(
p
.
shape
)
==
1
or
pn
.
endswith
(
".bias"
):
p
.
optim_overrides
=
{
"optimizer"
:
{
"weight_decay_scale"
:
0
}}
if
self
.
fc_norm
is
not
None
:
for
pn
,
p
in
self
.
fc_norm
.
named_parameters
():
if
len
(
p
.
shape
)
==
1
or
pn
.
endswith
(
".bias"
):
p
.
optim_overrides
=
{
"optimizer"
:
{
"weight_decay_scale"
:
0
}}
for
pn
,
p
in
self
.
head
.
named_parameters
():
if
len
(
p
.
shape
)
==
1
or
pn
.
endswith
(
".bias"
):
p
.
optim_overrides
=
{
"optimizer"
:
{
"weight_decay_scale"
:
0
}}
if
self
.
d2v_multi
:
mod_encs
=
list
(
model
.
modality_encoders
.
values
())
assert
len
(
mod_encs
)
==
1
,
len
(
mod_encs
)
blocks
=
list
(
mod_encs
[
0
].
context_encoder
.
blocks
)
+
list
(
model
.
blocks
)
else
:
blocks
=
model
.
blocks
num_layers
=
len
(
blocks
)
+
1
layer_scales
=
list
(
cfg
.
layer_decay
**
(
num_layers
-
i
)
for
i
in
range
(
num_layers
+
1
)
)
if
self
.
d2v_multi
:
for
n
,
p
in
self
.
model
.
named_parameters
():
optimizer_override_dict
=
{}
if
len
(
p
.
shape
)
==
1
or
n
.
endswith
(
".bias"
):
optimizer_override_dict
[
"weight_decay_scale"
]
=
0
p
.
optim_overrides
=
{
"optimizer"
:
optimizer_override_dict
}
if
cfg
.
layer_decay
>
0
:
for
i
,
b
in
enumerate
(
blocks
):
lid
=
i
+
1
if
layer_scales
[
lid
]
==
1.0
:
continue
for
n
,
p
in
b
.
named_parameters
():
optim_override
=
getattr
(
p
,
"optim_overrides"
,
{})
if
"optimizer"
not
in
optim_override
:
optim_override
[
"optimizer"
]
=
{}
if
cfg
.
no_decay_blocks
:
optim_override
[
"optimizer"
][
"lr_scale"
]
=
layer_scales
[
lid
]
p
.
optim_overrides
=
optim_override
else
:
optim_override
[
"optimizer"
]
=
{
"lr_scale"
:
layer_scales
[
lid
]
}
p
.
optim_overrides
=
optim_override
else
:
for
n
,
p
in
self
.
model
.
named_parameters
():
optimizer_override_dict
=
{}
layer_id
=
get_layer_id_for_vit
(
n
,
num_layers
)
if
len
(
p
.
shape
)
==
1
or
n
.
endswith
(
".bias"
):
optimizer_override_dict
[
"weight_decay_scale"
]
=
0
if
cfg
.
layer_decay
>
0
:
optimizer_override_dict
[
"lr_scale"
]
=
layer_scales
[
layer_id
]
p
.
optim_overrides
=
{
"optimizer"
:
optimizer_override_dict
}
@
classmethod
def
build_model
(
cls
,
cfg
:
MaeImageClassificationConfig
,
task
=
None
):
"""Build a new model instance."""
return
cls
(
cfg
)
def
forward
(
self
,
imgs
,
labels
=
None
,
):
if
self
.
training
and
self
.
mixup_fn
is
not
None
and
labels
is
not
None
:
imgs
,
labels
=
self
.
mixup_fn
(
imgs
,
labels
)
if
self
.
linear_classifier
:
with
torch
.
no_grad
():
x
=
self
.
model_forward
(
imgs
)
else
:
x
=
self
.
model_forward
(
imgs
)
if
self
.
cfg
.
prediction_mode
==
PredictionMode
.
MEAN_POOLING
:
x
=
x
.
mean
(
dim
=
1
)
elif
self
.
cfg
.
prediction_mode
==
PredictionMode
.
CLS_TOKEN
:
x
=
x
[:,
0
]
elif
self
.
cfg
.
prediction_mode
==
PredictionMode
.
LIN_SOFTMAX
:
dtype
=
x
.
dtype
x
=
F
.
logsigmoid
(
x
.
float
())
x
=
torch
.
logsumexp
(
x
+
x
,
dim
=
1
)
-
torch
.
logsumexp
(
x
+
1e-6
,
dim
=
1
)
x
=
x
.
clamp
(
max
=
0
)
x
=
x
-
torch
.
log
(
-
(
torch
.
expm1
(
x
)))
x
=
torch
.
nan_to_num
(
x
,
nan
=
0
,
posinf
=
0
,
neginf
=
0
)
x
=
x
.
to
(
dtype
=
dtype
)
else
:
raise
Exception
(
f
"unknown prediction mode
{
self
.
cfg
.
prediction_mode
.
name
}
"
)
if
self
.
fc_norm
is
not
None
:
x
=
self
.
fc_norm
(
x
)
x
=
self
.
head
(
x
)
if
labels
is
None
:
return
x
if
self
.
training
and
self
.
mixup_fn
is
not
None
:
loss
=
-
labels
*
F
.
log_softmax
(
x
.
float
(),
dim
=-
1
)
else
:
loss
=
F
.
cross_entropy
(
x
.
float
(),
labels
,
label_smoothing
=
self
.
cfg
.
label_smoothing
if
self
.
training
else
0
,
reduction
=
"none"
,
)
result
=
{
"losses"
:
{
"regression"
:
loss
},
"sample_size"
:
imgs
.
size
(
0
),
}
if
not
self
.
training
:
with
torch
.
no_grad
():
pred
=
x
.
argmax
(
-
1
)
correct
=
(
pred
==
labels
).
sum
()
result
[
"correct"
]
=
correct
return
result
def
model_forward
(
self
,
imgs
):
if
self
.
d2v_multi
:
x
=
self
.
model
.
extract_features
(
imgs
,
mode
=
"IMAGE"
,
mask
=
False
,
remove_extra_tokens
=
(
self
.
cfg
.
prediction_mode
!=
PredictionMode
.
CLS_TOKEN
),
)[
"x"
]
else
:
x
=
self
.
model
(
imgs
,
predictions_only
=
True
)
if
(
"no_cls"
not
in
self
.
model
.
cfg
or
not
self
.
model
.
cfg
.
no_cls
)
and
not
self
.
cfg
.
prediction_mode
==
PredictionMode
.
CLS_TOKEN
:
x
=
x
[:,
1
:]
return
x
examples/data2vec/models/modalities/__init__.py
0 → 100644
View file @
72f5785f
examples/data2vec/models/modalities/audio.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
functools
import
partial
import
torch
import
torch.nn
as
nn
import
numpy
as
np
from
dataclasses
import
dataclass
,
field
from
typing
import
Callable
,
Dict
,
Optional
from
fairseq.models.wav2vec
import
ConvFeatureExtractionModel
from
fairseq.modules
import
(
LayerNorm
,
SamePad
,
TransposeLast
,
)
from
fairseq.tasks
import
FairseqTask
from
.base
import
D2vModalityConfig
,
ModalitySpecificEncoder
,
get_alibi_bias
from
.modules
import
BlockEncoder
,
Decoder1d
from
examples.data2vec.data.modality
import
Modality
@
dataclass
class
D2vAudioConfig
(
D2vModalityConfig
):
type
:
Modality
=
Modality
.
AUDIO
extractor_mode
:
str
=
"layer_norm"
feature_encoder_spec
:
str
=
field
(
default
=
"[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] + [(512,2,2)]"
,
metadata
=
{
"help"
:
"string describing convolutional feature extraction layers in form of a python list that contains "
"[(dim, kernel_size, stride), ...]"
},
)
conv_pos_width
:
int
=
field
(
default
=
95
,
metadata
=
{
"help"
:
"number of filters for convolutional positional embeddings"
},
)
conv_pos_groups
:
int
=
field
(
default
=
16
,
metadata
=
{
"help"
:
"number of groups for convolutional positional embedding"
},
)
conv_pos_depth
:
int
=
field
(
default
=
5
,
metadata
=
{
"help"
:
"depth of positional encoder network"
},
)
conv_pos_pre_ln
:
bool
=
False
class
AudioEncoder
(
ModalitySpecificEncoder
):
modality_cfg
:
D2vAudioConfig
def
__init__
(
self
,
modality_cfg
:
D2vAudioConfig
,
embed_dim
:
int
,
make_block
:
Callable
[[
float
],
nn
.
ModuleList
],
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
],
layer_norm_first
:
bool
,
alibi_biases
:
Dict
,
task
:
Optional
[
FairseqTask
],
):
self
.
feature_enc_layers
=
eval
(
modality_cfg
.
feature_encoder_spec
)
feature_embed_dim
=
self
.
feature_enc_layers
[
-
1
][
0
]
local_encoder
=
ConvFeatureExtractionModel
(
conv_layers
=
self
.
feature_enc_layers
,
dropout
=
0.0
,
mode
=
modality_cfg
.
extractor_mode
,
conv_bias
=
False
,
)
project_features
=
nn
.
Sequential
(
TransposeLast
(),
nn
.
LayerNorm
(
feature_embed_dim
),
nn
.
Linear
(
feature_embed_dim
,
embed_dim
),
)
num_pos_layers
=
modality_cfg
.
conv_pos_depth
k
=
max
(
3
,
modality_cfg
.
conv_pos_width
//
num_pos_layers
)
positional_encoder
=
nn
.
Sequential
(
TransposeLast
(),
*
[
nn
.
Sequential
(
nn
.
Conv1d
(
embed_dim
,
embed_dim
,
kernel_size
=
k
,
padding
=
k
//
2
,
groups
=
modality_cfg
.
conv_pos_groups
,
),
SamePad
(
k
),
TransposeLast
(),
LayerNorm
(
embed_dim
,
elementwise_affine
=
False
),
TransposeLast
(),
nn
.
GELU
(),
)
for
_
in
range
(
num_pos_layers
)
],
TransposeLast
(),
)
if
modality_cfg
.
conv_pos_pre_ln
:
positional_encoder
=
nn
.
Sequential
(
LayerNorm
(
embed_dim
),
positional_encoder
)
dpr
=
np
.
linspace
(
modality_cfg
.
start_drop_path_rate
,
modality_cfg
.
end_drop_path_rate
,
modality_cfg
.
prenet_depth
,
)
context_encoder
=
BlockEncoder
(
nn
.
ModuleList
(
make_block
(
dpr
[
i
])
for
i
in
range
(
modality_cfg
.
prenet_depth
)),
norm_layer
(
embed_dim
)
if
not
layer_norm_first
else
None
,
layer_norm_first
,
modality_cfg
.
prenet_layerdrop
,
modality_cfg
.
prenet_dropout
,
)
decoder
=
(
Decoder1d
(
modality_cfg
.
decoder
,
embed_dim
)
if
modality_cfg
.
decoder
is
not
None
else
None
)
alibi_bias_fn
=
partial
(
get_alibi_bias
,
alibi_biases
=
alibi_biases
)
super
().
__init__
(
modality_cfg
=
modality_cfg
,
embed_dim
=
embed_dim
,
local_encoder
=
local_encoder
,
project_features
=
project_features
,
fixed_positional_encoder
=
None
,
relative_positional_encoder
=
positional_encoder
,
context_encoder
=
context_encoder
,
decoder
=
decoder
,
get_alibi_bias
=
alibi_bias_fn
,
)
def
convert_padding_mask
(
self
,
x
,
padding_mask
):
def
get_feat_extract_output_lengths
(
input_lengths
:
torch
.
LongTensor
):
"""
Computes the output length of the convolutional layers
"""
def
_conv_out_length
(
input_length
,
kernel_size
,
stride
):
return
torch
.
floor
((
input_length
-
kernel_size
)
/
stride
+
1
)
for
i
in
range
(
len
(
self
.
feature_enc_layers
)):
input_lengths
=
_conv_out_length
(
input_lengths
,
self
.
feature_enc_layers
[
i
][
1
],
self
.
feature_enc_layers
[
i
][
2
],
)
return
input_lengths
.
to
(
torch
.
long
)
if
padding_mask
is
not
None
:
input_lengths
=
(
1
-
padding_mask
.
long
()).
sum
(
-
1
)
# apply conv formula to get real output_lengths
output_lengths
=
get_feat_extract_output_lengths
(
input_lengths
)
if
padding_mask
.
any
():
padding_mask
=
torch
.
zeros
(
x
.
shape
[:
2
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
# these two operations makes sure that all values
# before the output lengths indices are attended to
padding_mask
[
(
torch
.
arange
(
padding_mask
.
shape
[
0
],
device
=
padding_mask
.
device
),
output_lengths
-
1
,
)
]
=
1
padding_mask
=
(
1
-
padding_mask
.
flip
([
-
1
]).
cumsum
(
-
1
).
flip
([
-
1
])
).
bool
()
else
:
padding_mask
=
torch
.
zeros
(
x
.
shape
[:
2
],
dtype
=
torch
.
bool
,
device
=
x
.
device
)
return
padding_mask
def
reset_parameters
(
self
):
super
().
reset_parameters
()
for
mod
in
self
.
project_features
.
children
():
if
isinstance
(
mod
,
nn
.
Linear
):
mod
.
reset_parameters
()
if
self
.
decoder
is
not
None
:
self
.
decoder
.
reset_parameters
()
examples/data2vec/models/modalities/base.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
collections
import
namedtuple
from
dataclasses
import
dataclass
from
functools
import
partial
from
omegaconf
import
MISSING
,
II
from
typing
import
Optional
,
Callable
from
fairseq.data.data_utils
import
compute_mask_indices
from
fairseq.modules
import
GradMultiply
from
fairseq.utils
import
index_put
from
examples.data2vec.data.modality
import
Modality
from
.modules
import
D2vDecoderConfig
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
D2vModalityConfig
:
type
:
Modality
=
MISSING
prenet_depth
:
int
=
4
prenet_layerdrop
:
float
=
0
prenet_dropout
:
float
=
0
start_drop_path_rate
:
float
=
0
end_drop_path_rate
:
float
=
0
num_extra_tokens
:
int
=
0
init_extra_token_zero
:
bool
=
True
mask_noise_std
:
float
=
0.01
mask_prob_min
:
Optional
[
float
]
=
None
mask_prob
:
float
=
0.7
inverse_mask
:
bool
=
False
mask_prob_adjust
:
float
=
0
keep_masked_pct
:
float
=
0
mask_length
:
int
=
5
add_masks
:
bool
=
False
remove_masks
:
bool
=
False
mask_dropout
:
float
=
0.0
encoder_zero_mask
:
bool
=
True
mask_channel_prob
:
float
=
0.0
mask_channel_length
:
int
=
64
ema_local_encoder
:
bool
=
False
# used in data2vec_multi
local_grad_mult
:
float
=
1.0
use_alibi_encoder
:
bool
=
False
alibi_scale
:
float
=
1.0
learned_alibi
:
bool
=
False
alibi_max_pos
:
Optional
[
int
]
=
None
learned_alibi_scale
:
bool
=
False
learned_alibi_scale_per_head
:
bool
=
False
learned_alibi_scale_per_layer
:
bool
=
False
num_alibi_heads
:
int
=
II
(
"model.num_heads"
)
model_depth
:
int
=
II
(
"model.depth"
)
decoder
:
Optional
[
D2vDecoderConfig
]
=
D2vDecoderConfig
()
MaskSeed
=
namedtuple
(
"MaskSeed"
,
[
"seed"
,
"update"
,
"ids"
])
MaskInfo
=
namedtuple
(
"MaskInfo"
,
[
"x_unmasked"
,
"mask"
,
"ids_restore"
,
"ids_keep"
])
class
ModalitySpecificEncoder
(
nn
.
Module
):
def
__init__
(
self
,
modality_cfg
:
D2vModalityConfig
,
embed_dim
:
int
,
local_encoder
:
nn
.
Module
,
project_features
:
nn
.
Module
,
fixed_positional_encoder
:
Optional
[
nn
.
Module
],
relative_positional_encoder
:
Optional
[
nn
.
Module
],
context_encoder
:
nn
.
Module
,
decoder
:
nn
.
Module
,
get_alibi_bias
:
Optional
[
Callable
[[
int
,
int
,
str
,
str
],
torch
.
Tensor
]],
):
super
().
__init__
()
self
.
modality_cfg
=
modality_cfg
self
.
local_encoder
=
local_encoder
self
.
project_features
=
project_features
self
.
fixed_positional_encoder
=
fixed_positional_encoder
self
.
relative_positional_encoder
=
relative_positional_encoder
self
.
context_encoder
=
context_encoder
self
.
decoder
=
decoder
self
.
get_alibi_bias
=
get_alibi_bias
if
modality_cfg
.
use_alibi_encoder
else
None
self
.
local_grad_mult
=
self
.
modality_cfg
.
local_grad_mult
self
.
extra_tokens
=
None
if
modality_cfg
.
num_extra_tokens
>
0
:
self
.
extra_tokens
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
modality_cfg
.
num_extra_tokens
,
embed_dim
)
)
if
not
modality_cfg
.
init_extra_token_zero
:
nn
.
init
.
normal_
(
self
.
extra_tokens
)
elif
self
.
extra_tokens
.
size
(
1
)
>
1
:
nn
.
init
.
normal_
(
self
.
extra_tokens
[:,
1
:])
self
.
alibi_scale
=
None
if
self
.
get_alibi_bias
is
not
None
:
self
.
alibi_scale
=
nn
.
Parameter
(
torch
.
full
(
(
(
modality_cfg
.
prenet_depth
+
modality_cfg
.
model_depth
)
if
modality_cfg
.
learned_alibi_scale_per_layer
else
1
,
1
,
self
.
modality_cfg
.
num_alibi_heads
if
modality_cfg
.
learned_alibi_scale_per_head
else
1
,
1
,
1
,
),
modality_cfg
.
alibi_scale
,
dtype
=
torch
.
float
,
),
requires_grad
=
modality_cfg
.
learned_alibi_scale
,
)
if
modality_cfg
.
learned_alibi
and
self
.
get_alibi_bias
is
not
None
:
assert
modality_cfg
.
alibi_max_pos
is
not
None
alibi_bias
=
self
.
get_alibi_bias
(
batch_size
=
1
,
time_steps
=
modality_cfg
.
alibi_max_pos
,
heads
=
modality_cfg
.
num_alibi_heads
,
scale
=
1.0
,
dtype
=
torch
.
float
,
device
=
"cpu"
,
)
self
.
alibi_bias
=
nn
.
Parameter
(
alibi_bias
)
self
.
get_alibi_bias
=
partial
(
_learned_alibi_bias
,
alibi_bias
=
self
.
alibi_bias
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
k
=
f
"
{
name
}
.alibi_scale"
if
k
in
state_dict
and
state_dict
[
k
].
dim
()
==
4
:
state_dict
[
k
]
=
state_dict
[
k
].
unsqueeze
(
0
)
return
state_dict
def
convert_padding_mask
(
self
,
x
,
padding_mask
):
return
padding_mask
def
decoder_input
(
self
,
x
,
mask_info
:
MaskInfo
):
inp_drop
=
self
.
modality_cfg
.
decoder
.
input_dropout
if
inp_drop
>
0
:
x
=
F
.
dropout
(
x
,
inp_drop
,
training
=
self
.
training
,
inplace
=
True
)
num_extra
=
self
.
modality_cfg
.
num_extra_tokens
if
mask_info
is
not
None
:
num_masked
=
mask_info
.
ids_restore
.
shape
[
1
]
-
x
.
shape
[
1
]
+
num_extra
mask_tokens
=
x
.
new_empty
(
x
.
size
(
0
),
num_masked
,
x
.
size
(
-
1
),
).
normal_
(
0
,
self
.
modality_cfg
.
mask_noise_std
)
x_
=
torch
.
cat
([
x
[:,
num_extra
:],
mask_tokens
],
dim
=
1
)
x
=
torch
.
gather
(
x_
,
dim
=
1
,
index
=
mask_info
.
ids_restore
)
if
self
.
modality_cfg
.
decoder
.
add_positions_masked
:
assert
self
.
fixed_positional_encoder
is
not
None
pos
=
self
.
fixed_positional_encoder
(
x
,
None
)
x
=
x
+
(
pos
*
mask_info
.
mask
.
unsqueeze
(
-
1
))
else
:
x
=
x
[:,
num_extra
:]
if
self
.
modality_cfg
.
decoder
.
add_positions_all
:
assert
self
.
fixed_positional_encoder
is
not
None
x
=
x
+
self
.
fixed_positional_encoder
(
x
,
None
)
return
x
,
mask_info
def
local_features
(
self
,
features
):
if
self
.
local_grad_mult
>
0
:
if
self
.
local_grad_mult
==
1.0
:
x
=
self
.
local_encoder
(
features
)
else
:
x
=
GradMultiply
.
apply
(
self
.
local_encoder
(
features
),
self
.
local_grad_mult
)
else
:
with
torch
.
no_grad
():
x
=
self
.
local_encoder
(
features
)
x
=
self
.
project_features
(
x
)
return
x
def
contextualized_features
(
self
,
x
,
padding_mask
,
mask
,
remove_masked
,
clone_batch
:
int
=
1
,
mask_seeds
:
Optional
[
torch
.
Tensor
]
=
None
,
precomputed_mask
=
None
,
):
if
padding_mask
is
not
None
:
padding_mask
=
self
.
convert_padding_mask
(
x
,
padding_mask
)
local_features
=
x
if
mask
and
clone_batch
==
1
:
local_features
=
local_features
.
clone
()
orig_B
,
orig_T
,
_
=
x
.
shape
pre_mask_B
=
orig_B
mask_info
=
None
x_pos
=
None
if
self
.
fixed_positional_encoder
is
not
None
:
x
=
x
+
self
.
fixed_positional_encoder
(
x
,
padding_mask
)
if
mask
:
if
clone_batch
>
1
:
x
=
x
.
repeat_interleave
(
clone_batch
,
0
)
if
mask_seeds
is
not
None
:
clone_hash
=
[
int
(
hash
((
mask_seeds
.
seed
,
ind
))
%
1e10
)
for
ind
in
range
(
clone_batch
-
1
)
]
clone_hash
=
torch
.
tensor
([
0
]
+
clone_hash
).
long
().
view
(
1
,
-
1
)
id
=
mask_seeds
.
ids
id
=
id
.
repeat_interleave
(
clone_batch
,
0
)
id
=
id
.
view
(
-
1
,
clone_batch
)
+
clone_hash
.
to
(
id
)
id
=
id
.
view
(
-
1
)
mask_seeds
=
MaskSeed
(
seed
=
mask_seeds
.
seed
,
update
=
mask_seeds
.
update
,
ids
=
id
)
if
padding_mask
is
not
None
:
padding_mask
=
padding_mask
.
repeat_interleave
(
clone_batch
,
0
)
x
,
mask_info
=
self
.
compute_mask
(
x
,
padding_mask
,
mask_seed
=
mask_seeds
,
apply
=
self
.
relative_positional_encoder
is
not
None
or
not
remove_masked
,
precomputed_mask
=
precomputed_mask
,
)
if
self
.
relative_positional_encoder
is
not
None
:
x_pos
=
self
.
relative_positional_encoder
(
x
)
masked_padding_mask
=
padding_mask
if
mask
and
remove_masked
:
x
=
mask_info
.
x_unmasked
if
x_pos
is
not
None
:
x
=
x
+
gather_unmasked
(
x_pos
,
mask_info
)
if
padding_mask
is
not
None
and
padding_mask
.
any
():
masked_padding_mask
=
gather_unmasked_mask
(
padding_mask
,
mask_info
)
if
not
masked_padding_mask
.
any
():
masked_padding_mask
=
None
else
:
masked_padding_mask
=
None
elif
x_pos
is
not
None
:
x
=
x
+
x_pos
alibi_bias
=
None
alibi_scale
=
self
.
alibi_scale
if
self
.
get_alibi_bias
is
not
None
:
alibi_bias
=
self
.
get_alibi_bias
(
batch_size
=
pre_mask_B
,
time_steps
=
orig_T
,
heads
=
self
.
modality_cfg
.
num_alibi_heads
,
dtype
=
torch
.
float32
,
device
=
x
.
device
,
)
if
alibi_scale
is
not
None
:
alibi_scale
=
alibi_scale
.
clamp_min
(
0
)
if
alibi_scale
.
size
(
0
)
==
1
:
alibi_bias
=
alibi_bias
*
alibi_scale
.
squeeze
(
0
).
type_as
(
alibi_bias
)
alibi_scale
=
None
if
clone_batch
>
1
:
alibi_bias
=
alibi_bias
.
repeat_interleave
(
clone_batch
,
0
)
if
mask_info
is
not
None
and
remove_masked
:
alibi_bias
=
masked_alibi
(
alibi_bias
,
mask_info
)
if
self
.
extra_tokens
is
not
None
:
num
=
self
.
extra_tokens
.
size
(
1
)
x
=
torch
.
cat
([
self
.
extra_tokens
.
expand
(
x
.
size
(
0
),
-
1
,
-
1
),
x
],
dim
=
1
)
if
masked_padding_mask
is
not
None
:
# B x T
masked_padding_mask
=
F
.
pad
(
masked_padding_mask
,
(
num
,
0
))
if
alibi_bias
is
not
None
:
# B x H x T x T
alibi_bias
=
F
.
pad
(
alibi_bias
,
(
num
,
0
,
num
,
0
))
x
=
self
.
context_encoder
(
x
,
masked_padding_mask
,
alibi_bias
,
alibi_scale
[:
self
.
modality_cfg
.
prenet_depth
]
if
alibi_scale
is
not
None
else
None
,
)
return
{
"x"
:
x
,
"local_features"
:
local_features
,
"padding_mask"
:
masked_padding_mask
,
"alibi_bias"
:
alibi_bias
,
"alibi_scale"
:
alibi_scale
[
self
.
modality_cfg
.
prenet_depth
:]
if
alibi_scale
is
not
None
and
alibi_scale
.
size
(
0
)
>
1
else
alibi_scale
,
"encoder_mask"
:
mask_info
,
}
def
forward
(
self
,
features
,
padding_mask
,
mask
:
bool
,
remove_masked
:
bool
,
clone_batch
:
int
=
1
,
mask_seeds
:
Optional
[
torch
.
Tensor
]
=
None
,
precomputed_mask
=
None
,
):
x
=
self
.
local_features
(
features
)
return
self
.
contextualized_features
(
x
,
padding_mask
,
mask
,
remove_masked
,
clone_batch
,
mask_seeds
,
precomputed_mask
,
)
def
reset_parameters
(
self
):
pass
def
compute_mask
(
self
,
x
,
padding_mask
,
mask_seed
:
Optional
[
MaskSeed
],
apply
,
precomputed_mask
,
):
if
precomputed_mask
is
not
None
:
mask
=
precomputed_mask
mask_info
=
self
.
make_maskinfo
(
x
,
mask
)
else
:
B
,
T
,
C
=
x
.
shape
cfg
=
self
.
modality_cfg
mask_prob
=
cfg
.
mask_prob
if
(
cfg
.
mask_prob_min
is
not
None
and
cfg
.
mask_prob_min
>=
0
and
cfg
.
mask_prob_min
<
mask_prob
):
mask_prob
=
np
.
random
.
uniform
(
cfg
.
mask_prob_min
,
mask_prob
)
if
mask_prob
>
0
:
if
cfg
.
mask_length
==
1
:
mask_info
=
random_masking
(
x
,
mask_prob
,
mask_seed
)
else
:
if
self
.
modality_cfg
.
inverse_mask
:
mask_prob
=
1
-
mask_prob
mask
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
mask_prob
,
cfg
.
mask_length
,
min_masks
=
1
,
require_same_masks
=
True
,
mask_dropout
=
cfg
.
mask_dropout
,
add_masks
=
cfg
.
add_masks
,
seed
=
mask_seed
.
seed
if
mask_seed
is
not
None
else
None
,
epoch
=
mask_seed
.
update
if
mask_seed
is
not
None
else
None
,
indices
=
mask_seed
.
ids
if
mask_seed
is
not
None
else
None
,
)
mask
=
torch
.
from_numpy
(
mask
).
to
(
device
=
x
.
device
)
if
self
.
modality_cfg
.
inverse_mask
:
mask
=
1
-
mask
mask_info
=
self
.
make_maskinfo
(
x
,
mask
)
else
:
mask_info
=
None
if
apply
:
x
=
self
.
apply_mask
(
x
,
mask_info
)
return
x
,
mask_info
def
make_maskinfo
(
self
,
x
,
mask
,
shape
=
None
):
if
shape
is
None
:
B
,
T
,
D
=
x
.
shape
else
:
B
,
T
,
D
=
shape
mask
=
mask
.
to
(
torch
.
uint8
)
ids_shuffle
=
mask
.
argsort
(
dim
=
1
)
ids_restore
=
ids_shuffle
.
argsort
(
dim
=
1
).
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
D
)
len_keep
=
T
-
mask
[
0
].
sum
()
if
self
.
modality_cfg
.
keep_masked_pct
>
0
:
len_keep
+=
round
((
T
-
int
(
len_keep
))
*
self
.
modality_cfg
.
keep_masked_pct
)
ids_keep
=
ids_shuffle
[:,
:
len_keep
]
if
shape
is
not
None
:
x_unmasked
=
None
else
:
ids_keep
=
ids_keep
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
D
)
x_unmasked
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
ids_keep
)
mask_info
=
MaskInfo
(
x_unmasked
=
x_unmasked
,
mask
=
mask
,
ids_restore
=
ids_restore
,
ids_keep
=
ids_keep
,
)
return
mask_info
def
apply_mask
(
self
,
x
,
mask_info
):
cfg
=
self
.
modality_cfg
B
,
T
,
C
=
x
.
shape
if
mask_info
is
not
None
:
mask
=
mask_info
.
mask
if
cfg
.
encoder_zero_mask
:
x
=
x
*
(
1
-
mask
.
type_as
(
x
).
unsqueeze
(
-
1
))
else
:
num_masks
=
mask
.
sum
().
item
()
masks
=
x
.
new_empty
(
num_masks
,
x
.
size
(
-
1
)).
normal_
(
0
,
cfg
.
mask_noise_std
)
x
=
index_put
(
x
,
mask
,
masks
)
if
cfg
.
mask_channel_prob
>
0
:
mask_channel
=
compute_mask_indices
(
(
B
,
C
),
None
,
cfg
.
mask_channel_prob
,
cfg
.
mask_channel_length
,
)
mask_channel
=
(
torch
.
from_numpy
(
mask_channel
)
.
to
(
x
.
device
)
.
unsqueeze
(
1
)
.
expand
(
-
1
,
T
,
-
1
)
)
x
=
index_put
(
x
,
mask_channel
,
0
)
return
x
def
remove_pretraining_modules
(
self
,
keep_decoder
=
False
):
if
not
keep_decoder
:
self
.
decoder
=
None
def
get_annealed_rate
(
start
,
end
,
curr_step
,
total_steps
):
if
curr_step
>=
total_steps
:
return
end
r
=
end
-
start
pct_remaining
=
1
-
curr_step
/
total_steps
return
end
-
r
*
pct_remaining
# adapted from MAE
def
random_masking
(
x
,
mask_ratio
,
mask_seed
:
Optional
[
MaskSeed
]):
N
,
L
,
D
=
x
.
shape
# batch, length, dim
len_keep
=
int
(
L
*
(
1
-
mask_ratio
))
generator
=
None
if
mask_seed
is
not
None
:
seed
=
int
(
hash
((
mask_seed
.
seed
,
mask_seed
.
update
,
mask_seed
.
ids
.
sum
().
item
()))
%
1e6
)
generator
=
torch
.
Generator
(
device
=
x
.
device
)
generator
.
manual_seed
(
seed
)
noise
=
torch
.
rand
(
N
,
L
,
generator
=
generator
,
device
=
x
.
device
)
# noise in [0, 1]
# sort noise for each sample
ids_shuffle
=
noise
.
argsort
(
dim
=
1
)
# ascend: small is keep, large is remove
ids_restore
=
ids_shuffle
.
argsort
(
dim
=
1
)
# keep the first subset
ids_keep
=
ids_shuffle
[:,
:
len_keep
]
ids_keep
=
ids_keep
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
D
)
x_unmasked
=
torch
.
gather
(
x
,
dim
=
1
,
index
=
ids_keep
)
# generate the binary mask: 0 is keep, 1 is remove
mask
=
torch
.
ones
([
N
,
L
],
dtype
=
x
.
dtype
,
device
=
x
.
device
)
mask
[:,
:
len_keep
]
=
0
# unshuffle to get the binary mask
mask
=
torch
.
gather
(
mask
,
dim
=
1
,
index
=
ids_restore
)
ids_restore
=
ids_restore
.
unsqueeze
(
-
1
).
expand
(
-
1
,
-
1
,
D
)
return
MaskInfo
(
x_unmasked
=
x_unmasked
,
mask
=
mask
,
ids_restore
=
ids_restore
,
ids_keep
=
ids_keep
)
def
gather_unmasked
(
x
:
torch
.
Tensor
,
mask_info
:
MaskInfo
)
->
torch
.
Tensor
:
return
torch
.
gather
(
x
,
dim
=
1
,
index
=
mask_info
.
ids_keep
,
)
def
gather_unmasked_mask
(
x
:
torch
.
Tensor
,
mask_info
:
MaskInfo
)
->
torch
.
Tensor
:
return
torch
.
gather
(
x
,
dim
=
1
,
index
=
mask_info
.
ids_keep
[...,
0
],
# ignore the feature dimension
)
def
get_alibi
(
max_positions
:
int
,
attention_heads
:
int
,
dims
:
int
=
1
,
distance
:
str
=
"manhattan"
,
):
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
# In the paper, we only train models that have 2^a heads for some
# a. This function has some good properties that only occur when
# the input is a power of 2. To maintain that even when the number
# of heads is not a power of 2, we use this workaround.
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
]
)
maxpos
=
max_positions
attn_heads
=
attention_heads
slopes
=
torch
.
Tensor
(
get_slopes
(
attn_heads
))
if
dims
==
1
:
# prepare alibi position linear bias. Note that wav2vec2 is non
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias
=
(
torch
.
abs
(
torch
.
arange
(
maxpos
).
unsqueeze
(
0
)
-
torch
.
arange
(
maxpos
).
unsqueeze
(
1
)
)
*
-
1
)
elif
dims
==
2
:
if
distance
==
"manhattan"
:
df
=
lambda
x1
,
y1
,
x2
,
y2
:
abs
(
x1
-
x2
)
+
abs
(
y1
-
y2
)
elif
distance
==
"euclidean"
:
df
=
lambda
x1
,
y1
,
x2
,
y2
:
math
.
sqrt
((
x1
-
x2
)
**
2
+
(
y1
-
y2
)
**
2
)
n
=
math
.
sqrt
(
max_positions
)
assert
n
.
is_integer
(),
n
n
=
int
(
n
)
pos_bias
=
torch
.
zeros
((
max_positions
,
max_positions
))
for
i
in
range
(
n
):
for
j
in
range
(
n
):
for
k
in
range
(
n
):
for
l
in
range
(
n
):
new_x
=
i
*
n
+
j
new_y
=
k
*
n
+
l
pos_bias
[
new_x
,
new_y
]
=
-
df
(
i
,
j
,
k
,
l
)
else
:
raise
Exception
(
f
"unsupported number of alibi dims:
{
dims
}
"
)
alibi_bias
=
slopes
.
unsqueeze
(
1
).
unsqueeze
(
1
)
*
pos_bias
.
unsqueeze
(
0
).
expand
(
attn_heads
,
-
1
,
-
1
)
return
alibi_bias
def
get_alibi_bias
(
alibi_biases
,
batch_size
,
time_steps
,
heads
,
dtype
,
device
,
dims
=
1
,
distance
=
"manhattan"
,
):
cache_key
=
f
"
{
dims
}
_
{
heads
}
_
{
distance
}
"
buffered
=
alibi_biases
.
get
(
cache_key
,
None
)
target_size
=
heads
*
batch_size
if
(
buffered
is
None
or
buffered
.
size
(
0
)
<
target_size
or
buffered
.
size
(
1
)
<
time_steps
or
buffered
.
dtype
!=
dtype
or
buffered
.
device
!=
device
):
bt
=
max
(
time_steps
,
buffered
.
size
(
1
)
if
buffered
is
not
None
else
0
)
bn
=
max
(
target_size
,
buffered
.
size
(
0
)
if
buffered
is
not
None
else
0
)
//
heads
buffered
=
(
get_alibi
(
bt
,
heads
,
dims
=
dims
,
distance
=
distance
)
.
to
(
dtype
=
dtype
,
device
=
device
)
.
repeat
(
bn
,
1
,
1
)
)
alibi_biases
[
cache_key
]
=
buffered
b
=
buffered
[:
target_size
,
:
time_steps
,
:
time_steps
]
b
=
b
.
view
(
batch_size
,
heads
,
time_steps
,
time_steps
)
return
b
def
_learned_alibi_bias
(
alibi_bias
,
batch_size
,
time_steps
,
heads
,
scale
,
dtype
,
device
,
):
assert
alibi_bias
.
size
(
1
)
==
heads
,
alibi_bias
.
shape
assert
alibi_bias
.
dtype
==
dtype
,
alibi_bias
.
dtype
assert
alibi_bias
.
device
==
device
,
alibi_bias
.
device
if
alibi_bias
.
size
(
-
1
)
<
time_steps
:
psz
=
math
.
ceil
((
time_steps
-
alibi_bias
.
size
(
-
1
))
/
2
)
alibi_bias
=
F
.
pad
(
alibi_bias
,
(
psz
,
psz
,
psz
,
psz
),
mode
=
"replicate"
)
alibi_bias
=
alibi_bias
.
expand
(
batch_size
,
-
1
,
-
1
,
-
1
)
*
scale
return
alibi_bias
[...,
:
time_steps
,
:
time_steps
]
def
masked_alibi
(
alibi_bias
,
mask_info
):
H
=
alibi_bias
.
size
(
1
)
orig_bias
=
alibi_bias
index
=
mask_info
.
ids_keep
.
unsqueeze
(
1
)[...,
0
].
unsqueeze
(
-
1
)
alibi_bias
=
torch
.
gather
(
orig_bias
,
dim
=-
2
,
index
=
index
.
expand
(
-
1
,
H
,
-
1
,
mask_info
.
ids_restore
.
size
(
1
)),
)
alibi_bias
=
torch
.
gather
(
alibi_bias
,
dim
=-
1
,
index
=
index
.
transpose
(
-
1
,
-
2
).
expand
(
-
1
,
H
,
alibi_bias
.
size
(
-
2
),
-
1
),
)
return
alibi_bias
examples/data2vec/models/modalities/images.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
functools
import
partial
from
dataclasses
import
dataclass
from
typing
import
Callable
,
Dict
,
Optional
from
timm.models.layers
import
to_2tuple
from
fairseq.tasks
import
FairseqTask
from
examples.data2vec.models.mae
import
get_2d_sincos_pos_embed
,
PatchEmbed
from
.base
import
(
D2vModalityConfig
,
ModalitySpecificEncoder
,
get_alibi_bias
,
MaskSeed
,
)
from
.modules
import
(
BlockEncoder
,
Decoder2d
,
FixedPositionalEncoder
,
TransformerDecoder
,
EncDecTransformerDecoder
,
)
from
examples.data2vec.data.modality
import
Modality
@
dataclass
class
D2vImageConfig
(
D2vModalityConfig
):
type
:
Modality
=
Modality
.
IMAGE
input_size
:
int
=
224
in_chans
:
int
=
3
patch_size
:
int
=
16
embed_dim
:
int
=
768
alibi_dims
:
int
=
2
alibi_distance
:
str
=
"manhattan"
fixed_positions
:
bool
=
True
transformer_decoder
:
bool
=
False
enc_dec_transformer
:
bool
=
False
class
ImageEncoder
(
ModalitySpecificEncoder
):
modality_cfg
:
D2vImageConfig
def
__init__
(
self
,
modality_cfg
:
D2vImageConfig
,
embed_dim
:
int
,
make_block
:
Callable
[[
float
,
Optional
[
int
],
Optional
[
int
]],
nn
.
ModuleList
],
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
],
layer_norm_first
:
bool
,
alibi_biases
:
Dict
,
task
:
Optional
[
FairseqTask
],
):
img_size
=
to_2tuple
(
modality_cfg
.
input_size
)
patch_size
=
to_2tuple
(
modality_cfg
.
patch_size
)
num_patches
=
(
img_size
[
1
]
//
patch_size
[
1
])
*
(
img_size
[
0
]
//
patch_size
[
0
])
local_encoder
=
PatchEmbed
(
modality_cfg
.
input_size
,
modality_cfg
.
patch_size
,
modality_cfg
.
in_chans
,
modality_cfg
.
embed_dim
,
)
w
=
local_encoder
.
proj
.
weight
.
data
torch
.
nn
.
init
.
xavier_uniform_
(
w
.
view
([
w
.
shape
[
0
],
-
1
]))
if
modality_cfg
.
embed_dim
!=
embed_dim
:
local_encoder
=
nn
.
Sequential
(
local_encoder
,
nn
.
Linear
(
modality_cfg
.
embed_dim
,
embed_dim
),
)
project_features
=
nn
.
Identity
()
pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
),
requires_grad
=
False
)
side_n
=
int
(
num_patches
**
0.5
)
emb
=
get_2d_sincos_pos_embed
(
pos_embed
.
shape
[
-
1
],
side_n
,
cls_token
=
False
,
)
pos_embed
.
data
.
copy_
(
torch
.
from_numpy
(
emb
).
float
().
unsqueeze
(
0
))
fixed_positional_encoder
=
(
FixedPositionalEncoder
(
pos_embed
)
if
modality_cfg
.
fixed_positions
else
None
)
dpr
=
np
.
linspace
(
modality_cfg
.
start_drop_path_rate
,
modality_cfg
.
end_drop_path_rate
,
modality_cfg
.
prenet_depth
,
)
context_encoder
=
BlockEncoder
(
nn
.
ModuleList
(
make_block
(
dpr
[
i
])
for
i
in
range
(
modality_cfg
.
prenet_depth
)),
norm_layer
(
embed_dim
)
if
not
layer_norm_first
else
None
,
layer_norm_first
,
modality_cfg
.
prenet_layerdrop
,
modality_cfg
.
prenet_dropout
,
)
if
modality_cfg
.
transformer_decoder
:
if
modality_cfg
.
enc_dec_transformer
:
decoder
=
EncDecTransformerDecoder
(
modality_cfg
.
decoder
,
embed_dim
)
else
:
dec_enc
=
BlockEncoder
(
nn
.
ModuleList
(
make_block
(
0
,
modality_cfg
.
decoder
.
decoder_dim
,
8
)
for
_
in
range
(
modality_cfg
.
decoder
.
decoder_layers
)
),
None
,
layer_norm_first
,
0
,
0
,
)
decoder
=
TransformerDecoder
(
modality_cfg
.
decoder
,
embed_dim
,
dec_enc
)
else
:
decoder
=
(
Decoder2d
(
modality_cfg
.
decoder
,
embed_dim
,
side_n
,
side_n
)
if
modality_cfg
.
decoder
is
not
None
else
None
)
alibi_bias_fn
=
partial
(
get_alibi_bias
,
alibi_biases
=
alibi_biases
,
heads
=
modality_cfg
.
num_alibi_heads
,
dims
=
modality_cfg
.
alibi_dims
,
distance
=
modality_cfg
.
alibi_distance
,
)
super
().
__init__
(
modality_cfg
=
modality_cfg
,
embed_dim
=
embed_dim
,
local_encoder
=
local_encoder
,
project_features
=
project_features
,
fixed_positional_encoder
=
fixed_positional_encoder
,
relative_positional_encoder
=
None
,
context_encoder
=
context_encoder
,
decoder
=
decoder
,
get_alibi_bias
=
alibi_bias_fn
,
)
def
reset_parameters
(
self
):
super
().
reset_parameters
()
if
self
.
decoder
is
not
None
:
self
.
decoder
.
reset_parameters
()
@
torch
.
no_grad
()
def
patchify
(
self
,
imgs
):
"""
imgs: (N, 3, H, W)
x: (N, L, patch_size**2 *3)
"""
p
=
self
.
modality_cfg
.
patch_size
h
=
w
=
imgs
.
shape
[
2
]
//
p
x
=
imgs
.
reshape
(
shape
=
(
imgs
.
shape
[
0
],
3
,
h
,
p
,
w
,
p
))
x
=
torch
.
einsum
(
"nchpwq->nhwpqc"
,
x
)
x
=
x
.
reshape
(
shape
=
(
imgs
.
shape
[
0
],
h
*
w
,
p
**
2
*
3
))
return
x
@
torch
.
no_grad
()
def
unpatchify
(
self
,
x
):
"""
x: (N, L, patch_size**2 *3)
imgs: (N, 3, H, W)
"""
p
=
self
.
modality_cfg
.
patch_size
h
=
w
=
int
(
x
.
shape
[
1
]
**
0.5
)
assert
h
*
w
==
x
.
shape
[
1
]
x
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
h
,
w
,
p
,
p
,
3
))
x
=
torch
.
einsum
(
"nhwpqc->nchpwq"
,
x
)
imgs
=
x
.
reshape
(
shape
=
(
x
.
shape
[
0
],
3
,
h
*
p
,
h
*
p
))
return
imgs
def
compute_mask
(
self
,
x
,
padding_mask
,
mask_seed
:
Optional
[
MaskSeed
],
apply
,
shape
=
None
,
precomputed_mask
=
None
,
):
mlen
=
self
.
modality_cfg
.
mask_length
if
mlen
<=
1
:
return
super
().
compute_mask
(
x
,
padding_mask
,
mask_seed
,
apply
,
precomputed_mask
)
if
precomputed_mask
is
not
None
:
mask
=
precomputed_mask
else
:
from
fairseq.data.data_utils
import
compute_block_mask_2d
if
shape
is
not
None
:
B
,
L
,
D
=
shape
else
:
B
,
L
,
D
=
x
.
shape
mask
=
compute_block_mask_2d
(
shape
=
(
B
,
L
),
mask_prob
=
self
.
modality_cfg
.
mask_prob
,
mask_length
=
self
.
modality_cfg
.
mask_length
,
mask_prob_adjust
=
self
.
modality_cfg
.
mask_prob_adjust
,
inverse_mask
=
self
.
modality_cfg
.
inverse_mask
,
require_same_masks
=
True
,
mask_dropout
=
self
.
modality_cfg
.
mask_dropout
,
)
mask_info
=
self
.
make_maskinfo
(
x
,
mask
,
shape
)
if
apply
:
x
=
self
.
apply_mask
(
x
,
mask_info
)
return
x
,
mask_info
def
decoder_input
(
self
,
x
,
mask_info
):
if
(
not
self
.
modality_cfg
.
transformer_decoder
or
not
self
.
modality_cfg
.
enc_dec_transformer
):
return
super
().
decoder_input
(
x
,
mask_info
)
inp_drop
=
self
.
modality_cfg
.
decoder
.
input_dropout
if
inp_drop
>
0
:
x
=
F
.
dropout
(
x
,
inp_drop
,
training
=
self
.
training
,
inplace
=
True
)
kv
=
x
[:,
self
.
modality_cfg
.
num_extra_tokens
:]
assert
self
.
fixed_positional_encoder
is
not
None
pos
=
self
.
fixed_positional_encoder
(
x
,
None
).
expand
(
x
.
size
(
0
),
-
1
,
-
1
)
mask
=
mask_info
.
mask
.
bool
()
if
self
.
modality_cfg
.
decoder
.
add_positions_all
:
kv
=
kv
+
pos
[
~
mask
].
view
(
kv
.
shape
)
q
=
pos
[
mask
].
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))
return
q
,
kv
examples/data2vec/models/modalities/modules.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
dataclasses
import
dataclass
from
fairseq.modules
import
(
LayerNorm
,
SamePad
,
SamePad2d
,
TransposeLast
,
)
@
dataclass
class
D2vDecoderConfig
:
decoder_dim
:
int
=
384
decoder_groups
:
int
=
16
decoder_kernel
:
int
=
5
decoder_layers
:
int
=
5
input_dropout
:
float
=
0.1
add_positions_masked
:
bool
=
False
add_positions_all
:
bool
=
False
decoder_residual
:
bool
=
True
projection_layers
:
int
=
1
projection_ratio
:
float
=
2.0
class
FixedPositionalEncoder
(
nn
.
Module
):
def
__init__
(
self
,
pos_embed
):
super
().
__init__
()
self
.
positions
=
pos_embed
def
forward
(
self
,
x
,
padding_mask
):
return
self
.
positions
class
TextFeatPositionalEncoder
(
nn
.
Module
):
"""
Original encoder expects (B, T) long input. This module wraps it to take
local_encoder output which are (B, T, D) float tensors
"""
def
__init__
(
self
,
pos_encoder
):
super
().
__init__
()
self
.
pos_encoder
=
pos_encoder
def
forward
(
self
,
x
,
padding_mask
):
# assume padded token embeddings are 0s
# TODO: consider using padding_mask as input
return
self
.
pos_encoder
(
x
[...,
0
])
class
BlockEncoder
(
nn
.
Module
):
def
__init__
(
self
,
blocks
,
norm_layer
,
layer_norm_first
,
layerdrop
,
dropout
):
super
().
__init__
()
self
.
blocks
=
blocks
self
.
norm
=
norm_layer
self
.
layer_norm_first
=
layer_norm_first
self
.
layerdrop
=
layerdrop
self
.
dropout
=
nn
.
Dropout
(
dropout
,
inplace
=
True
)
def
forward
(
self
,
x
,
padding_mask
,
alibi_bias
,
alibi_scale
):
if
self
.
norm
is
not
None
and
not
self
.
layer_norm_first
:
x
=
self
.
norm
(
x
)
x
=
self
.
dropout
(
x
)
for
i
,
blk
in
enumerate
(
self
.
blocks
):
if
(
not
self
.
training
or
self
.
layerdrop
==
0
or
(
np
.
random
.
random
()
>
self
.
layerdrop
)
):
ab
=
alibi_bias
if
ab
is
not
None
and
alibi_scale
is
not
None
:
scale
=
(
alibi_scale
[
i
]
if
alibi_scale
.
size
(
0
)
>
1
else
alibi_scale
.
squeeze
(
0
)
)
ab
=
ab
*
scale
.
type_as
(
ab
)
x
,
_
=
blk
(
x
,
padding_mask
,
ab
)
if
self
.
norm
is
not
None
and
self
.
layer_norm_first
:
x
=
self
.
norm
(
x
)
return
x
class
DecoderBase
(
nn
.
Module
):
decoder_cfg
:
D2vDecoderConfig
def
__init__
(
self
,
cfg
:
D2vDecoderConfig
):
super
().
__init__
()
self
.
decoder_cfg
=
cfg
def
reset_parameters
(
self
):
for
mod
in
self
.
proj
.
modules
():
if
isinstance
(
mod
,
nn
.
Linear
):
mod
.
reset_parameters
()
def
add_residual
(
self
,
x
,
residual
,
i
,
mask_info
):
if
(
residual
is
None
or
not
self
.
decoder_cfg
.
decoder_residual
or
residual
.
size
(
1
)
!=
x
.
size
(
1
)
):
return
x
ret
=
x
+
residual
return
ret
class
Decoder1d
(
DecoderBase
):
def
__init__
(
self
,
cfg
:
D2vDecoderConfig
,
input_dim
):
super
().
__init__
(
cfg
)
def
make_block
(
in_dim
):
block
=
[
nn
.
Conv1d
(
in_dim
,
cfg
.
decoder_dim
,
kernel_size
=
cfg
.
decoder_kernel
,
padding
=
cfg
.
decoder_kernel
//
2
,
groups
=
cfg
.
decoder_groups
,
),
SamePad
(
cfg
.
decoder_kernel
),
TransposeLast
(),
LayerNorm
(
cfg
.
decoder_dim
,
elementwise_affine
=
False
),
TransposeLast
(),
nn
.
GELU
(),
]
return
nn
.
Sequential
(
*
block
)
self
.
blocks
=
nn
.
Sequential
(
*
[
make_block
(
input_dim
if
i
==
0
else
cfg
.
decoder_dim
)
for
i
in
range
(
cfg
.
decoder_layers
)
]
)
projs
=
[]
curr_dim
=
cfg
.
decoder_dim
for
i
in
range
(
cfg
.
projection_layers
-
1
):
next_dim
=
int
(
curr_dim
*
cfg
.
projection_ratio
)
if
i
==
0
else
curr_dim
projs
.
append
(
nn
.
Linear
(
curr_dim
,
next_dim
))
projs
.
append
(
nn
.
GELU
())
curr_dim
=
next_dim
projs
.
append
(
nn
.
Linear
(
curr_dim
,
input_dim
))
if
len
(
projs
)
==
1
:
self
.
proj
=
projs
[
0
]
else
:
self
.
proj
=
nn
.
Sequential
(
*
projs
)
def
forward
(
self
,
x
,
mask_info
):
x
=
x
.
transpose
(
1
,
2
)
residual
=
x
for
i
,
layer
in
enumerate
(
self
.
blocks
):
x
=
layer
(
x
)
x
=
self
.
add_residual
(
x
,
residual
,
i
,
mask_info
)
residual
=
x
x
=
x
.
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
Decoder2d
(
DecoderBase
):
def
__init__
(
self
,
cfg
:
D2vDecoderConfig
,
input_dim
,
h_size
,
w_size
):
super
().
__init__
(
cfg
)
self
.
h_size
=
h_size
self
.
w_size
=
w_size
def
make_block
(
in_dim
):
block
=
[
nn
.
Conv2d
(
in_dim
,
cfg
.
decoder_dim
,
kernel_size
=
cfg
.
decoder_kernel
,
padding
=
cfg
.
decoder_kernel
//
2
,
groups
=
cfg
.
decoder_groups
,
),
SamePad2d
(
cfg
.
decoder_kernel
),
TransposeLast
(
tranpose_dim
=-
3
),
LayerNorm
(
cfg
.
decoder_dim
,
elementwise_affine
=
False
),
TransposeLast
(
tranpose_dim
=-
3
),
nn
.
GELU
(),
]
return
nn
.
Sequential
(
*
block
)
self
.
blocks
=
nn
.
Sequential
(
*
[
make_block
(
input_dim
if
i
==
0
else
cfg
.
decoder_dim
)
for
i
in
range
(
cfg
.
decoder_layers
)
]
)
self
.
proj
=
nn
.
Linear
(
cfg
.
decoder_dim
,
input_dim
)
def
forward
(
self
,
x
,
mask_info
):
B
,
T
,
C
=
x
.
shape
x
=
x
.
transpose
(
1
,
2
).
reshape
(
B
,
C
,
self
.
h_size
,
self
.
w_size
)
residual
=
x
for
i
,
layer
in
enumerate
(
self
.
blocks
):
x
=
layer
(
x
)
x
=
self
.
add_residual
(
x
,
residual
,
i
,
mask_info
)
residual
=
x
x
=
x
.
reshape
(
B
,
-
1
,
T
).
transpose
(
1
,
2
)
x
=
self
.
proj
(
x
)
return
x
class
TransformerDecoder
(
nn
.
Module
):
decoder_cfg
:
D2vDecoderConfig
def
__init__
(
self
,
cfg
:
D2vDecoderConfig
,
input_dim
,
encoder
):
super
().
__init__
()
self
.
decoder_cfg
=
cfg
self
.
input_proj
=
nn
.
Linear
(
input_dim
,
cfg
.
decoder_dim
)
self
.
encoder
=
encoder
self
.
proj
=
nn
.
Linear
(
cfg
.
decoder_dim
,
input_dim
)
def
reset_parameters
(
self
):
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
self
.
apply
(
init_bert_params
)
def
forward
(
self
,
x
,
mask_info
):
x
=
self
.
input_proj
(
x
)
x
=
self
.
encoder
(
x
,
None
,
None
,
1
)
x
=
self
.
proj
(
x
)
return
x
class
AltBlock
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
mlp_drop
=
0.0
,
post_mlp_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
layer_norm_first
=
True
,
ffn_targets
=
False
,
cosine_attention
=
False
,
):
super
().
__init__
()
self
.
layer_norm_first
=
layer_norm_first
self
.
ffn_targets
=
ffn_targets
from
timm.models.vision_transformer
import
DropPath
,
Mlp
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
AltAttention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
cosine_attention
=
cosine_attention
,
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
mlp_drop
,
)
self
.
post_mlp_dropout
=
nn
.
Dropout
(
post_mlp_drop
,
inplace
=
False
)
def
forward
(
self
,
x
,
padding_mask
=
None
,
alibi_bias
=
None
):
if
self
.
layer_norm_first
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
padding_mask
,
alibi_bias
))
r
=
x
=
self
.
mlp
(
self
.
norm2
(
x
))
t
=
x
x
=
r
+
self
.
drop_path
(
self
.
post_mlp_dropout
(
x
))
if
not
self
.
ffn_targets
:
t
=
x
else
:
x
=
x
+
self
.
drop_path
(
self
.
attn
(
x
,
padding_mask
,
alibi_bias
))
r
=
x
=
self
.
norm1
(
x
)
x
=
self
.
mlp
(
x
)
t
=
x
x
=
self
.
norm2
(
r
+
self
.
drop_path
(
self
.
post_mlp_dropout
(
x
)))
if
not
self
.
ffn_targets
:
t
=
x
return
x
,
t
class
AltAttention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
cosine_attention
=
False
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
cosine_attention
=
cosine_attention
if
cosine_attention
:
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
log
(
10
*
torch
.
ones
((
num_heads
,
1
,
1
))),
requires_grad
=
True
)
def
forward
(
self
,
x
,
padding_mask
=
None
,
alibi_bias
=
None
):
B
,
N
,
C
=
x
.
shape
qkv
=
(
self
.
qkv
(
x
)
.
reshape
(
B
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
)
.
permute
(
2
,
0
,
3
,
1
,
4
)
# qkv x B x H x L x D
)
q
,
k
,
v
=
(
qkv
[
0
],
qkv
[
1
],
qkv
[
2
],
)
# make torchscript happy (cannot use tensor as tuple)
dtype
=
q
.
dtype
if
self
.
cosine_attention
:
# cosine attention
attn
=
F
.
normalize
(
q
,
dim
=-
1
)
@
F
.
normalize
(
k
,
dim
=-
1
).
transpose
(
-
2
,
-
1
)
logit_scale
=
torch
.
clamp
(
self
.
logit_scale
,
max
=
torch
.
log
(
torch
.
tensor
(
1.0
/
0.01
))
).
exp
()
attn
=
attn
*
logit_scale
else
:
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
if
alibi_bias
is
not
None
:
attn
=
attn
.
type_as
(
alibi_bias
)
attn
[:,
:
alibi_bias
.
size
(
1
)]
+=
alibi_bias
if
padding_mask
is
not
None
and
padding_mask
.
any
():
attn
=
attn
.
masked_fill
(
padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
),
float
(
"-inf"
),
)
attn
=
attn
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
dtype
=
dtype
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
)
#
x
=
x
.
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
EncDecAttention
(
nn
.
Module
):
def
__init__
(
self
,
q_dim
,
kv_dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.0
,
proj_drop
=
0.0
,
cosine_attention
=
False
,
):
super
().
__init__
()
self
.
num_heads
=
num_heads
head_dim
=
q_dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
q_proj
=
nn
.
Linear
(
q_dim
,
q_dim
,
bias
=
qkv_bias
)
self
.
kv_proj
=
nn
.
Linear
(
kv_dim
,
2
*
q_dim
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
q_dim
,
q_dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
cosine_attention
=
cosine_attention
if
cosine_attention
:
self
.
logit_scale
=
nn
.
Parameter
(
torch
.
log
(
10
*
torch
.
ones
((
num_heads
,
1
,
1
))),
requires_grad
=
True
)
def
forward
(
self
,
q
,
kv
,
padding_mask
=
None
,
alibi_bias
=
None
):
B
,
N
,
C
=
q
.
shape
q
=
(
self
.
q_proj
(
q
)
.
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
)
.
permute
(
0
,
2
,
1
,
3
)
)
# B x H x L x D
kv
=
(
self
.
kv_proj
(
kv
)
.
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
)
.
permute
(
2
,
0
,
3
,
1
,
4
)
)
# kv x B x H x L x D
k
,
v
=
(
kv
[
0
],
kv
[
1
],
)
# make torchscript happy (cannot use tensor as tuple)
dtype
=
q
.
dtype
if
self
.
cosine_attention
:
# cosine attention
attn
=
F
.
normalize
(
q
,
dim
=-
1
)
@
F
.
normalize
(
k
,
dim
=-
1
).
transpose
(
-
2
,
-
1
)
logit_scale
=
torch
.
clamp
(
self
.
logit_scale
,
max
=
torch
.
log
(
torch
.
tensor
(
1.0
/
0.01
))
).
exp
()
attn
=
attn
*
logit_scale
else
:
q
=
q
*
self
.
scale
attn
=
q
@
k
.
transpose
(
-
2
,
-
1
)
if
alibi_bias
is
not
None
:
attn
=
attn
.
type_as
(
alibi_bias
)
attn
[:,
:
alibi_bias
.
size
(
1
)]
+=
alibi_bias
if
padding_mask
is
not
None
and
padding_mask
.
any
():
attn
=
attn
.
masked_fill
(
padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
),
float
(
"-inf"
),
)
attn
=
attn
.
softmax
(
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
dtype
=
dtype
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
)
#
x
=
x
.
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
EncDecBlock
(
nn
.
Module
):
def
__init__
(
self
,
q_dim
,
kv_dim
,
num_heads
,
mlp_ratio
=
4.0
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
mlp_drop
=
0.0
,
post_mlp_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
layer_norm_first
=
True
,
cosine_attention
=
False
,
first_residual
=
True
,
):
super
().
__init__
()
self
.
layer_norm_first
=
layer_norm_first
from
timm.models.vision_transformer
import
DropPath
,
Mlp
self
.
norm1
=
norm_layer
(
q_dim
)
self
.
attn
=
EncDecAttention
(
q_dim
,
kv_dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
cosine_attention
=
cosine_attention
,
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.0
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
q_dim
)
mlp_hidden_dim
=
int
(
q_dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
q_dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
mlp_drop
,
)
self
.
post_mlp_dropout
=
nn
.
Dropout
(
post_mlp_drop
,
inplace
=
False
)
self
.
first_residual
=
first_residual
def
forward
(
self
,
q
,
kv
,
padding_mask
=
None
,
alibi_bias
=
None
):
r
=
q
if
self
.
first_residual
else
0
if
self
.
layer_norm_first
:
x
=
r
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
q
),
kv
,
padding_mask
,
alibi_bias
)
)
r
=
x
=
self
.
mlp
(
self
.
norm2
(
x
))
x
=
r
+
self
.
drop_path
(
self
.
post_mlp_dropout
(
x
))
else
:
x
=
r
+
self
.
drop_path
(
self
.
attn
(
q
,
kv
,
padding_mask
,
alibi_bias
))
r
=
x
=
self
.
norm1
(
x
)
x
=
self
.
mlp
(
x
)
x
=
self
.
norm2
(
r
+
self
.
drop_path
(
self
.
post_mlp_dropout
(
x
)))
return
x
class
EncDecTransformerDecoder
(
nn
.
Module
):
def
__init__
(
self
,
cfg
:
D2vDecoderConfig
,
input_dim
):
super
().
__init__
()
self
.
input_proj
=
nn
.
Linear
(
input_dim
,
cfg
.
decoder_dim
)
self
.
blocks
=
nn
.
Sequential
(
*
[
EncDecBlock
(
q_dim
=
cfg
.
decoder_dim
,
kv_dim
=
input_dim
,
num_heads
=
8
,
mlp_ratio
=
4.0
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.0
,
attn_drop
=
0.0
,
mlp_drop
=
0.0
,
post_mlp_drop
=
0.0
,
drop_path
=
0.0
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
,
layer_norm_first
=
False
,
cosine_attention
=
False
,
first_residual
=
i
>
0
,
)
for
i
in
range
(
cfg
.
decoder_layers
)
]
)
self
.
proj
=
nn
.
Linear
(
cfg
.
decoder_dim
,
input_dim
)
def
reset_parameters
(
self
):
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
self
.
apply
(
init_bert_params
)
def
forward
(
self
,
x
,
kv
):
x
=
self
.
input_proj
(
x
)
for
i
,
layer
in
enumerate
(
self
.
blocks
):
x
=
layer
(
x
,
kv
)
x
=
self
.
proj
(
x
)
return
x
examples/data2vec/models/modalities/text.py
0 → 100644
View file @
72f5785f
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
from
dataclasses
import
dataclass
from
functools
import
partial
from
typing
import
Callable
,
Dict
,
Optional
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
numpy
as
np
from
fairseq.modules
import
PositionalEmbedding
,
FairseqDropout
,
LayerNorm
from
fairseq.tasks
import
FairseqTask
from
.base
import
D2vModalityConfig
,
ModalitySpecificEncoder
,
get_alibi_bias
from
.modules
import
BlockEncoder
,
Decoder1d
from
examples.data2vec.data.modality
import
Modality
@
dataclass
class
D2vTextConfig
(
D2vModalityConfig
):
type
:
Modality
=
Modality
.
TEXT
max_source_positions
:
int
=
512
learned_pos
:
bool
=
True
dropout
:
float
=
0.1
# used for both local_encoder and contextualized encoder. tied with global transformer in data2vec_text
no_scale_embedding
:
bool
=
True
layernorm_embedding
:
bool
=
True
no_token_positional_embeddings
:
bool
=
False
class
TextEncoder
(
ModalitySpecificEncoder
):
modality_cfg
:
D2vTextConfig
def
__init__
(
self
,
modality_cfg
:
D2vTextConfig
,
embed_dim
:
int
,
make_block
:
Callable
[[
float
],
nn
.
ModuleList
],
norm_layer
:
Callable
[[
int
],
nn
.
LayerNorm
],
layer_norm_first
:
bool
,
alibi_biases
:
Dict
,
task
:
Optional
[
FairseqTask
],
):
self
.
pad_idx
=
task
.
source_dictionary
.
pad
()
self
.
vocab_size
=
len
(
task
.
source_dictionary
)
local_encoder
=
TextLocalEncoder
(
vocab_size
=
self
.
vocab_size
,
embed_dim
=
embed_dim
,
max_source_positions
=
modality_cfg
.
max_source_positions
,
pad_idx
=
self
.
pad_idx
,
no_scale_embedding
=
modality_cfg
.
no_scale_embedding
,
layernorm_embedding
=
modality_cfg
.
layernorm_embedding
,
dropout
=
modality_cfg
.
dropout
,
no_token_positional_embeddings
=
modality_cfg
.
no_token_positional_embeddings
,
learned_pos
=
modality_cfg
.
learned_pos
,
)
dpr
=
np
.
linspace
(
modality_cfg
.
start_drop_path_rate
,
modality_cfg
.
end_drop_path_rate
,
modality_cfg
.
prenet_depth
,
)
context_encoder
=
BlockEncoder
(
nn
.
ModuleList
(
make_block
(
dpr
[
i
])
for
i
in
range
(
modality_cfg
.
prenet_depth
)),
norm_layer
(
embed_dim
)
if
not
layer_norm_first
and
modality_cfg
.
prenet_depth
>
0
else
None
,
layer_norm_first
,
modality_cfg
.
prenet_layerdrop
,
modality_cfg
.
prenet_dropout
if
modality_cfg
.
prenet_depth
>
0
else
0.0
,
)
decoder
=
(
Decoder1d
(
modality_cfg
.
decoder
,
embed_dim
)
if
modality_cfg
.
decoder
is
not
None
else
None
)
alibi_bias_fn
=
partial
(
get_alibi_bias
,
alibi_biases
=
alibi_biases
)
super
().
__init__
(
modality_cfg
=
modality_cfg
,
embed_dim
=
embed_dim
,
local_encoder
=
local_encoder
,
project_features
=
nn
.
Identity
(),
fixed_positional_encoder
=
None
,
relative_positional_encoder
=
None
,
context_encoder
=
context_encoder
,
decoder
=
decoder
,
get_alibi_bias
=
alibi_bias_fn
,
)
def
reset_parameters
(
self
):
super
().
reset_parameters
()
def
convert_padding_mask
(
self
,
x
,
padding_mask
):
if
padding_mask
is
None
or
padding_mask
.
size
(
1
)
==
x
.
size
(
1
):
return
padding_mask
diff
=
self
.
downsample
-
padding_mask
.
size
(
1
)
%
self
.
downsample
if
0
<
diff
<
self
.
downsample
:
padding_mask
=
F
.
pad
(
padding_mask
,
(
0
,
diff
),
value
=
True
)
padding_mask
=
padding_mask
.
view
(
padding_mask
.
size
(
0
),
-
1
,
self
.
downsample
)
padding_mask
=
padding_mask
.
all
(
-
1
)
if
padding_mask
.
size
(
1
)
>
x
.
size
(
1
):
padding_mask
=
padding_mask
[:,
:
x
.
size
(
1
)]
assert
x
.
size
(
1
)
==
padding_mask
.
size
(
1
),
f
"
{
x
.
size
(
1
),
padding_mask
.
size
(
1
),
diff
,
self
.
downsample
}
"
return
padding_mask
class
TextLocalEncoder
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
embed_dim
,
max_source_positions
,
pad_idx
,
no_scale_embedding
,
layernorm_embedding
,
dropout
,
no_token_positional_embeddings
,
learned_pos
,
):
super
().
__init__
()
self
.
pad_idx
=
pad_idx
self
.
dropout_module
=
FairseqDropout
(
dropout
)
self
.
embed_tokens
=
nn
.
Embedding
(
vocab_size
,
embed_dim
,
pad_idx
)
self
.
embed_scale
=
1.0
if
no_scale_embedding
else
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
(
PositionalEmbedding
(
max_source_positions
,
embed_dim
,
pad_idx
,
learned
=
learned_pos
,
)
if
not
no_token_positional_embeddings
else
None
)
self
.
embed_scale
=
1.0
if
no_scale_embedding
else
math
.
sqrt
(
embed_dim
)
self
.
layernorm_embedding
=
None
if
layernorm_embedding
:
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
)
def
forward
(
self
,
src_tokens
):
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
src_tokens
)
if
self
.
embed_positions
is
not
None
:
x
=
x
+
self
.
embed_positions
(
src_tokens
)
if
self
.
layernorm_embedding
is
not
None
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
dropout_module
(
x
)
return
x
examples/data2vec/models/utils.py
0 → 100644
View file @
72f5785f
import
math
import
torch
def
get_alibi
(
max_positions
:
int
,
attention_heads
:
int
,
):
def
get_slopes
(
n
):
def
get_slopes_power_of_2
(
n
):
start
=
2
**
(
-
(
2
**
-
(
math
.
log2
(
n
)
-
3
)))
ratio
=
start
return
[
start
*
ratio
**
i
for
i
in
range
(
n
)]
# In the paper, we only train models that have 2^a heads for some
# a. This function has some good properties that only occur when
# the input is a power of 2. To maintain that even when the number
# of heads is not a power of 2, we use this workaround.
if
math
.
log2
(
n
).
is_integer
():
return
get_slopes_power_of_2
(
n
)
else
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
n
))
return
(
get_slopes_power_of_2
(
closest_power_of_2
)
+
get_slopes
(
2
*
closest_power_of_2
)[
0
::
2
][:
n
-
closest_power_of_2
]
)
maxpos
=
max_positions
attn_heads
=
attention_heads
slopes
=
torch
.
Tensor
(
get_slopes
(
attn_heads
))
# prepare alibi position linear bias. Note that wav2vec2 is non
# autoregressive model so we want a symmetric mask with 0 on the
# diagonal and other wise linear decreasing valuees
pos_bias
=
(
torch
.
abs
(
torch
.
arange
(
maxpos
).
unsqueeze
(
0
)
-
torch
.
arange
(
maxpos
).
unsqueeze
(
1
)
)
*
-
1
)
alibi_bias
=
slopes
.
unsqueeze
(
1
).
unsqueeze
(
1
)
*
pos_bias
.
unsqueeze
(
0
).
expand
(
attn_heads
,
-
1
,
-
1
)
return
alibi_bias
def
masked_alibi
(
alibi_bias
,
mask_indices
,
orig_B
,
orig_T
):
alibi_bias
=
alibi_bias
.
view
(
orig_B
,
-
1
,
orig_T
,
orig_T
)
H
=
alibi_bias
.
size
(
1
)
alibi_mask
=
mask_indices
.
unsqueeze
(
1
)
alibi_bias
=
alibi_bias
.
masked_select
(
alibi_mask
.
unsqueeze
(
-
1
))
alibi_bias
=
alibi_bias
.
view
(
orig_B
,
H
,
-
1
,
orig_T
)
M
=
alibi_bias
.
size
(
-
2
)
alibi_bias
=
alibi_bias
.
masked_select
(
alibi_mask
.
unsqueeze
(
-
2
))
alibi_bias
=
alibi_bias
.
view
(
-
1
,
M
,
M
)
return
alibi_bias
examples/data2vec/scripts/convert_audioset_labels.py
0 → 100644
View file @
72f5785f
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
argparse
import
os
def
get_parser
():
parser
=
argparse
.
ArgumentParser
(
description
=
"convert audioset labels"
)
# fmt: off
parser
.
add_argument
(
'in_file'
,
help
=
'audioset csv file to convert'
)
parser
.
add_argument
(
'--manifest'
,
required
=
True
,
metavar
=
'PATH'
,
help
=
'wav2vec-like manifest'
)
parser
.
add_argument
(
'--descriptors'
,
required
=
True
,
metavar
=
'PATH'
,
help
=
'path to label descriptor file'
)
parser
.
add_argument
(
'--output'
,
required
=
True
,
metavar
=
'PATH'
,
help
=
'where to output converted labels'
)
# fmt: on
return
parser
def
main
():
parser
=
get_parser
()
args
=
parser
.
parse_args
()
label_descriptors
=
{}
with
open
(
args
.
descriptors
,
"r"
)
as
ldf
:
next
(
ldf
)
for
line
in
ldf
:
if
line
.
strip
()
==
""
:
continue
items
=
line
.
split
(
","
)
assert
len
(
items
)
>
2
,
line
idx
=
items
[
0
]
lbl
=
items
[
1
]
assert
lbl
not
in
label_descriptors
,
lbl
label_descriptors
[
lbl
]
=
idx
labels
=
{}
with
open
(
args
.
in_file
,
"r"
)
as
ifd
:
for
line
in
ifd
:
if
line
.
lstrip
().
startswith
(
"#"
):
continue
items
=
line
.
rstrip
().
split
(
","
)
id
=
items
[
0
].
strip
()
start
=
items
[
1
].
strip
()
end
=
items
[
2
].
strip
()
lbls
=
[
label_descriptors
[
it
.
strip
(
' "'
)]
for
it
in
items
[
3
:]]
labels
[
id
]
=
[
start
,
end
,
","
.
join
(
lbls
)]
with
open
(
args
.
manifest
,
"r"
)
as
mf
,
open
(
args
.
output
,
"w"
)
as
of
:
next
(
mf
)
for
line
in
mf
:
path
,
_
=
line
.
split
(
"
\t
"
)
id
=
os
.
path
.
splitext
(
os
.
path
.
basename
(
path
))[
0
]
lbl
=
labels
[
id
]
print
(
"
\t
"
.
join
(
lbl
),
file
=
of
)
if
__name__
==
"__main__"
:
main
()
Prev
1
…
13
14
15
16
17
18
19
20
21
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment