Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
fairseq-wav2vec_pytorch
Commits
18d27e00
Commit
18d27e00
authored
Aug 27, 2024
by
wangwei990215
Browse files
initial commit
parent
541f4c7a
Changes
789
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
729 additions
and
0 deletions
+729
-0
fairseq/fairseq/models/bart/__pycache__/__init__.cpython-38.pyc
...q/fairseq/models/bart/__pycache__/__init__.cpython-38.pyc
+0
-0
fairseq/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc
...seq/models/bart/__pycache__/hub_interface.cpython-310.pyc
+0
-0
fairseq/fairseq/models/bart/__pycache__/hub_interface.cpython-38.pyc
...rseq/models/bart/__pycache__/hub_interface.cpython-38.pyc
+0
-0
fairseq/fairseq/models/bart/__pycache__/model.cpython-310.pyc
...seq/fairseq/models/bart/__pycache__/model.cpython-310.pyc
+0
-0
fairseq/fairseq/models/bart/__pycache__/model.cpython-38.pyc
fairseq/fairseq/models/bart/__pycache__/model.cpython-38.pyc
+0
-0
fairseq/fairseq/models/bart/hub_interface.py
fairseq/fairseq/models/bart/hub_interface.py
+201
-0
fairseq/fairseq/models/bart/model.py
fairseq/fairseq/models/bart/model.py
+368
-0
fairseq/fairseq/models/composite_encoder.py
fairseq/fairseq/models/composite_encoder.py
+57
-0
fairseq/fairseq/models/distributed_fairseq_model.py
fairseq/fairseq/models/distributed_fairseq_model.py
+103
-0
No files found.
Too many changes to show.
To preserve performance only
789 of 789+
files are displayed.
Plain diff
Email patch
fairseq/fairseq/models/bart/__pycache__/__init__.cpython-38.pyc
0 → 100644
View file @
18d27e00
File added
fairseq/fairseq/models/bart/__pycache__/hub_interface.cpython-310.pyc
0 → 100644
View file @
18d27e00
File added
fairseq/fairseq/models/bart/__pycache__/hub_interface.cpython-38.pyc
0 → 100644
View file @
18d27e00
File added
fairseq/fairseq/models/bart/__pycache__/model.cpython-310.pyc
0 → 100644
View file @
18d27e00
File added
fairseq/fairseq/models/bart/__pycache__/model.cpython-38.pyc
0 → 100644
View file @
18d27e00
File added
fairseq/fairseq/models/bart/hub_interface.py
0 → 100644
View file @
18d27e00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
copy
import
logging
from
typing
import
List
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.data
import
encoders
logger
=
logging
.
getLogger
(
__name__
)
class
BARTHubInterface
(
nn
.
Module
):
"""A simple PyTorch Hub interface to BART.
Usage: https://github.com/pytorch/fairseq/tree/master/examples/bart
"""
def
__init__
(
self
,
args
,
task
,
model
):
super
().
__init__
()
self
.
args
=
args
self
.
task
=
task
self
.
model
=
model
self
.
bpe
=
encoders
.
build_bpe
(
args
)
self
.
max_positions
=
min
(
utils
.
resolve_max_positions
(
self
.
task
.
max_positions
(),
self
.
model
.
max_positions
(),
)
)
# this is useful for determining the device
self
.
register_buffer
(
"_float_tensor"
,
torch
.
tensor
([
0
],
dtype
=
torch
.
float
))
@
property
def
device
(
self
):
return
self
.
_float_tensor
.
device
def
encode
(
self
,
sentence
:
str
,
*
addl_sentences
,
no_separator
=
True
)
->
torch
.
LongTensor
:
"""
BPE-encode a sentence (or multiple sentences).
Every sequence begins with a beginning-of-sentence (`<s>`) symbol.
Every sentence ends with an end-of-sentence (`</s>`).
Example (single sentence): `<s> a b c </s>`
Example (sentence pair): `<s> d e f </s> 1 2 3 </s>`
The BPE encoding follows GPT-2. One subtle detail is that the GPT-2 BPE
requires leading spaces. For example::
>>> bart.encode('Hello world').tolist()
[0, 31414, 232, 2]
>>> bart.encode(' world').tolist()
[0, 232, 2]
>>> bart.encode('world').tolist()
[0, 8331, 2]
"""
tokens
=
self
.
bpe
.
encode
(
sentence
)
if
len
(
tokens
.
split
(
" "
))
>
self
.
max_positions
-
2
:
tokens
=
" "
.
join
(
tokens
.
split
(
" "
)[:
self
.
max_positions
-
2
])
bpe_sentence
=
"<s> "
+
tokens
+
" </s>"
for
s
in
addl_sentences
:
bpe_sentence
+=
" </s>"
if
not
no_separator
else
""
bpe_sentence
+=
" "
+
self
.
bpe
.
encode
(
s
)
+
" </s>"
tokens
=
self
.
task
.
source_dictionary
.
encode_line
(
bpe_sentence
,
append_eos
=
False
)
return
tokens
.
long
()
def
decode
(
self
,
tokens
:
torch
.
LongTensor
):
assert
tokens
.
dim
()
==
1
tokens
=
tokens
.
cpu
().
numpy
()
if
tokens
[
0
]
==
self
.
task
.
source_dictionary
.
bos
():
tokens
=
tokens
[
1
:]
# remove <s>
eos_mask
=
tokens
==
self
.
task
.
source_dictionary
.
eos
()
doc_mask
=
eos_mask
[
1
:]
&
eos_mask
[:
-
1
]
sentences
=
np
.
split
(
tokens
,
doc_mask
.
nonzero
()[
0
]
+
1
)
sentences
=
[
self
.
bpe
.
decode
(
self
.
task
.
source_dictionary
.
string
(
s
))
for
s
in
sentences
]
if
len
(
sentences
)
==
1
:
return
sentences
[
0
]
return
sentences
def
_build_sample
(
self
,
src_tokens
:
List
[
torch
.
LongTensor
]):
# assert torch.is_tensor(src_tokens)
dataset
=
self
.
task
.
build_dataset_for_inference
(
src_tokens
,
[
x
.
numel
()
for
x
in
src_tokens
],
)
sample
=
dataset
.
collater
(
dataset
)
sample
=
utils
.
apply_to_sample
(
lambda
tensor
:
tensor
.
to
(
self
.
device
),
sample
)
return
sample
def
sample
(
self
,
sentences
:
List
[
str
],
beam
:
int
=
1
,
verbose
:
bool
=
False
,
**
kwargs
)
->
str
:
input
=
[
self
.
encode
(
sentence
)
for
sentence
in
sentences
]
hypos
=
self
.
generate
(
input
,
beam
,
verbose
,
**
kwargs
)
return
[
self
.
decode
(
x
[
"tokens"
])
for
x
in
hypos
]
def
generate
(
self
,
tokens
:
List
[
torch
.
LongTensor
],
beam
:
int
=
5
,
verbose
:
bool
=
False
,
**
kwargs
)
->
torch
.
LongTensor
:
sample
=
self
.
_build_sample
(
tokens
)
# build generator using current args as well as any kwargs
gen_args
=
copy
.
copy
(
self
.
args
)
gen_args
.
beam
=
beam
for
k
,
v
in
kwargs
.
items
():
setattr
(
gen_args
,
k
,
v
)
generator
=
self
.
task
.
build_generator
([
self
.
model
],
gen_args
)
translations
=
self
.
task
.
inference_step
(
generator
,
[
self
.
model
],
sample
,
prefix_tokens
=
sample
[
"net_input"
][
"src_tokens"
]
.
new_zeros
((
len
(
tokens
),
1
))
.
fill_
(
self
.
task
.
source_dictionary
.
bos
()),
)
if
verbose
:
src_str_with_unk
=
self
.
string
(
tokens
)
logger
.
info
(
"S
\t
{}"
.
format
(
src_str_with_unk
))
def
getarg
(
name
,
default
):
return
getattr
(
gen_args
,
name
,
getattr
(
self
.
args
,
name
,
default
))
# Process top predictions
hypos
=
[
x
[
0
]
for
x
in
translations
]
hypos
=
[
v
for
_
,
v
in
sorted
(
zip
(
sample
[
"id"
].
tolist
(),
hypos
))]
return
hypos
def
extract_features
(
self
,
tokens
:
torch
.
LongTensor
,
return_all_hiddens
:
bool
=
False
)
->
torch
.
Tensor
:
if
tokens
.
dim
()
==
1
:
tokens
=
tokens
.
unsqueeze
(
0
)
if
tokens
.
size
(
-
1
)
>
min
(
self
.
model
.
max_positions
()):
raise
ValueError
(
"tokens exceeds maximum length: {} > {}"
.
format
(
tokens
.
size
(
-
1
),
self
.
model
.
max_positions
()
)
)
tokens
.
to
(
device
=
self
.
device
),
prev_output_tokens
=
tokens
.
clone
()
prev_output_tokens
[:,
0
]
=
tokens
.
gather
(
1
,
(
tokens
.
ne
(
self
.
task
.
source_dictionary
.
pad
()).
sum
(
dim
=
1
)
-
1
).
unsqueeze
(
-
1
),
).
squeeze
()
prev_output_tokens
[:,
1
:]
=
tokens
[:,
:
-
1
]
features
,
extra
=
self
.
model
(
src_tokens
=
tokens
,
src_lengths
=
None
,
prev_output_tokens
=
prev_output_tokens
,
features_only
=
True
,
return_all_hiddens
=
return_all_hiddens
,
)
if
return_all_hiddens
:
# convert from T x B x C -> B x T x C
inner_states
=
extra
[
"inner_states"
]
return
[
inner_state
.
transpose
(
0
,
1
)
for
inner_state
in
inner_states
]
else
:
return
features
# just the last layer's features
def
register_classification_head
(
self
,
name
:
str
,
num_classes
:
int
=
None
,
embedding_size
:
int
=
None
,
**
kwargs
):
self
.
model
.
register_classification_head
(
name
,
num_classes
=
num_classes
,
embedding_size
=
embedding_size
,
**
kwargs
)
def
predict
(
self
,
head
:
str
,
tokens
:
torch
.
LongTensor
,
return_logits
:
bool
=
False
):
if
tokens
.
dim
()
==
1
:
tokens
=
tokens
.
unsqueeze
(
0
)
features
=
self
.
extract_features
(
tokens
.
to
(
device
=
self
.
device
))
sentence_representation
=
features
[
tokens
.
eq
(
self
.
task
.
source_dictionary
.
eos
()),
:
].
view
(
features
.
size
(
0
),
-
1
,
features
.
size
(
-
1
))[:,
-
1
,
:]
logits
=
self
.
model
.
classification_heads
[
head
](
sentence_representation
)
if
return_logits
:
return
logits
return
F
.
log_softmax
(
logits
,
dim
=-
1
)
fairseq/fairseq/models/bart/model.py
0 → 100644
View file @
18d27e00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
BART: Denoising Sequence-to-Sequence Pre-training for
Natural Language Generation, Translation, and Comprehension
"""
import
logging
import
torch
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.models
import
register_model
,
register_model_architecture
from
fairseq.models.transformer
import
TransformerModel
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
from
.hub_interface
import
BARTHubInterface
logger
=
logging
.
getLogger
(
__name__
)
@
register_model
(
"bart"
)
class
BARTModel
(
TransformerModel
):
@
classmethod
def
hub_models
(
cls
):
return
{
"bart.base"
:
"http://dl.fbaipublicfiles.com/fairseq/models/bart.base.tar.gz"
,
"bart.large"
:
"http://dl.fbaipublicfiles.com/fairseq/models/bart.large.tar.gz"
,
"bart.large.mnli"
:
"http://dl.fbaipublicfiles.com/fairseq/models/bart.large.mnli.tar.gz"
,
"bart.large.cnn"
:
"http://dl.fbaipublicfiles.com/fairseq/models/bart.large.cnn.tar.gz"
,
"bart.large.xsum"
:
"http://dl.fbaipublicfiles.com/fairseq/models/bart.large.xsum.tar.gz"
,
}
def
__init__
(
self
,
args
,
encoder
,
decoder
):
super
().
__init__
(
args
,
encoder
,
decoder
)
# We follow BERT's random weight initialization
self
.
apply
(
init_bert_params
)
self
.
classification_heads
=
nn
.
ModuleDict
()
@
staticmethod
def
add_args
(
parser
):
super
(
BARTModel
,
BARTModel
).
add_args
(
parser
)
parser
.
add_argument
(
"--pooler-dropout"
,
type
=
float
,
metavar
=
"D"
,
help
=
"dropout probability in the masked_lm pooler layers"
,
)
parser
.
add_argument
(
"--pooler-activation-fn"
,
choices
=
utils
.
get_available_activation_fns
(),
help
=
"activation function to use for pooler layer"
,
)
parser
.
add_argument
(
"--spectral-norm-classification-head"
,
action
=
"store_true"
,
help
=
"Apply spectral normalization on the classification head"
,
)
@
property
def
supported_targets
(
self
):
return
{
"self"
}
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
,
features_only
=
False
,
classification_head_name
=
None
,
token_embeddings
=
None
,
**
kwargs
,
):
if
classification_head_name
is
not
None
:
features_only
=
True
encoder_out
=
self
.
encoder
(
src_tokens
,
src_lengths
=
src_lengths
,
token_embeddings
=
token_embeddings
,
**
kwargs
,
)
x
,
extra
=
self
.
decoder
(
prev_output_tokens
,
encoder_out
=
encoder_out
,
features_only
=
features_only
,
**
kwargs
,
)
if
classification_head_name
is
not
None
:
sentence_representation
=
x
[
src_tokens
.
eq
(
self
.
encoder
.
dictionary
.
eos
()),
:
].
view
(
x
.
size
(
0
),
-
1
,
x
.
size
(
-
1
))[:,
-
1
,
:]
x
=
self
.
classification_heads
[
classification_head_name
](
sentence_representation
)
return
x
,
extra
@
classmethod
def
from_pretrained
(
cls
,
model_name_or_path
,
checkpoint_file
=
"model.pt"
,
data_name_or_path
=
"."
,
bpe
=
"gpt2"
,
**
kwargs
,
):
from
fairseq
import
hub_utils
x
=
hub_utils
.
from_pretrained
(
model_name_or_path
,
checkpoint_file
,
data_name_or_path
,
archive_map
=
cls
.
hub_models
(),
bpe
=
bpe
,
load_checkpoint_heads
=
True
,
**
kwargs
,
)
return
BARTHubInterface
(
x
[
"args"
],
x
[
"task"
],
x
[
"models"
][
0
])
def
register_classification_head
(
self
,
name
,
num_classes
=
None
,
inner_dim
=
None
,
**
kwargs
):
"""Register a classification head."""
logger
.
info
(
"Registering classification head: {0}"
.
format
(
name
))
if
name
in
self
.
classification_heads
:
prev_num_classes
=
self
.
classification_heads
[
name
].
out_proj
.
out_features
prev_inner_dim
=
self
.
classification_heads
[
name
].
dense
.
out_features
if
num_classes
!=
prev_num_classes
or
inner_dim
!=
prev_inner_dim
:
logger
.
warning
(
're-registering head "{}" with num_classes {} (prev: {}) '
"and inner_dim {} (prev: {})"
.
format
(
name
,
num_classes
,
prev_num_classes
,
inner_dim
,
prev_inner_dim
)
)
self
.
classification_heads
[
name
]
=
BARTClassificationHead
(
input_dim
=
self
.
args
.
encoder_embed_dim
,
inner_dim
=
inner_dim
or
self
.
args
.
encoder_embed_dim
,
num_classes
=
num_classes
,
activation_fn
=
self
.
args
.
pooler_activation_fn
,
pooler_dropout
=
self
.
args
.
pooler_dropout
,
do_spectral_norm
=
self
.
args
.
spectral_norm_classification_head
,
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
prefix
=
name
+
"."
if
name
!=
""
else
""
current_head_names
=
(
[]
if
not
hasattr
(
self
,
"classification_heads"
)
else
self
.
classification_heads
.
keys
()
)
# Handle new classification heads present in the state dict.
keys_to_delete
=
[]
for
k
in
state_dict
.
keys
():
if
not
k
.
startswith
(
prefix
+
"classification_heads."
):
continue
head_name
=
k
[
len
(
prefix
+
"classification_heads."
)
:].
split
(
"."
)[
0
]
num_classes
=
state_dict
[
prefix
+
"classification_heads."
+
head_name
+
".out_proj.weight"
].
size
(
0
)
inner_dim
=
state_dict
[
prefix
+
"classification_heads."
+
head_name
+
".dense.weight"
].
size
(
0
)
if
getattr
(
self
.
args
,
"load_checkpoint_heads"
,
False
):
if
head_name
not
in
current_head_names
:
self
.
register_classification_head
(
head_name
,
num_classes
,
inner_dim
)
else
:
if
head_name
not
in
current_head_names
:
logger
.
warning
(
"deleting classification head ({}) from checkpoint "
"not present in current model: {}"
.
format
(
head_name
,
k
)
)
keys_to_delete
.
append
(
k
)
elif
(
num_classes
!=
self
.
classification_heads
[
head_name
].
out_proj
.
out_features
or
inner_dim
!=
self
.
classification_heads
[
head_name
].
dense
.
out_features
):
logger
.
warning
(
"deleting classification head ({}) from checkpoint "
"with different dimensions than current model: {}"
.
format
(
head_name
,
k
)
)
keys_to_delete
.
append
(
k
)
for
k
in
keys_to_delete
:
del
state_dict
[
k
]
def
truncate_emb
(
key
):
if
key
in
state_dict
:
state_dict
[
key
]
=
state_dict
[
key
][:
-
1
,
:]
# When finetuning on translation task, remove last row of
# embedding matrix that corresponds to mask_idx token.
loaded_dict_size
=
state_dict
[
"encoder.embed_tokens.weight"
].
size
(
0
)
if
(
loaded_dict_size
==
len
(
self
.
encoder
.
dictionary
)
+
1
and
"<mask>"
not
in
self
.
encoder
.
dictionary
):
truncate_emb
(
"encoder.embed_tokens.weight"
)
truncate_emb
(
"decoder.embed_tokens.weight"
)
truncate_emb
(
"encoder.output_projection.weight"
)
truncate_emb
(
"decoder.output_projection.weight"
)
# When continued pretraining on new set of languages for mbart,
# add extra lang embeddings at the end of embed_tokens.
# Note: newly added languages are assumed to have been added at the end.
if
self
.
args
.
task
==
"multilingual_denoising"
and
loaded_dict_size
<
len
(
self
.
encoder
.
dictionary
):
logger
.
info
(
"Adding extra language embeddings not found in pretrained model for "
"continued pretraining of MBART on new set of languages."
)
loaded_mask_token_embedding
=
state_dict
[
"encoder.embed_tokens.weight"
][
-
1
,
:
]
num_langids_to_add
=
len
(
self
.
encoder
.
dictionary
)
-
loaded_dict_size
embed_dim
=
state_dict
[
"encoder.embed_tokens.weight"
].
size
(
1
)
new_lang_embed_to_add
=
torch
.
zeros
(
num_langids_to_add
,
embed_dim
)
nn
.
init
.
normal_
(
new_lang_embed_to_add
,
mean
=
0
,
std
=
embed_dim
**
-
0.5
)
new_lang_embed_to_add
=
new_lang_embed_to_add
.
to
(
dtype
=
state_dict
[
"encoder.embed_tokens.weight"
].
dtype
,
)
state_dict
[
"encoder.embed_tokens.weight"
]
=
torch
.
cat
(
[
state_dict
[
"encoder.embed_tokens.weight"
][
:
loaded_dict_size
-
1
,
:
],
new_lang_embed_to_add
,
loaded_mask_token_embedding
.
unsqueeze
(
0
),
]
)
state_dict
[
"decoder.embed_tokens.weight"
]
=
torch
.
cat
(
[
state_dict
[
"decoder.embed_tokens.weight"
][
:
loaded_dict_size
-
1
,
:
],
new_lang_embed_to_add
,
loaded_mask_token_embedding
.
unsqueeze
(
0
),
]
)
# Copy any newly-added classification heads into the state dict
# with their current weights.
if
hasattr
(
self
,
"classification_heads"
):
cur_state
=
self
.
classification_heads
.
state_dict
()
for
k
,
v
in
cur_state
.
items
():
if
prefix
+
"classification_heads."
+
k
not
in
state_dict
:
logger
.
info
(
"Overwriting"
,
prefix
+
"classification_heads."
+
k
)
state_dict
[
prefix
+
"classification_heads."
+
k
]
=
v
class
BARTClassificationHead
(
nn
.
Module
):
"""Head for sentence-level classification tasks."""
def
__init__
(
self
,
input_dim
,
inner_dim
,
num_classes
,
activation_fn
,
pooler_dropout
,
do_spectral_norm
=
False
,
):
super
().
__init__
()
self
.
dense
=
nn
.
Linear
(
input_dim
,
inner_dim
)
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation_fn
)
self
.
dropout
=
nn
.
Dropout
(
p
=
pooler_dropout
)
self
.
out_proj
=
nn
.
Linear
(
inner_dim
,
num_classes
)
if
do_spectral_norm
:
self
.
out_proj
=
torch
.
nn
.
utils
.
spectral_norm
(
self
.
out_proj
)
def
forward
(
self
,
features
,
**
kwargs
):
x
=
features
x
=
self
.
dropout
(
x
)
x
=
self
.
dense
(
x
)
x
=
self
.
activation_fn
(
x
)
x
=
self
.
dropout
(
x
)
x
=
self
.
out_proj
(
x
)
return
x
@
register_model_architecture
(
"bart"
,
"bart_large"
)
def
bart_large_architecture
(
args
):
args
.
encoder_embed_path
=
getattr
(
args
,
"encoder_embed_path"
,
None
)
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
1024
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
4
*
1024
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
12
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
16
)
args
.
encoder_normalize_before
=
getattr
(
args
,
"encoder_normalize_before"
,
False
)
args
.
encoder_learned_pos
=
getattr
(
args
,
"encoder_learned_pos"
,
True
)
args
.
decoder_embed_path
=
getattr
(
args
,
"decoder_embed_path"
,
None
)
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
args
.
encoder_embed_dim
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
args
.
encoder_ffn_embed_dim
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
12
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
decoder_normalize_before
=
getattr
(
args
,
"decoder_normalize_before"
,
False
)
args
.
decoder_learned_pos
=
getattr
(
args
,
"decoder_learned_pos"
,
True
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.0
)
args
.
relu_dropout
=
getattr
(
args
,
"relu_dropout"
,
0.0
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
max_target_positions
=
getattr
(
args
,
"max_target_positions"
,
1024
)
args
.
max_source_positions
=
getattr
(
args
,
"max_source_positions"
,
1024
)
args
.
adaptive_softmax_cutoff
=
getattr
(
args
,
"adaptive_softmax_cutoff"
,
None
)
args
.
adaptive_softmax_dropout
=
getattr
(
args
,
"adaptive_softmax_dropout"
,
0
)
args
.
share_decoder_input_output_embed
=
getattr
(
args
,
"share_decoder_input_output_embed"
,
True
)
args
.
share_all_embeddings
=
getattr
(
args
,
"share_all_embeddings"
,
True
)
args
.
decoder_output_dim
=
getattr
(
args
,
"decoder_output_dim"
,
args
.
decoder_embed_dim
)
args
.
decoder_input_dim
=
getattr
(
args
,
"decoder_input_dim"
,
args
.
decoder_embed_dim
)
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
True
)
args
.
layernorm_embedding
=
getattr
(
args
,
"layernorm_embedding"
,
True
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"gelu"
)
args
.
pooler_activation_fn
=
getattr
(
args
,
"pooler_activation_fn"
,
"tanh"
)
args
.
pooler_dropout
=
getattr
(
args
,
"pooler_dropout"
,
0.0
)
@
register_model_architecture
(
"bart"
,
"bart_base"
)
def
bart_base_architecture
(
args
):
args
.
encoder_embed_dim
=
getattr
(
args
,
"encoder_embed_dim"
,
768
)
args
.
encoder_ffn_embed_dim
=
getattr
(
args
,
"encoder_ffn_embed_dim"
,
4
*
768
)
args
.
encoder_layers
=
getattr
(
args
,
"encoder_layers"
,
6
)
args
.
encoder_attention_heads
=
getattr
(
args
,
"encoder_attention_heads"
,
12
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
6
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
12
)
bart_large_architecture
(
args
)
@
register_model_architecture
(
"bart"
,
"mbart_large"
)
def
mbart_large_architecture
(
args
):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
bart_large_architecture
(
args
)
@
register_model_architecture
(
"bart"
,
"mbart_base"
)
def
mbart_base_architecture
(
args
):
args
.
no_scale_embedding
=
getattr
(
args
,
"no_scale_embedding"
,
False
)
bart_base_architecture
(
args
)
@
register_model_architecture
(
"bart"
,
"mbart_base_wmt20"
)
def
mbart_base_wmt20_architecture
(
args
):
args
.
layernorm_embedding
=
getattr
(
args
,
"layernorm_embedding"
,
False
)
mbart_base_architecture
(
args
)
fairseq/fairseq/models/composite_encoder.py
0 → 100644
View file @
18d27e00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
.fairseq_encoder
import
FairseqEncoder
class
CompositeEncoder
(
FairseqEncoder
):
"""
A wrapper around a dictionary of :class:`FairseqEncoder` objects.
We run forward on each encoder and return a dictionary of outputs. The first
encoder's dictionary is used for initialization.
Args:
encoders (dict): a dictionary of :class:`FairseqEncoder` objects.
"""
def
__init__
(
self
,
encoders
):
super
().
__init__
(
next
(
iter
(
encoders
.
values
())).
dictionary
)
self
.
encoders
=
encoders
for
key
in
self
.
encoders
:
self
.
add_module
(
key
,
self
.
encoders
[
key
])
def
forward
(
self
,
src_tokens
,
src_lengths
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (LongTensor): lengths of each source sentence of shape
`(batch)`
Returns:
dict:
the outputs from each Encoder
"""
encoder_out
=
{}
for
key
in
self
.
encoders
:
encoder_out
[
key
]
=
self
.
encoders
[
key
](
src_tokens
,
src_lengths
)
return
encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
"""Reorder encoder output according to new_order."""
for
key
in
self
.
encoders
:
encoder_out
[
key
]
=
self
.
encoders
[
key
].
reorder_encoder_out
(
encoder_out
[
key
],
new_order
)
return
encoder_out
def
max_positions
(
self
):
return
min
(
self
.
encoders
[
key
].
max_positions
()
for
key
in
self
.
encoders
)
def
upgrade_state_dict
(
self
,
state_dict
):
for
key
in
self
.
encoders
:
self
.
encoders
[
key
].
upgrade_state_dict
(
state_dict
)
return
state_dict
fairseq/fairseq/models/distributed_fairseq_model.py
0 → 100644
View file @
18d27e00
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
inspect
import
torch.nn
as
nn
from
fairseq.legacy_distributed_data_parallel
import
LegacyDistributedDataParallel
_GOSSIP_DISABLED
=
False
try
:
import
gossip
except
ImportError
:
_GOSSIP_DISABLED
=
True
def
DistributedFairseqModel
(
args
,
model
,
process_group
=
None
):
"""
Wrap a *model* to support distributed data parallel training.
This is similar to the built-in DistributedDataParallel, but allows
additional configuration of the DistributedDataParallel class to
use, and also provides easier access to the wrapped model by
forwarding requests for missing attributes to the wrapped model.
Args:
args (argparse.Namespace): fairseq args
model (BaseFairseqModel): model to wrap
"""
# determine which DDP class to extend
assert
isinstance
(
model
,
nn
.
Module
)
if
args
.
distributed_wrapper
==
"DDP"
and
args
.
ddp_backend
==
"c10d"
:
ddp_class
=
nn
.
parallel
.
DistributedDataParallel
init_kwargs
=
dict
(
module
=
model
,
device_ids
=
[
args
.
device_id
],
output_device
=
args
.
device_id
,
broadcast_buffers
=
args
.
broadcast_buffers
,
bucket_cap_mb
=
args
.
bucket_cap_mb
,
process_group
=
process_group
,
)
# Maintain backward compatibility
if
"check_reduction"
in
inspect
.
getfullargspec
(
ddp_class
)[
0
]:
init_kwargs
[
"check_reduction"
]
=
True
if
"find_unused_parameters"
in
inspect
.
getfullargspec
(
ddp_class
)[
0
]:
init_kwargs
[
"find_unused_parameters"
]
=
args
.
find_unused_parameters
elif
args
.
distributed_wrapper
==
"DDP"
and
args
.
ddp_backend
==
"no_c10d"
:
ddp_class
=
LegacyDistributedDataParallel
init_kwargs
=
dict
(
module
=
model
,
world_size
=
args
.
distributed_world_size
,
buffer_size
=
2
**
28
,
process_group
=
process_group
,
)
elif
args
.
distributed_wrapper
==
"SlowMo"
:
if
_GOSSIP_DISABLED
:
raise
ImportError
(
"Cannot find gossip library. Please install from: "
"github.com/facebookresearch/stochastic_gradient_push"
)
ddp_class
=
gossip
.
GossipDataParallel
# The values of slowmo_momentum below were obtained by tuning on the
# En-De 16 dataset by training the transformer_wmt_en_de_large model
if
args
.
slowmo_momentum
is
None
:
if
args
.
distributed_world_size
<=
16
:
args
.
slowmo_momentum
=
0.0
elif
args
.
distributed_world_size
<=
32
:
args
.
slowmo_momentum
=
0.2
elif
args
.
distributed_world_size
<=
64
:
args
.
slowmo_momentum
=
0.5
else
:
args
.
slowmo_momentum
=
0.6
init_kwargs
=
dict
(
module
=
model
,
device_ids
=
[
args
.
device_id
],
output_device
=
args
.
device_id
,
broadcast_buffers
=
args
.
broadcast_buffers
,
nprocs_per_node
=
args
.
nprocs_per_node
,
slowmo_momentum
=
args
.
slowmo_momentum
,
localsgd
=
(
args
.
slowmo_algorithm
==
"LocalSGD"
),
localsgd_frequency
=
args
.
localsgd_frequency
,
)
else
:
raise
ValueError
(
"Unknown --ddp-backend: "
+
args
.
ddp_backend
)
class
_DistributedFairseqModel
(
ddp_class
):
"""Extend DistributedDataParallel to check for missing
attributes in the wrapped module."""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
def
__getattr__
(
self
,
name
):
wrapped_module
=
super
().
__getattr__
(
"module"
)
if
hasattr
(
wrapped_module
,
name
):
return
getattr
(
wrapped_module
,
name
)
return
super
().
__getattr__
(
name
)
return
_DistributedFairseqModel
(
**
init_kwargs
)
Prev
1
…
36
37
38
39
40
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment