Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
SpeechT5_pytorch
Commits
12c90639
Commit
12c90639
authored
Sep 28, 2024
by
“change”
Browse files
init
parent
417b607b
Changes
350
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
4035 additions
and
0 deletions
+4035
-0
Speech2S/speech2s/data/load_langpair_dataset.py
Speech2S/speech2s/data/load_langpair_dataset.py
+172
-0
Speech2S/speech2s/data/multimodal_corpus_dataset.py
Speech2S/speech2s/data/multimodal_corpus_dataset.py
+368
-0
Speech2S/speech2s/models/__init__.py
Speech2S/speech2s/models/__init__.py
+0
-0
Speech2S/speech2s/models/speechut.py
Speech2S/speech2s/models/speechut.py
+785
-0
Speech2S/speech2s/models/speechut_asr.py
Speech2S/speech2s/models/speechut_asr.py
+165
-0
Speech2S/speech2s/models/speechut_st.py
Speech2S/speech2s/models/speechut_st.py
+221
-0
Speech2S/speech2s/models/t5_transformer_lm.py
Speech2S/speech2s/models/t5_transformer_lm.py
+25
-0
Speech2S/speech2s/modules/__init__.py
Speech2S/speech2s/modules/__init__.py
+27
-0
Speech2S/speech2s/modules/ctc_prefix_score.py
Speech2S/speech2s/modules/ctc_prefix_score.py
+93
-0
Speech2S/speech2s/modules/learned_positional_embedding.py
Speech2S/speech2s/modules/learned_positional_embedding.py
+69
-0
Speech2S/speech2s/modules/multihead_attention.py
Speech2S/speech2s/modules/multihead_attention.py
+346
-0
Speech2S/speech2s/modules/relative_pos_enc.py
Speech2S/speech2s/modules/relative_pos_enc.py
+33
-0
Speech2S/speech2s/modules/transformer_decoder.py
Speech2S/speech2s/modules/transformer_decoder.py
+543
-0
Speech2S/speech2s/modules/transformer_encoder.py
Speech2S/speech2s/modules/transformer_encoder.py
+401
-0
Speech2S/speech2s/modules/transformer_layer.py
Speech2S/speech2s/modules/transformer_layer.py
+330
-0
Speech2S/speech2s/modules/w2v_encoder.py
Speech2S/speech2s/modules/w2v_encoder.py
+281
-0
Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_asr.sh
...s/scripts copy/pretrain_speechut/base_speechut_for_asr.sh
+40
-0
Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st.sh
...2s/scripts copy/pretrain_speechut/base_speechut_for_st.sh
+47
-0
Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st_enfr.sh
...ripts copy/pretrain_speechut/base_speechut_for_st_enfr.sh
+48
-0
Speech2S/speech2s/scripts copy/pretrain_speechut/large_speechut_for_asr.sh
.../scripts copy/pretrain_speechut/large_speechut_for_asr.sh
+41
-0
No files found.
Too many changes to show.
To preserve performance only
350 of 350+
files are displayed.
Plain diff
Email patch
Speech2S/speech2s/data/load_langpair_dataset.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/272c4c5197250997148fb12c0db6306035f166a4/fairseq/tasks/translation.py
1. Add custom lang_format in function load_langpair_dataset
2. If truncate_source (default no), use RandomCropDataset instead of TruncateDataset
"""
import
itertools
import
logging
import
os
from
fairseq.data
import
(
AppendTokenDataset
,
LanguagePairDataset
,
PrependTokenDataset
,
StripTokenDataset
,
TruncateDataset
,
RandomCropDataset
,
data_utils
,
indexed_dataset
,
)
from
speechut.data.concat_dataset
import
ConcatDataset
EVAL_BLEU_ORDER
=
4
logger
=
logging
.
getLogger
(
__name__
)
def
load_langpair_dataset
(
data_path
,
split
,
src
,
src_dict
,
tgt
,
tgt_dict
,
combine
,
dataset_impl
,
upsample_primary
,
left_pad_source
,
left_pad_target
,
max_source_positions
,
max_target_positions
,
prepend_bos
=
False
,
load_alignments
=
False
,
truncate_source
=
False
,
append_source_id
=
False
,
num_buckets
=
0
,
shuffle
=
True
,
pad_to_multiple
=
1
,
prepend_bos_src
=
None
,
lang_format
=
"[{}]"
,
input_feeding
=
True
,
):
def
split_exists
(
split
,
src
,
tgt
,
lang
,
data_path
):
filename
=
os
.
path
.
join
(
data_path
,
"{}.{}-{}.{}"
.
format
(
split
,
src
,
tgt
,
lang
))
return
indexed_dataset
.
dataset_exists
(
filename
,
impl
=
dataset_impl
)
src_datasets
=
[]
tgt_datasets
=
[]
for
k
in
itertools
.
count
():
split_k
=
split
+
(
str
(
k
)
if
k
>
0
else
""
)
# infer langcode
if
split_exists
(
split_k
,
src
,
tgt
,
src
,
data_path
):
prefix
=
os
.
path
.
join
(
data_path
,
"{}.{}-{}."
.
format
(
split_k
,
src
,
tgt
))
elif
split_exists
(
split_k
,
tgt
,
src
,
src
,
data_path
):
prefix
=
os
.
path
.
join
(
data_path
,
"{}.{}-{}."
.
format
(
split_k
,
tgt
,
src
))
else
:
if
k
>
0
:
break
else
:
raise
FileNotFoundError
(
"Dataset not found: {} ({})"
.
format
(
split
,
data_path
)
)
src_dataset
=
data_utils
.
load_indexed_dataset
(
prefix
+
src
,
src_dict
,
dataset_impl
)
if
truncate_source
:
src_dataset
=
AppendTokenDataset
(
RandomCropDataset
(
StripTokenDataset
(
src_dataset
,
src_dict
.
eos
()),
max_source_positions
-
1
,
),
src_dict
.
eos
(),
)
src_datasets
.
append
(
src_dataset
)
tgt_dataset
=
data_utils
.
load_indexed_dataset
(
prefix
+
tgt
,
tgt_dict
,
dataset_impl
)
if
tgt_dataset
is
not
None
:
tgt_datasets
.
append
(
tgt_dataset
)
logger
.
info
(
"{} {} {}-{} {} examples"
.
format
(
data_path
,
split_k
,
src
,
tgt
,
len
(
src_datasets
[
-
1
])
)
)
if
not
combine
:
break
assert
len
(
src_datasets
)
==
len
(
tgt_datasets
)
or
len
(
tgt_datasets
)
==
0
if
len
(
src_datasets
)
==
1
:
src_dataset
=
src_datasets
[
0
]
tgt_dataset
=
tgt_datasets
[
0
]
if
len
(
tgt_datasets
)
>
0
else
None
else
:
sample_ratios
=
[
1
]
*
len
(
src_datasets
)
sample_ratios
[
0
]
=
upsample_primary
src_dataset
=
ConcatDataset
(
src_datasets
,
sample_ratios
)
if
len
(
tgt_datasets
)
>
0
:
tgt_dataset
=
ConcatDataset
(
tgt_datasets
,
sample_ratios
)
else
:
tgt_dataset
=
None
if
prepend_bos
:
assert
hasattr
(
src_dict
,
"bos_index"
)
and
hasattr
(
tgt_dict
,
"bos_index"
)
src_dataset
=
PrependTokenDataset
(
src_dataset
,
src_dict
.
bos
())
if
tgt_dataset
is
not
None
:
tgt_dataset
=
PrependTokenDataset
(
tgt_dataset
,
tgt_dict
.
bos
())
elif
prepend_bos_src
is
not
None
:
logger
.
info
(
f
"prepending src bos:
{
prepend_bos_src
}
"
)
src_dataset
=
PrependTokenDataset
(
src_dataset
,
prepend_bos_src
)
eos
=
None
if
append_source_id
:
src_dataset
=
AppendTokenDataset
(
src_dataset
,
src_dict
.
index
(
lang_format
.
format
(
src
))
)
if
tgt_dataset
is
not
None
:
tgt_dataset
=
AppendTokenDataset
(
tgt_dataset
,
tgt_dict
.
index
(
lang_format
.
format
(
tgt
))
)
eos
=
tgt_dict
.
index
(
lang_format
.
format
(
tgt
))
align_dataset
=
None
if
load_alignments
:
align_path
=
os
.
path
.
join
(
data_path
,
"{}.align.{}-{}"
.
format
(
split
,
src
,
tgt
))
if
indexed_dataset
.
dataset_exists
(
align_path
,
impl
=
dataset_impl
):
align_dataset
=
data_utils
.
load_indexed_dataset
(
align_path
,
None
,
dataset_impl
)
tgt_dataset_sizes
=
tgt_dataset
.
sizes
if
tgt_dataset
is
not
None
else
None
return
LanguagePairDataset
(
src_dataset
,
src_dataset
.
sizes
,
src_dict
,
tgt_dataset
,
tgt_dataset_sizes
,
tgt_dict
,
left_pad_source
=
left_pad_source
,
left_pad_target
=
left_pad_target
,
align_dataset
=
align_dataset
,
eos
=
eos
,
num_buckets
=
num_buckets
,
shuffle
=
shuffle
,
pad_to_multiple
=
pad_to_multiple
,
input_feeding
=
input_feeding
,
)
Speech2S/speech2s/data/multimodal_corpus_dataset.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import
logging
from
os
import
replace
import
time
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
List
,
Optional
import
numpy
as
np
from
fairseq.data
import
data_utils
from
fairseq.data
import
FairseqDataset
logger
=
logging
.
getLogger
(
__name__
)
class
MultiCorpusDataset
(
FairseqDataset
):
"""
see fairseq/fairseq/data/multi_corpus_dataset.__doc__
Args:
datasets: a OrderedDict of FairseqDataset instances.
distribution: a List containing the probability of getting an utterance from
corresponding dataset
seed: random seed for sampling the datsets
sort_indices: if true, will sort the ordered indices by size
batch_sample: if true, will ensure each batch is from a single dataset
"""
def
__init__
(
self
,
datasets
:
Dict
[
str
,
FairseqDataset
],
max_positions
:
Dict
,
distribution
:
List
[
float
],
max_tokens_ratio
:
List
[
float
],
seed
:
int
=
1234
,
sort_indices
:
bool
=
False
,
check_length
:
bool
=
False
,
):
super
().
__init__
()
assert
isinstance
(
datasets
,
OrderedDict
)
assert
len
(
datasets
)
==
len
(
distribution
)
# assert sum(distribution) == 1
self
.
datasets
=
datasets
self
.
distribution
=
distribution
self
.
max_tokens_ratio
=
max_tokens_ratio
self
.
seed
=
seed
self
.
sort_indices
=
sort_indices
self
.
max_positions
=
max_positions
self
.
check_length
=
check_length
# Avoid repeated conversions to list later
self
.
dataset_list
=
list
(
datasets
.
values
())
self
.
total_num_instances
=
0
# first_dataset = self.dataset_list[0]
self
.
num_instances_per_dataset
=
[]
self
.
dataset_offsets
=
[]
for
i
,
dataset
in
enumerate
(
self
.
dataset_list
):
assert
isinstance
(
dataset
,
FairseqDataset
)
# assert type(dataset) is type(first_dataset)
self
.
num_instances_per_dataset
.
append
(
0
if
self
.
distribution
[
i
]
==
0
else
len
(
dataset
)
)
self
.
dataset_offsets
.
append
(
self
.
total_num_instances
)
self
.
total_num_instances
+=
self
.
num_instances_per_dataset
[
i
]
def
ordered_indices
(
self
):
start
=
time
.
time
()
with
data_utils
.
numpy_seed
(
self
.
seed
,
self
.
epoch
):
logger
.
info
(
f
"sampling new dataset with seed
{
self
.
seed
}
epoch
{
self
.
epoch
}
"
)
sampled_indices
=
{}
# For each dataset i, sample self.distribution[i] * self.total_num_instances
for
i
,
key
in
enumerate
(
self
.
datasets
):
tp
=
time
.
time
()
if
self
.
distribution
[
i
]
==
0
:
# skip dataset if sampling probability is 0
continue
if
i
<
len
(
self
.
datasets
)
-
1
:
num_instances
=
int
(
self
.
distribution
[
i
]
*
self
.
total_num_instances
)
high
=
self
.
dataset_offsets
[
i
+
1
]
else
:
num_instances
=
int
(
self
.
distribution
[
i
]
*
self
.
total_num_instances
)
high
=
self
.
total_num_instances
logger
.
info
(
f
"sampling
{
num_instances
}
from
{
key
}
dataset"
)
# First, add k copies of the dataset where k = num_instances // len(dataset).
# This ensures an equal distribution of the data points as much as possible.
# For the remaining entries randomly sample them
dataset_size
=
len
(
self
.
datasets
[
key
])
num_copies
=
num_instances
//
dataset_size
dataset_indices
=
np
.
random
.
permutation
(
high
-
self
.
dataset_offsets
[
i
])[:
num_instances
-
num_copies
*
dataset_size
]
if
num_copies
>
0
:
dataset_indices
=
np
.
concatenate
(
(
np
.
repeat
(
np
.
arange
(
high
-
self
.
dataset_offsets
[
i
]),
num_copies
),
dataset_indices
,
)
)
# filter by size, we should ignore it by setting check_length=False
# , as it is very time-consuming on large dadaset
if
self
.
max_positions
[
key
]
is
not
None
and
self
.
check_length
:
dataset_indices
,
ignored
=
self
.
datasets
[
key
].
filter_indices_by_size
(
dataset_indices
,
self
.
max_positions
[
key
],
)
if
len
(
ignored
)
>
0
:
logger
.
warning
(
(
"{:,} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).
format
(
len
(
ignored
),
self
.
max_positions
[
key
],
ignored
[:
10
])
)
if
self
.
sort_indices
:
logger
.
info
(
" - sampled indices took {}s"
.
format
(
time
.
time
()
-
tp
))
tp
=
time
.
time
()
dataset_indices
=
np
.
sort
(
dataset_indices
)
ordered_indices
=
self
.
datasets
[
key
].
ordered_indices
()
if
isinstance
(
ordered_indices
[
0
],
np
.
ndarray
):
# chunked audio data
dataset_indices
=
[
order_idx
+
self
.
dataset_offsets
[
i
]
for
order_idx
in
ordered_indices
]
assert
self
.
dataset_offsets
[
i
]
==
0
# TODO for chunked audio data, now assume len(dataset_indices) == len(dataset). Don't filter any data.
else
:
dataset_indices
=
ordered_indices
[
dataset_indices
]
+
self
.
dataset_offsets
[
i
]
logger
.
info
(
" - ordered_indices took {}s"
.
format
(
time
.
time
()
-
tp
))
else
:
np
.
random
.
shuffle
(
dataset_indices
)
sampled_indices
[
key
]
=
dataset_indices
logger
.
info
(
"multi_corpus_dataset ordered_indices took {}s"
.
format
(
time
.
time
()
-
start
)
)
return
sampled_indices
def
_map_index
(
self
,
index
:
int
):
"""
If dataset A has length N and dataset B has length M
then index 1 maps to index 1 of dataset A, and index N + 1
maps to index 1 of B.
"""
counter
=
0
for
num_instances
,
key
in
zip
(
self
.
num_instances_per_dataset
,
self
.
datasets
):
if
index
<
counter
+
num_instances
:
return
index
-
counter
,
key
counter
+=
num_instances
raise
ValueError
(
"Invalid index: {}, max: {}"
.
format
(
index
,
self
.
total_num_instances
)
)
def
__len__
(
self
):
"""
Length of this dataset is the sum of individual datasets
"""
return
self
.
total_num_instances
def
__getitem__
(
self
,
index
):
new_index
,
key
=
self
.
_map_index
(
index
)
try
:
item
=
self
.
datasets
[
key
][
new_index
]
item
[
"full_id"
]
=
index
return
item
except
Exception
as
e
:
e
.
args
=
(
f
"Error from
{
key
}
dataset"
,
*
e
.
args
)
raise
def
collater
(
self
,
samples
):
"""
If we are doing batch sampling, then pick the right collater to use.
Otherwise we assume all collaters are the same.
"""
if
len
(
samples
)
==
0
:
return
None
samples_dict
=
{
key
:
[]
for
key
in
self
.
datasets
}
for
s
in
samples
:
_
,
key
=
self
.
_map_index
(
s
[
"full_id"
])
samples_dict
[
key
].
append
(
s
)
batch
=
{}
for
key
in
samples_dict
:
if
len
(
samples_dict
[
key
])
==
0
:
continue
batch
[
key
]
=
self
.
datasets
[
key
].
collater
(
samples_dict
[
key
])
return
batch
def
num_tokens
(
self
,
index
:
int
):
index
,
key
=
self
.
_map_index
(
index
)
return
self
.
datasets
[
key
].
num_tokens
(
index
)
def
size
(
self
,
index
:
int
):
index
,
key
=
self
.
_map_index
(
index
)
return
self
.
datasets
[
key
].
size
(
index
)
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
False
def
set_epoch
(
self
,
epoch
,
**
unused
):
super
().
set_epoch
(
epoch
)
logger
.
info
(
f
"setting epoch of multi_corpus_dataset to
{
epoch
}
"
)
for
ds
in
self
.
dataset_list
:
if
hasattr
(
ds
,
"set_epoch"
):
ds
.
set_epoch
(
epoch
)
self
.
epoch
=
epoch
@
property
def
supports_prefetch
(
self
):
return
False
@
property
def
supports_fetch_outside_dataloader
(
self
):
return
all
(
self
.
datasets
[
key
].
supports_fetch_outside_dataloader
for
key
in
self
.
datasets
)
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
):
dataset_indices
=
indices
batches_dict
=
{}
for
n
,
key
in
enumerate
(
dataset_indices
):
max_tokens_ratio
=
self
.
max_tokens_ratio
[
n
]
if
isinstance
(
dataset_indices
[
key
][
0
],
np
.
ndarray
):
# chunked audio data
cur_batches
=
self
.
datasets
[
key
].
batch_by_size
(
dataset_indices
[
key
],
round
(
max_tokens
*
max_tokens_ratio
),
max_sentences
,
required_batch_size_multiple
,
)
logger
.
info
(
f
"Created
{
sum
([
len
(
b
)
for
b
in
cur_batches
])
}
[
{
len
(
cur_batches
)
}
] batches for dataset
{
key
}
"
)
else
:
cur_batches
=
super
().
batch_by_size
(
np
.
array
(
dataset_indices
[
key
],
dtype
=
np
.
int64
),
round
(
max_tokens
*
max_tokens_ratio
),
max_sentences
,
required_batch_size_multiple
,
)
logger
.
info
(
f
"Created
{
len
(
cur_batches
)
}
batches for dataset
{
key
}
"
)
batches_dict
[
key
]
=
cur_batches
return
batches_dict
def
get_batch_sampler
(
self
,
indices
,
num_shards
,
seed
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
split_modality_batch
=
False
,
):
def
batch_sampler
(
dataset
,
epoch
):
start
=
time
.
time
()
batches_dict
=
dataset
.
batch_by_size
(
indices
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
required_batch_size_multiple
=
required_batch_size_multiple
,
)
logger
.
info
(
f
"multi_corpus_dataset, batch_by_size took
{
time
.
time
()
-
start
}
s"
)
start
=
time
.
time
()
new_batches
=
[]
### shuffle inner group size, split into speech/text batches
shuffled_batches_list
=
[]
speech_batches
=
[]
### we should specify the speech_batches because: we need concatenate different speech datasets
# (e.g. ltr or km) instead of loading them parellelly.
for
name
,
batches
in
batches_dict
.
items
():
if
name
.
startswith
(
"speech"
):
if
isinstance
(
batches
[
0
],
list
):
# chunked audio data
batches
=
self
.
datasets
[
name
].
shuffle_batches
(
list
(
batches
),
seed
+
epoch
)
shuffled_batches_list
.
append
(
batches
)
else
:
batches
=
inner_bucket_shuffle
(
batches
,
seed
+
epoch
,
num_shards
*
10
)
batches
=
batches
[:
(
len
(
batches
)
//
num_shards
)
*
num_shards
]
if
len
(
batches
)
==
0
:
logger
.
warning
(
f
"Sample 0 batch for
{
name
}
, you should ensure that no
{
name
}
data provided."
)
else
:
speech_batches
+=
batches
else
:
batches
=
inner_bucket_shuffle
(
batches
,
seed
+
epoch
,
num_shards
*
10
)
batches
=
batches
[:
(
len
(
batches
)
//
num_shards
)
*
num_shards
]
if
len
(
batches
)
==
0
:
logger
.
warning
(
f
"Sample 0 batch for
{
name
}
, you should ensure that no
{
name
}
data provided."
)
else
:
batches
=
shuffle_buckets
(
batches
,
seed
=
seed
+
epoch
,
inner_shuf
=
False
)
shuffled_batches_list
.
append
(
batches
)
if
len
(
speech_batches
)
>
0
:
speech_batches
=
shuffle_buckets
(
speech_batches
,
seed
=
seed
+
epoch
,
inner_shuf
=
False
)
shuffled_batches_list
.
append
(
speech_batches
)
### create the final new_batches
num_batch
=
min
(
len
(
batches
)
for
batches
in
shuffled_batches_list
)
if
split_modality_batch
:
for
i
in
range
(
0
,
num_batch
,
num_shards
):
for
batches
in
shuffled_batches_list
:
new_batches
+=
batches
[
i
:
i
+
num_shards
]
else
:
for
i
in
range
(
num_batch
):
new_batches
.
append
(
np
.
concatenate
([
batches
[
i
]
for
batches
in
shuffled_batches_list
]))
logger
.
info
(
f
"multi_corpus_dataset sample
{
len
(
new_batches
)
}
batches, took
{
time
.
time
()
-
start
}
s"
)
return
new_batches
def
inner_bucket_shuffle
(
batches
,
seed
,
bucket_size
=
10
,
thr
=
0
):
"""we assert batches is sorted form long to short.
shuffle samples in a buctet(e.g. 10 batches).
batches: a list of numpy array"""
num_batch
=
len
(
batches
)
new_batches
=
[]
num_buckets
=
len
(
batches
)
//
bucket_size
i
=
0
while
i
<
num_batch
:
if
(
i
<
bucket_size
*
thr
or
i
>=
bucket_size
*
(
num_buckets
-
thr
)
):
new_batches
.
append
(
batches
[
i
])
i
+=
1
else
:
group
=
np
.
concatenate
(
batches
[
i
:
i
+
bucket_size
])
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
group
)
new_batches
+=
np
.
array_split
(
group
,
bucket_size
)
i
+=
bucket_size
assert
all
([
len
(
batch
)
>
0
for
batch
in
new_batches
])
return
new_batches
def
shuffle_buckets
(
batches
,
seed
,
inner_shuf
=
True
):
if
inner_shuf
:
batches
=
inner_bucket_shuffle
(
batches
,
seed
,
num_shards
*
10
)
batches
=
[
batches
[
i
:
i
+
num_shards
]
for
i
in
range
(
0
,
len
(
batches
)
-
num_shards
+
1
,
num_shards
)]
assert
len
(
batches
[
-
1
])
==
num_shards
new_batches
=
[]
with
data_utils
.
numpy_seed
(
seed
):
np
.
random
.
shuffle
(
batches
)
for
group
in
batches
:
new_batches
+=
group
return
new_batches
return
batch_sampler
Speech2S/speech2s/models/__init__.py
0 → 100644
View file @
12c90639
Speech2S/speech2s/models/speechut.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
logging
from
dataclasses
import
dataclass
,
field
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
,
checkpoint_utils
from
fairseq.data.data_utils
import
compute_mask_indices
from
fairseq.data.dictionary
import
Dictionary
from
fairseq.dataclass
import
ChoiceEnum
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.transformer
import
Embedding
from
fairseq.file_io
import
PathManager
from
torch
import
Tensor
from
fairseq.models.wav2vec.wav2vec2
import
ConvFeatureExtractionModel
from
fairseq.modules
import
GradMultiply
,
LayerNorm
from
fairseq.tasks.hubert_pretraining
import
(
HubertPretrainingConfig
,
HubertPretrainingTask
,
)
from
fairseq.models.hubert
import
HubertConfig
from
fairseq.models.transformer
import
TransformerConfig
from
speechut.modules
import
TransformerEncoder
from
speechut.modules
import
TransformerEncoderBase
from
speechut.modules
import
TransformerDecoderBaseScriptable
logger
=
logging
.
getLogger
(
__name__
)
EXTRACTOR_MODE_CHOICES
=
ChoiceEnum
([
"default"
,
"layer_norm"
])
MASKING_DISTRIBUTION_CHOICES
=
ChoiceEnum
([
"static"
,
"uniform"
,
"normal"
,
"poisson"
])
@
dataclass
class
SpeechutConfig
(
HubertConfig
):
use_rel_pos_enc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"whether to use relative positional encoding"
},
)
scaling_for_att
:
float
=
field
(
default
=
1.0
,
metadata
=
{
"help"
:
"scaling for attention weights to prevent overflow issue (for large model)"
},
)
# unit encoder-decoder
text_transformer
:
TransformerConfig
=
TransformerConfig
()
reset_decoder_embedding_config
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"reset the no_scale_embedding/layernorm_embedding to default for the decoder"
},
)
add_unit_encoder
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"add unit encoder"
},
)
add_decoder
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"add decoder"
},
)
add_text_ctc
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"add_text_ctc head"
},
)
text_ctc_conv_kernel
:
int
=
field
(
default
=
2
,
metadata
=
{
"help"
:
"text_ctc_conv kernel size"
},
)
mask_u2t
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"mask the unit input in unit-to-text task"
},
)
# embedding mixing
mix_with_unit
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"mix with the unit embeddings"
},
)
use_pred_unit
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"use the embeddings of predicted units"
},
)
l2_embedding
:
bool
=
field
(
default
=
False
,
metadata
=
{
"help"
:
"compute l2 loss between unit embedding and unit hidden state"
},
)
# Finetune related
encoder_dict_size
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"text encoder dictionary dimension"
},
)
decoder_dict_size
:
int
=
field
(
default
=-
1
,
metadata
=
{
"help"
:
"decoder dictionary dimension"
},
)
@
register_model
(
"speechut"
,
dataclass
=
SpeechutConfig
)
class
SpeechutModel
(
BaseFairseqModel
):
def
__init__
(
self
,
cfg
:
SpeechutConfig
,
task_cfg
:
HubertPretrainingConfig
,
dictionaries
:
List
[
Dictionary
],
unit_dictionary
:
Dictionary
=
None
,
text_tgt_dictionary
:
Dictionary
=
None
,
)
->
None
:
super
().
__init__
()
logger
.
info
(
f
"SpeechutModel Config:
{
cfg
}
"
)
feature_enc_layers
=
eval
(
cfg
.
conv_feature_layers
)
# noqa
self
.
embed
=
feature_enc_layers
[
-
1
][
0
]
self
.
feature_extractor
=
ConvFeatureExtractionModel
(
conv_layers
=
feature_enc_layers
,
dropout
=
0.0
,
mode
=
cfg
.
extractor_mode
,
conv_bias
=
cfg
.
conv_bias
,
)
feature_ds_rate
=
np
.
prod
([
s
for
_
,
_
,
s
in
feature_enc_layers
])
self
.
feat2tar_ratio
=
cfg
.
label_rate
*
feature_ds_rate
/
task_cfg
.
sample_rate
self
.
post_extract_proj
=
(
nn
.
Linear
(
self
.
embed
,
cfg
.
encoder_embed_dim
)
if
self
.
embed
!=
cfg
.
encoder_embed_dim
else
None
)
self
.
mask_prob
=
cfg
.
mask_prob
self
.
mask_selection
=
cfg
.
mask_selection
self
.
mask_other
=
cfg
.
mask_other
self
.
mask_length
=
cfg
.
mask_length
self
.
no_mask_overlap
=
cfg
.
no_mask_overlap
self
.
mask_min_space
=
cfg
.
mask_min_space
self
.
mask_channel_prob
=
cfg
.
mask_channel_prob
self
.
mask_channel_selection
=
cfg
.
mask_channel_selection
self
.
mask_channel_other
=
cfg
.
mask_channel_other
self
.
mask_channel_length
=
cfg
.
mask_channel_length
self
.
no_mask_channel_overlap
=
cfg
.
no_mask_channel_overlap
self
.
mask_channel_min_space
=
cfg
.
mask_channel_min_space
self
.
dropout_input
=
nn
.
Dropout
(
cfg
.
dropout_input
)
self
.
dropout_features
=
nn
.
Dropout
(
cfg
.
dropout_features
)
self
.
feature_grad_mult
=
cfg
.
feature_grad_mult
self
.
logit_temp
=
cfg
.
logit_temp
self
.
skip_masked
=
cfg
.
skip_masked
self
.
skip_nomask
=
cfg
.
skip_nomask
final_dim
=
cfg
.
final_dim
if
cfg
.
final_dim
>
0
else
cfg
.
encoder_embed_dim
self
.
mask_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
cfg
.
encoder_embed_dim
).
uniform_
()
)
self
.
encoder
=
TransformerEncoder
(
cfg
)
self
.
layer_norm
=
LayerNorm
(
self
.
embed
)
self
.
target_glu
=
None
if
cfg
.
target_glu
:
self
.
target_glu
=
nn
.
Sequential
(
nn
.
Linear
(
final_dim
,
final_dim
*
2
),
nn
.
GLU
()
)
self
.
final_dim
=
final_dim
assert
len
(
dictionaries
)
<=
2
,
f
"Only support <=2 kinds of targets, get
{
len
(
dictionaries
)
}
dictionaries"
if
len
(
dictionaries
)
==
1
:
dictionaries
=
[
dictionaries
[
0
],
dictionaries
[
0
]]
self
.
num_classes
=
[
len
(
d
)
for
d
in
dictionaries
]
self
.
final_proj
=
nn
.
Linear
(
cfg
.
encoder_embed_dim
,
final_dim
)
self
.
code_encoder_proj
=
nn
.
Linear
(
cfg
.
text_transformer
.
encoder
.
embed_dim
,
self
.
num_classes
[
-
1
])
self
.
final_proj_list
=
[
self
.
final_proj
,
self
.
code_encoder_proj
]
self
.
label_embs_concat
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
num_classes
[
0
],
final_dim
))
self
.
label_embs_list
=
[
self
.
label_embs_concat
]
for
p
in
self
.
label_embs_list
:
nn
.
init
.
uniform_
(
p
)
### build unit encoder:
self
.
mask_u2t
=
cfg
.
mask_u2t
self
.
add_text_ctc
=
cfg
.
add_text_ctc
self
.
text_ctc_conv_kernel
=
cfg
.
text_ctc_conv_kernel
self
.
padding_idx
=
unit_dictionary
.
pad
()
self
.
unit_mask_idx
=
unit_dictionary
.
index
(
"<mask>"
)
self
.
add_unit_encoder
=
cfg
.
add_unit_encoder
self
.
mix_with_unit
=
cfg
.
mix_with_unit
self
.
use_pred_unit
=
cfg
.
use_pred_unit
self
.
l2_embedding
=
cfg
.
l2_embedding
if
self
.
add_unit_encoder
:
assert
len
(
unit_dictionary
)
==
self
.
num_classes
[
0
],
f
"unit_dictionary:
{
len
(
unit_dictionary
)
}
, self.num_classes[0]:
{
self
.
num_classes
[
0
]
}
"
### build unit pre-net, and shared with hubert label_embs if needed (default: False)
self
.
unit_embed_tokens
=
self
.
build_embedding
(
unit_dictionary
,
cfg
.
text_transformer
.
encoder
.
embed_dim
,
)
if
self
.
final_dim
==
cfg
.
text_transformer
.
encoder
.
embed_dim
:
logger
.
info
(
"Share label_embs[0] with unit_embed_tokens ..."
)
nn
.
init
.
uniform_
(
self
.
unit_embed_tokens
.
weight
)
self
.
label_embs_list
[
0
]
=
self
.
unit_embed_tokens
.
weight
### build unit encoder
self
.
unit_encoder
=
TransformerEncoderBase
(
cfg
.
text_transformer
,
unit_dictionary
,
self
.
unit_embed_tokens
,
use_rel_pos_enc
=
cfg
.
use_rel_pos_enc
,
scaling_for_att
=
cfg
.
scaling_for_att
,
)
### build text ctc head
if
self
.
add_text_ctc
:
conv
=
nn
.
Conv1d
(
cfg
.
text_transformer
.
encoder
.
embed_dim
,
cfg
.
text_transformer
.
encoder
.
embed_dim
,
self
.
text_ctc_conv_kernel
,
stride
=
self
.
text_ctc_conv_kernel
//
2
,
bias
=
False
,
padding
=
self
.
text_ctc_conv_kernel
//
2
,
)
nn
.
init
.
kaiming_normal_
(
conv
.
weight
)
self
.
unit_encoder_ctc_head
=
nn
.
Sequential
(
Rotate3D
(),
conv
,
nn
.
Dropout
(
p
=
0.1
),
nn
.
Sequential
(
Rotate3D
(),
Rotate3D
(),
LayerNorm
(
cfg
.
text_transformer
.
encoder
.
embed_dim
),
),
nn
.
GELU
(),
nn
.
Linear
(
cfg
.
text_transformer
.
encoder
.
embed_dim
,
len
(
text_tgt_dictionary
)),
)
### build unit2text decoder, not available for now
self
.
add_decoder
=
cfg
.
add_decoder
self
.
text_transformer_cfg
=
cfg
.
text_transformer
if
self
.
add_decoder
:
# To make sure that the decoder dict size is the same as the fine-tuning tgt_dict size or bpe code dict size
dec_dictionary
=
self
.
cutting_dictionary
(
text_tgt_dictionary
,
cfg
.
decoder_dict_size
)
decoder_embed_tokens
=
self
.
build_embedding
(
dec_dictionary
,
cfg
.
text_transformer
.
decoder
.
embed_dim
)
if
cfg
.
reset_decoder_embedding_config
:
cfg
.
text_transformer
.
no_scale_embedding
=
False
cfg
.
text_transformer
.
layernorm_embedding
=
False
cfg
.
text_transformer
.
no_token_positional_embeddings
=
False
self
.
decoder
=
TransformerDecoderBaseScriptable
(
cfg
.
text_transformer
,
dec_dictionary
,
decoder_embed_tokens
,
use_rel_pos_enc
=
cfg
.
use_rel_pos_enc
)
def
cutting_dictionary
(
self
,
dictionary
,
dict_size
):
if
dictionary
is
None
or
dict_size
<=
0
:
return
dictionary
else
:
import
copy
cut_dictionary
=
copy
.
deepcopy
(
dictionary
)
if
dict_size
>
len
(
cut_dictionary
):
for
i
in
range
(
dict_size
-
len
(
cut_dictionary
)):
cut_dictionary
.
symbols
.
append
(
f
'_
{
i
}
_'
)
else
:
cut_dictionary
.
symbols
=
cut_dictionary
.
symbols
[:
dict_size
]
return
cut_dictionary
def
build_embedding
(
self
,
dictionary
,
embed_dim
):
num_embeddings
=
len
(
dictionary
)
padding_idx
=
dictionary
.
pad
()
return
Embedding
(
num_embeddings
,
embed_dim
,
padding_idx
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
return
state_dict
@
classmethod
def
build_model
(
cls
,
cfg
:
SpeechutConfig
,
task
:
HubertPretrainingTask
):
"""Build a new model instance."""
unit_dictionary
=
getattr
(
task
,
"text_src_dictionary"
,
None
)
text_tgt_dictionary
=
getattr
(
task
,
"text_dictionary"
,
None
)
model
=
SpeechutModel
(
cfg
,
task
.
cfg
,
task
.
dictionaries
,
unit_dictionary
,
text_tgt_dictionary
)
return
model
def
apply_mask
(
self
,
x
,
padding_mask
,
target_list
):
B
,
T
,
C
=
x
.
shape
if
self
.
mask_prob
>
0
:
mask_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_prob
,
self
.
mask_length
,
self
.
mask_selection
,
self
.
mask_other
,
min_masks
=
2
,
no_overlap
=
self
.
no_mask_overlap
,
min_space
=
self
.
mask_min_space
,
)
mask_indices
=
torch
.
from_numpy
(
mask_indices
).
to
(
x
.
device
)
x
[
mask_indices
]
=
self
.
mask_emb
else
:
mask_indices
=
None
if
self
.
mask_channel_prob
>
0
:
mask_channel_indices
=
compute_mask_indices
(
(
B
,
C
),
None
,
self
.
mask_channel_prob
,
self
.
mask_channel_length
,
self
.
mask_channel_selection
,
self
.
mask_channel_other
,
no_overlap
=
self
.
no_mask_channel_overlap
,
min_space
=
self
.
mask_channel_min_space
,
)
mask_channel_indices
=
(
torch
.
from_numpy
(
mask_channel_indices
)
.
to
(
x
.
device
)
.
unsqueeze
(
1
)
.
expand
(
-
1
,
T
,
-
1
)
)
x
[
mask_channel_indices
]
=
0
return
x
,
mask_indices
def
forward_features
(
self
,
source
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
feature_grad_mult
>
0
:
features
=
self
.
feature_extractor
(
source
)
if
self
.
feature_grad_mult
!=
1.0
:
features
=
GradMultiply
.
apply
(
features
,
self
.
feature_grad_mult
)
else
:
with
torch
.
no_grad
():
features
=
self
.
feature_extractor
(
source
)
return
features
def
forward_targets
(
self
,
features
:
torch
.
Tensor
,
target_list
:
List
[
torch
.
Tensor
],
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
# Trim features to ensure labels exist and then get aligned labels
feat_tsz
=
features
.
size
(
2
)
targ_tsz
=
min
([
t
.
size
(
1
)
for
t
in
target_list
])
if
self
.
feat2tar_ratio
*
feat_tsz
>
targ_tsz
:
feat_tsz
=
int
(
targ_tsz
/
self
.
feat2tar_ratio
)
features
=
features
[...,
:
feat_tsz
]
target_inds
=
torch
.
arange
(
feat_tsz
).
float
()
*
self
.
feat2tar_ratio
target_inds
+=
np
.
random
.
choice
(
int
(
self
.
feat2tar_ratio
))
target_list
=
[
t
[:,
target_inds
.
long
()]
for
t
in
target_list
]
return
features
,
target_list
def
forward_padding_mask
(
self
,
features
:
torch
.
Tensor
,
padding_mask
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
extra
=
padding_mask
.
size
(
1
)
%
features
.
size
(
1
)
if
extra
>
0
:
padding_mask
=
padding_mask
[:,
:
-
extra
]
padding_mask
=
padding_mask
.
view
(
padding_mask
.
size
(
0
),
features
.
size
(
1
),
-
1
)
padding_mask
=
padding_mask
.
all
(
-
1
)
return
padding_mask
def
get_normalized_probs
(
self
,
net_output
:
Tuple
[
Tensor
,
Optional
[
Dict
[
str
,
List
[
Optional
[
Tensor
]]]]],
log_probs
:
bool
,
sample
:
Optional
[
Dict
[
str
,
Tensor
]]
=
None
,
):
lprobs
=
self
.
get_normalized_probs_scriptable
(
net_output
,
log_probs
,
sample
)
lprobs
.
batch_first
=
True
return
lprobs
def
downsample_ctc_padding_mask
(
self
,
padding_mask
):
"""
padding_mask: (B, T)
"""
stride
=
self
.
text_ctc_conv_kernel
//
2
return
padding_mask
[:,
::
stride
]
def
compute_pred
(
self
,
proj_x
,
label_embs
):
if
self
.
target_glu
:
label_embs
=
self
.
target_glu
(
label_embs
)
x
=
F
.
normalize
(
proj_x
.
float
(),
dim
=-
1
)
# (S, D)
label_embs
=
F
.
normalize
(
label_embs
.
float
(),
dim
=-
1
)
# (C, D)
logits
=
torch
.
matmul
(
x
,
label_embs
.
T
).
type_as
(
proj_x
)
# (S, C)
logits
/=
self
.
logit_temp
return
logits
def
compute_hubert_logits
(
self
,
x
,
target
,
proj
,
label_embs
,
padding_mask
,
mask_indices
):
if
not
self
.
skip_masked
:
masked_indices
=
torch
.
logical_and
(
~
padding_mask
,
mask_indices
)
proj_x_m
=
proj
(
x
[
masked_indices
])
logit_m_list
=
[(
self
.
compute_pred
(
proj_x_m
,
label_embs
),
target
[
masked_indices
])]
else
:
logit_m_list
=
[
None
]
if
not
self
.
skip_nomask
:
nomask_indices
=
torch
.
logical_and
(
~
padding_mask
,
~
mask_indices
)
proj_x_u
=
proj
(
x
[
nomask_indices
])
logit_u_list
=
[(
self
.
compute_pred
(
proj_x_u
,
label_embs
),
target
[
nomask_indices
])]
else
:
logit_u_list
=
[
None
]
return
logit_m_list
,
logit_u_list
def
compute_ce_logits
(
self
,
x
,
target
,
proj
,
padding_mask
,
mask_indices
):
if
not
self
.
skip_masked
:
masked_indices
=
torch
.
logical_and
(
~
padding_mask
,
mask_indices
)
logit_m_list
=
[(
proj
(
x
[
masked_indices
]),
target
[
masked_indices
])]
else
:
logit_m_list
=
[
None
]
if
not
self
.
skip_nomask
:
nomask_indices
=
torch
.
logical_and
(
~
padding_mask
,
~
mask_indices
)
logit_u_list
=
[(
proj
(
x
[
nomask_indices
]),
target
[
nomask_indices
])]
else
:
logit_u_list
=
[
None
]
return
logit_m_list
,
logit_u_list
def
convert_embeddings
(
self
,
x
,
padding_mask
,
target
=
None
,
mask_indices
=
None
,
mix_with_unit
=
False
,
use_pred_unit
=
False
,
l2_embedding
=
False
,
remask
=
False
):
"""
1. Mix with units if needed (default: True)
2. Prepare for unit_encoder inputs
Inputs:
x, (B, T, D)
Return:
src_tokens, (B, T)
soft_embeddings, (B, T, D)
l2_loss, a loss
"""
soft_embeddings
=
self
.
final_proj_list
[
0
](
x
)
if
x
.
size
(
-
1
)
==
self
.
final_dim
else
x
if
padding_mask
is
None
:
padding_mask
=
soft_embeddings
.
new_zeros
(
soft_embeddings
.
size
(
0
),
soft_embeddings
.
size
(
1
),
dtype
=
torch
.
long
)
if
use_pred_unit
:
src_tokens
=
self
.
compute_pred
(
self
.
final_proj_list
[
0
](
x
),
self
.
label_embs_list
[
0
]).
argmax
(
dim
=-
1
)
src_tokens
[
padding_mask
]
=
self
.
padding_idx
elif
target
is
not
None
:
src_tokens
=
target
else
:
src_tokens
=
padding_mask
.
long
()
if
l2_embedding
|
mix_with_unit
:
unit_embeddings
=
self
.
unit_embed_tokens
(
src_tokens
)
# (B, T, D)
l2_loss
=
0
if
l2_embedding
:
if
mask_indices
is
not
None
:
l2_loss
=
(
soft_embeddings
-
unit_embeddings
)[
mask_indices
].
float
().
pow
(
2
).
mean
(
dim
=-
1
)
scale
=
unit_embeddings
[
mask_indices
].
float
().
pow
(
2
).
sum
(
dim
=-
1
)
else
:
l2_loss
=
(
soft_embeddings
-
unit_embeddings
).
float
().
pow
(
2
).
mean
(
dim
=-
1
)
scale
=
unit_embeddings
.
float
().
pow
(
2
).
sum
(
dim
=-
1
)
l2_loss
=
(
l2_loss
/
scale
).
mean
()
if
mix_with_unit
:
B
,
T
,
D
=
x
.
shape
selected_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_prob
/
2
,
self
.
mask_length
//
2
,
self
.
mask_selection
,
self
.
mask_other
,
min_masks
=
2
,
no_overlap
=
self
.
no_mask_overlap
,
min_space
=
self
.
mask_min_space
,
)
selected_indices
=
torch
.
from_numpy
(
selected_indices
).
to
(
x
.
device
)
if
mask_indices
is
not
None
:
if
remask
:
remask_indices
=
torch
.
logical_and
(
selected_indices
,
mask_indices
)
soft_embeddings
[
remask_indices
]
=
self
.
mask_emb
swap_indices
=
torch
.
logical_and
(
selected_indices
,
~
mask_indices
)
else
:
swap_indices
=
selected_indices
soft_embeddings
[
swap_indices
]
=
unit_embeddings
[
swap_indices
]
soft_embeddings
=
soft_embeddings
*
(
1
-
padding_mask
.
unsqueeze
(
-
1
).
type_as
(
x
))
return
src_tokens
,
soft_embeddings
,
l2_loss
def
forward
(
self
,
source
:
torch
.
Tensor
=
None
,
src_tokens
:
torch
.
Tensor
=
None
,
src_lengths
:
torch
.
Tensor
=
None
,
prev_output_tokens
:
torch
.
Tensor
=
None
,
target_list
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
bool
=
True
,
features_only
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
assert
source
is
not
None
or
src_tokens
is
not
None
if
source
is
not
None
:
return
self
.
forward_speech
(
source
=
source
,
target_list
=
target_list
,
padding_mask
=
padding_mask
,
mask
=
mask
,
features_only
=
features_only
,
output_layer
=
output_layer
,
)
else
:
return
self
.
forward_text
(
src_tokens
=
src_tokens
,
src_lengths
=
src_lengths
,
prev_output_tokens
=
prev_output_tokens
,
mask
=
self
.
mask_u2t
,
features_only
=
features_only
,
output_layer
=
output_layer
,
)
def
forward_speech
(
self
,
source
:
torch
.
Tensor
=
None
,
target_list
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
bool
=
True
,
features_only
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
"""output layer is 1-based"""
features
=
self
.
forward_features
(
source
)
if
target_list
is
not
None
:
features
,
target_list
=
self
.
forward_targets
(
features
,
target_list
)
features_pen
=
features
.
float
().
pow
(
2
).
mean
()
features
=
features
.
transpose
(
1
,
2
)
features
=
self
.
layer_norm
(
features
)
unmasked_features
=
features
.
clone
()
if
padding_mask
is
not
None
:
padding_mask
=
self
.
forward_padding_mask
(
features
,
padding_mask
)
if
self
.
post_extract_proj
is
not
None
:
features
=
self
.
post_extract_proj
(
features
)
features
=
self
.
dropout_input
(
features
)
unmasked_features
=
self
.
dropout_features
(
unmasked_features
)
if
mask
:
x
,
mask_indices
=
self
.
apply_mask
(
features
,
padding_mask
,
target_list
)
else
:
x
=
features
mask_indices
=
None
# feature: (B, T, D), float
# target: (B, T), long
# x: (B, T, D), float
# padding_mask: (B, T), bool
# mask_indices: (B, T), bool
x
,
_
=
self
.
encoder
(
x
,
padding_mask
=
padding_mask
,
layer
=
None
if
output_layer
is
None
else
output_layer
-
1
,
)
if
features_only
:
return
{
"x"
:
x
,
"padding_mask"
:
padding_mask
,
"features"
:
features
}
logit_m_list
,
logit_u_list
=
self
.
compute_hubert_logits
(
x
,
target_list
[
0
],
self
.
final_proj_list
[
0
],
self
.
label_embs_list
[
0
],
padding_mask
,
mask_indices
,
)
result
=
{
"logit_m_list"
:
logit_m_list
,
"logit_u_list"
:
logit_u_list
,
"padding_mask"
:
padding_mask
,
"features_pen"
:
features_pen
,
}
if
self
.
add_unit_encoder
:
src_tokens
,
x_emb
,
l2_loss
=
self
.
convert_embeddings
(
x
,
padding_mask
,
target_list
[
0
],
mask_indices
=
mask_indices
,
mix_with_unit
=
self
.
mix_with_unit
,
use_pred_unit
=
self
.
use_pred_unit
,
l2_embedding
=
self
.
l2_embedding
,
)
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
x_emb
)
result
[
'encoder_out'
]
=
encoder_out
[
'encoder_out'
]
# [(T, B, D)]
result
[
'encoder_padding_mask'
]
=
encoder_out
[
'encoder_padding_mask'
]
# [(B, T)]
if
self
.
l2_embedding
:
result
[
'embedding_l2_loss'
]
=
l2_loss
code_logit_m_list
,
code_logit_u_list
=
self
.
compute_ce_logits
(
encoder_out
[
'encoder_out'
][
0
].
transpose
(
0
,
1
),
# -> (B, T, C)
target_list
[
-
1
],
self
.
final_proj_list
[
1
],
padding_mask
,
mask_indices
,
)
result
[
'logit_m_list'
]
+=
code_logit_m_list
result
[
'logit_u_list'
]
+=
code_logit_u_list
return
result
def
forward_text
(
self
,
src_tokens
:
torch
.
Tensor
=
None
,
src_lengths
:
torch
.
Tensor
=
None
,
prev_output_tokens
:
torch
.
Tensor
=
None
,
target_list
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
mask
:
bool
=
True
,
features_only
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
)
->
Dict
[
str
,
torch
.
Tensor
]:
assert
self
.
add_unit_encoder
,
f
"Can not forward unit-text branch without unit_encoder!"
padding_mask
=
src_tokens
==
self
.
padding_idx
unit_embeddings
=
self
.
unit_embed_tokens
(
src_tokens
)
if
mask
:
unit_embeddings
,
mask_indices
=
self
.
apply_mask
(
unit_embeddings
,
padding_mask
,
[
src_tokens
])
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
unit_embeddings
,
return_all_hiddens
=
output_layer
is
not
None
,
)
result
=
{}
result
[
"encoder_out"
]
=
encoder_out
[
"encoder_out"
]
result
[
"encoder_states"
]
=
encoder_out
[
"encoder_states"
]
result
[
"padding_mask"
]
=
padding_mask
if
self
.
add_text_ctc
:
result
[
"encoder_out_ctc"
]
=
[
self
.
unit_encoder_ctc_head
(
x
)
for
x
in
encoder_out
[
'encoder_out'
]]
result
[
"encoder_padding_mask"
]
=
[
self
.
downsample_ctc_padding_mask
(
padding_mask
)
for
padding_mask
in
encoder_out
[
'encoder_padding_mask'
]
]
if
features_only
:
return
result
if
self
.
add_decoder
:
assert
prev_output_tokens
is
not
None
decoder_out
=
self
.
decoder
(
prev_output_tokens
=
prev_output_tokens
,
encoder_out
=
encoder_out
,
)
result
[
'decoder_out'
]
=
decoder_out
return
result
def
forward_mum
(
self
,
src_tokens
,
target
,
mask
=
True
):
target_list
=
[
target
]
padding_mask
=
src_tokens
.
eq
(
self
.
unit_encoder
.
padding_idx
)
unit_embeddings
=
self
.
unit_embed_tokens
(
src_tokens
)
if
mask
:
unit_embeddings
,
mask_indices
=
self
.
apply_mask
(
unit_embeddings
,
padding_mask
,
target_list
)
else
:
### If already applied mask on src_tokens, then the target_list should contains many padding_idx
mask_indices
=
target_list
[
-
1
]
!=
self
.
padding_idx
unit_embeddings
[
mask_indices
]
=
self
.
mask_emb
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
unit_embeddings
,
)
code_logit_m_list
,
code_logit_u_list
=
self
.
compute_ce_logits
(
encoder_out
[
"encoder_out"
][
0
].
transpose
(
0
,
1
),
target_list
[
-
1
],
self
.
final_proj_list
[
1
],
padding_mask
,
mask_indices
,
)
result
=
{}
result
[
"logit_m_list"
]
=
code_logit_m_list
result
[
"logit_u_list"
]
=
code_logit_u_list
result
[
"padding_mask"
]
=
padding_mask
return
result
def
extract_features
(
self
,
source
:
torch
.
Tensor
,
padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
mask
:
bool
=
False
,
ret_conv
:
bool
=
False
,
output_layer
:
Optional
[
int
]
=
None
,
**
kwargs
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Extract encoder features for only speech input"""
res
=
self
.
forward
(
source
,
padding_mask
=
padding_mask
,
mask
=
mask
,
features_only
=
True
,
output_layer
=
output_layer
,
)
x
=
res
[
"x"
]
# B x T x D
padding_mask
=
res
[
"padding_mask"
]
if
self
.
add_unit_encoder
:
src_tokens
,
x
,
_
=
self
.
convert_embeddings
(
x
,
padding_mask
,
mix_with_unit
=
False
,
use_pred_unit
=
False
,
)
encoder_out
=
self
.
unit_encoder
(
src_tokens
,
token_embeddings
=
x
,
return_all_hiddens
=
output_layer
is
not
None
)
res
[
"x"
]
=
encoder_out
[
'encoder_out'
][
0
].
transpose
(
0
,
1
)
# (B, T, D)
feature
=
res
[
"features"
]
if
ret_conv
else
res
[
"x"
]
if
output_layer
is
not
None
:
feature
=
encoder_out
[
'encoder_states'
]
return
feature
,
padding_mask
def
get_logits
(
self
,
net_output
,
is_masked
=
True
):
if
is_masked
:
logits_list
=
net_output
[
"logit_m_list"
]
else
:
logits_list
=
net_output
[
"logit_u_list"
]
logits_list
=
[
x
[
0
].
float
()
for
x
in
logits_list
if
x
is
not
None
]
return
logits_list
def
get_targets
(
self
,
net_output
,
is_masked
=
True
):
if
is_masked
:
logits_list
=
net_output
[
"logit_m_list"
]
else
:
logits_list
=
net_output
[
"logit_u_list"
]
targets_list
=
[
x
[
1
].
long
()
for
x
in
logits_list
if
x
is
not
None
]
return
targets_list
def
get_extra_losses
(
self
,
net_output
):
extra_losses
=
[]
names
=
[]
if
"features_pen"
in
net_output
:
extra_losses
.
append
(
net_output
[
"features_pen"
])
names
.
append
(
"features_pen"
)
if
"embedding_l2_loss"
in
net_output
:
extra_losses
.
append
(
net_output
[
"embedding_l2_loss"
])
names
.
append
(
"embedding_l2_loss"
)
return
extra_losses
,
names
def
remove_pretraining_modules
(
self
,
step2
=
False
):
self
.
target_glu
=
None
def
load_checkpoint
(
self
,
checkpoint
:
str
):
if
not
PathManager
.
exists
(
checkpoint
):
raise
IOError
(
"Model file not found: {}"
.
format
(
checkpoint
))
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
checkpoint
)
return
state
class
Rotate3D
(
nn
.
Module
):
"""
(T, B, D) --> (B, D, T) --> (D, T, B) --> (T, B, D)
"""
def
__init__
(
self
):
super
().
__init__
()
def
forward
(
self
,
x
):
return
x
.
permute
(
1
,
2
,
0
)
Speech2S/speech2s/models/speechut_asr.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
contextlib
import
torch
from
dataclasses
import
dataclass
,
field
from
fairseq
import
utils
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.fairseq_encoder
import
FairseqEncoder
from
fairseq.models.hubert
import
HubertAsrConfig
,
HubertEncoder
from
fairseq.tasks
import
FairseqTask
@
dataclass
class
SpeechUTASRConfig
(
HubertAsrConfig
):
add_decoder
:
bool
=
field
(
default
=
True
,
metadata
=
{
"help"
:
"add decoder for fine-tune"
},
)
@
register_model
(
"speechut_asr"
,
dataclass
=
SpeechUTASRConfig
)
class
SpeechUTASR
(
BaseFairseqModel
):
"""
A encoder-ctc-decoder model if cfg.add_decoder is True, or a encoder-ctc model
"""
def
__init__
(
self
,
cfg
:
SpeechUTASRConfig
,
encoder
:
FairseqEncoder
):
super
().
__init__
()
self
.
cfg
=
cfg
self
.
encoder
=
encoder
if
not
cfg
.
add_decoder
:
self
.
encoder
.
w2v_model
.
decoder
=
None
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
return
state_dict
@
classmethod
def
build_model
(
cls
,
cfg
:
SpeechUTASRConfig
,
task
:
FairseqTask
):
"""Build a new model instance."""
encoder
=
SpeechUTEncoder
(
cfg
,
task
)
return
cls
(
cfg
,
encoder
)
def
forward
(
self
,
source
,
padding_mask
,
prev_output_tokens
,
**
kwargs
):
encoder_out
=
self
.
encoder
(
source
,
padding_mask
,
**
kwargs
)
x
=
self
.
encoder
.
final_dropout
(
encoder_out
[
'encoder_out'
][
0
])
# (T, B, C)
if
self
.
encoder
.
proj
:
x
=
self
.
encoder
.
proj
(
x
)
if
self
.
encoder
.
conv_ctc_proj
:
padding_mask
=
self
.
encoder
.
w2v_model
.
downsample_ctc_padding_mask
(
encoder_out
[
"encoder_padding_mask"
][
0
])
else
:
padding_mask
=
encoder_out
[
"encoder_padding_mask"
]
decoder_out
=
self
.
decoder
(
prev_output_tokens
,
encoder_out
=
encoder_out
,
**
kwargs
)
if
self
.
cfg
.
add_decoder
else
None
return
{
"encoder_out_ctc"
:
x
,
# (T, B, C), for CTC loss
"padding_mask"
:
padding_mask
,
# (B, T), for CTC loss
"decoder_out"
:
decoder_out
,
# for ED loss
}
def
forward_decoder
(
self
,
prev_output_tokens
,
**
kwargs
):
return
self
.
decoder
(
prev_output_tokens
,
**
kwargs
)
def
get_logits
(
self
,
net_output
):
"""For CTC decoding"""
logits
=
net_output
[
"encoder_out"
]
padding
=
net_output
[
"encoder_padding_mask"
]
if
padding
is
not
None
and
padding
.
any
():
padding
=
padding
.
T
logits
[
padding
][...,
0
]
=
0
logits
[
padding
][...,
1
:]
=
float
(
"-inf"
)
return
logits
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
"""For 1) computing CTC loss, 2) decoder decoding."""
if
"encoder_out_ctc"
in
net_output
:
logits
=
net_output
[
"encoder_out_ctc"
]
else
:
return
self
.
decoder
.
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
if
isinstance
(
logits
,
list
):
logits
=
logits
[
0
]
if
log_probs
:
return
utils
.
log_softmax
(
logits
.
float
(),
dim
=-
1
)
else
:
return
utils
.
softmax
(
logits
.
float
(),
dim
=-
1
)
@
property
def
decoder
(
self
):
return
self
.
encoder
.
w2v_model
.
decoder
class
SpeechUTEncoder
(
HubertEncoder
):
"""
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
1. make it compatible with encoder-decoder model
"""
def
__init__
(
self
,
cfg
:
HubertAsrConfig
,
task
):
super
().
__init__
(
cfg
,
task
)
if
(
task
.
target_dictionary
is
not
None
)
and
(
hasattr
(
self
.
w2v_model
,
"unit_encoder_ctc_head"
)
):
self
.
proj
=
self
.
w2v_model
.
unit_encoder_ctc_head
self
.
conv_ctc_proj
=
True
else
:
self
.
conv_ctc_proj
=
False
def
forward
(
self
,
source
,
padding_mask
,
tbc
=
True
,
**
kwargs
):
w2v_args
=
{
"source"
:
source
,
"padding_mask"
:
padding_mask
,
"mask"
:
self
.
apply_mask
and
self
.
training
,
}
ft
=
self
.
freeze_finetune_updates
<=
self
.
num_updates
with
torch
.
no_grad
()
if
not
ft
else
contextlib
.
ExitStack
():
x
,
padding_mask
=
self
.
w2v_model
.
extract_features
(
**
w2v_args
)
if
tbc
:
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
padding_mask
],
# B x T
}
def
forward_torchscript
(
self
,
net_input
):
"""A TorchScript-compatible version of forward.
Forward the encoder out.
"""
x
,
padding_mask
=
self
.
w2v_model
.
extract_features
(
**
net_input
,
mask
=
False
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
encoder_out
=
{
"encoder_out"
:
[
x
],
"encoder_padding_mask"
:
[
padding_mask
],
}
if
self
.
proj
:
x
=
self
.
proj
(
x
)
encoder_out
[
"encoder_out_ctc"
]
=
x
return
encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
if
encoder_out
[
"encoder_out"
]
is
not
None
:
encoder_out
[
"encoder_out"
]
=
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]
]
if
encoder_out
[
"encoder_padding_mask"
]
is
not
None
:
encoder_out
[
"encoder_padding_mask"
]
=
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_padding_mask"
]
]
return
encoder_out
Speech2S/speech2s/models/speechut_st.py
0 → 100644
View file @
12c90639
# ----------------------------------------------------------------------------
# SpeechUT: Bridging Speech and Text with Hidden-Unit for Encoder-Decoder Based Speech-Text Pre-training (https://arxiv.org/abs/2210.03730)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/SpeechUT
# Code based on fairseq: https://github.com/facebookresearch/fairseq/tree/272c4c5197250997148fb12c0db6306035f166a4
#
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# ----------------------------------------------------------------------------
import
logging
import
contextlib
import
torch
import
torch.nn
as
nn
from
argparse
import
Namespace
from
dataclasses
import
dataclass
from
typing
import
Any
from
fairseq
import
checkpoint_utils
,
tasks
from
fairseq.models
import
BaseFairseqModel
,
register_model
from
fairseq.models.fairseq_encoder
import
FairseqEncoder
from
fairseq.tasks
import
FairseqTask
from
fairseq.dataclass.utils
import
convert_namespace_to_omegaconf
from
fairseq.data.data_utils
import
lengths_to_padding_mask
from
fairseq.models.hubert
import
HubertAsrConfig
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
SpeechUTS2TConfig
(
HubertAsrConfig
):
### the following config is only for the compatibility to fairseq speech_to_text task
input_feat_per_channel
:
Any
=
None
input_channels
:
Any
=
None
speaker_to_id
:
Any
=
None
@
register_model
(
"speechut_st_legacy"
,
dataclass
=
SpeechUTS2TConfig
)
class
SpeechUTS2T
(
BaseFairseqModel
):
"""An encoder-decoder model."""
def
__init__
(
self
,
cfg
:
SpeechUTS2TConfig
,
encoder
:
FairseqEncoder
):
super
().
__init__
()
self
.
cfg
=
cfg
self
.
encoder
=
encoder
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
super
().
upgrade_state_dict_named
(
state_dict
,
name
)
return
state_dict
@
classmethod
def
build_model
(
cls
,
cfg
:
SpeechUTS2TConfig
,
task
:
FairseqTask
):
"""Build a new model instance."""
encoder
=
SpeechUTEncoder
(
cfg
,
task
)
return
cls
(
cfg
,
encoder
)
def
forward
(
self
,
src_tokens
,
src_lengths
,
prev_output_tokens
,
**
kwargs
):
encoder_out
=
self
.
encoder
(
src_tokens
,
src_lengths
,
**
kwargs
)
decoder_out
=
self
.
encoder
.
w2v_model
.
decoder
(
prev_output_tokens
,
encoder_out
=
encoder_out
,
**
kwargs
)
return
decoder_out
def
forward_decoder
(
self
,
prev_output_tokens
,
**
kwargs
):
return
self
.
encoder
.
w2v_model
.
decoder
(
prev_output_tokens
,
**
kwargs
)
def
get_normalized_probs
(
self
,
net_output
,
log_probs
,
sample
=
None
):
"""For decoder decoding."""
return
self
.
encoder
.
w2v_model
.
decoder
.
get_normalized_probs
(
net_output
,
log_probs
,
sample
)
@
property
def
decoder
(
self
):
return
self
.
encoder
.
w2v_model
.
decoder
class
SpeechUTEncoder
(
FairseqEncoder
):
"""
Modified from fairseq.models.hubert.hubert_asr.HubertEncoder
1. make it compatible with fairseq speech_to_text task
2. make it compatible with encoder-decoder model
"""
def
__init__
(
self
,
cfg
:
SpeechUTS2TConfig
,
task
):
self
.
apply_mask
=
cfg
.
apply_mask
arg_overrides
=
{
"dropout"
:
cfg
.
dropout
,
"activation_dropout"
:
cfg
.
activation_dropout
,
"dropout_input"
:
cfg
.
dropout_input
,
"attention_dropout"
:
cfg
.
attention_dropout
,
"mask_length"
:
cfg
.
mask_length
,
"mask_prob"
:
cfg
.
mask_prob
,
"mask_selection"
:
cfg
.
mask_selection
,
"mask_other"
:
cfg
.
mask_other
,
"no_mask_overlap"
:
cfg
.
no_mask_overlap
,
"mask_channel_length"
:
cfg
.
mask_channel_length
,
"mask_channel_prob"
:
cfg
.
mask_channel_prob
,
"mask_channel_selection"
:
cfg
.
mask_channel_selection
,
"mask_channel_other"
:
cfg
.
mask_channel_other
,
"no_mask_channel_overlap"
:
cfg
.
no_mask_channel_overlap
,
"encoder_layerdrop"
:
cfg
.
layerdrop
,
"feature_grad_mult"
:
cfg
.
feature_grad_mult
,
}
if
cfg
.
w2v_args
is
None
:
state
=
checkpoint_utils
.
load_checkpoint_to_cpu
(
cfg
.
w2v_path
,
arg_overrides
)
w2v_args
=
state
.
get
(
"cfg"
,
None
)
if
w2v_args
is
None
:
w2v_args
=
convert_namespace_to_omegaconf
(
state
[
"args"
])
cfg
.
w2v_args
=
w2v_args
else
:
state
=
None
w2v_args
=
cfg
.
w2v_args
if
isinstance
(
w2v_args
,
Namespace
):
cfg
.
w2v_args
=
w2v_args
=
convert_namespace_to_omegaconf
(
w2v_args
)
assert
task
.
data_cfg
.
standardize_audio
()
==
w2v_args
.
task
.
normalize
,
(
"Fine-tuning works best when data normalization is the same. "
"Please check that --normalize is set or unset for "
"both pre-training and here"
)
pretrain_task
=
tasks
.
setup_task
(
w2v_args
.
task
,
load_local_states
=
False
)
assert
state
is
not
None
and
"task_state"
in
state
,
f
"the stored dictionaries not found in checkpoint!"
# This will load the stored "dictionaries" object
pretrain_task
.
load_state_dict
(
state
[
"task_state"
])
model
=
pretrain_task
.
build_model
(
w2v_args
.
model
,
from_checkpoint
=
True
)
if
state
is
not
None
and
not
cfg
.
no_pretrained_weights
:
try
:
model
.
load_state_dict
(
state
[
"model"
],
strict
=
True
)
except
Exception
as
e
:
logger
.
warn
(
e
)
model
.
load_state_dict
(
state
[
"model"
],
strict
=
False
)
model
.
remove_pretraining_modules
()
super
().
__init__
(
pretrain_task
.
source_dictionary
)
d
=
w2v_args
.
model
.
encoder_embed_dim
self
.
w2v_model
=
model
self
.
final_dropout
=
nn
.
Dropout
(
cfg
.
final_dropout
)
self
.
freeze_finetune_updates
=
cfg
.
freeze_finetune_updates
self
.
num_updates
=
0
def
set_num_updates
(
self
,
num_updates
):
"""Set the number of parameters updates."""
super
().
set_num_updates
(
num_updates
)
self
.
num_updates
=
num_updates
def
forward
(
self
,
src_tokens
=
None
,
src_lengths
=
None
,
**
kwargs
):
w2v_args
=
{
"source"
:
src_tokens
,
"padding_mask"
:
lengths_to_padding_mask
(
src_lengths
),
"mask"
:
self
.
apply_mask
and
self
.
training
,
}
ft
=
self
.
freeze_finetune_updates
<=
self
.
num_updates
with
torch
.
no_grad
()
if
not
ft
else
contextlib
.
ExitStack
():
x
,
padding_mask
=
self
.
w2v_model
.
extract_features
(
**
w2v_args
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
padding_mask
],
# B x T
"padding_mask"
:
[
padding_mask
],
}
def
forward_torchscript
(
self
,
net_input
):
"""A TorchScript-compatible version of forward.
Forward the encoder out.
"""
_net_input
=
{
"source"
:
net_input
[
"src_tokens"
],
"padding_mask"
:
lengths_to_padding_mask
(
net_input
[
"src_lengths"
]),
"mask"
:
False
,
}
x
,
padding_mask
=
self
.
w2v_model
.
extract_features
(
**
_net_input
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
encoder_out
=
{
"encoder_out"
:
[
x
],
"encoder_padding_mask"
:
[
padding_mask
],
}
return
encoder_out
def
reorder_encoder_out
(
self
,
encoder_out
,
new_order
):
if
encoder_out
[
"encoder_out"
]
is
not
None
:
encoder_out
[
"encoder_out"
]
=
[
x
.
index_select
(
1
,
new_order
)
for
x
in
encoder_out
[
"encoder_out"
]
]
if
encoder_out
[
"encoder_padding_mask"
]
is
not
None
:
encoder_out
[
"encoder_padding_mask"
]
=
[
x
.
index_select
(
0
,
new_order
)
for
x
in
encoder_out
[
"encoder_padding_mask"
]
]
return
encoder_out
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
return
None
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
return
state_dict
def
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
):
m
=
nn
.
Embedding
(
num_embeddings
,
embedding_dim
,
padding_idx
=
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
return
m
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
return
m
Speech2S/speech2s/models/t5_transformer_lm.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Pre-Training Transformer Decoder for End-to-End ASR Model with Unpaired Speech Data (https://arxiv.org/abs/2203.17113)
# Github source: https://github.com/microsoft/SpeechT5/tree/main/Speech2C
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/pytorch/fairseq
# --------------------------------------------------------
from
fairseq.models
import
(
register_model_architecture
,
)
from
fairseq.models.transformer_lm
import
base_lm_architecture
@
register_model_architecture
(
model_name
=
"transformer_lm"
,
arch_name
=
"transformer_lm_t5"
)
def
transformer_lm_t5
(
args
):
args
.
decoder_embed_dim
=
getattr
(
args
,
"decoder_embed_dim"
,
1280
)
args
.
decoder_ffn_embed_dim
=
getattr
(
args
,
"decoder_ffn_embed_dim"
,
6144
)
args
.
decoder_layers
=
getattr
(
args
,
"decoder_layers"
,
20
)
args
.
decoder_attention_heads
=
getattr
(
args
,
"decoder_attention_heads"
,
16
)
args
.
dropout
=
getattr
(
args
,
"dropout"
,
0.1
)
args
.
attention_dropout
=
getattr
(
args
,
"attention_dropout"
,
0.1
)
args
.
activation_fn
=
getattr
(
args
,
"activation_fn"
,
"gelu"
)
base_lm_architecture
(
args
)
Speech2S/speech2s/modules/__init__.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
from
.learned_positional_embedding
import
LearnedPositionalEmbedding
from
.multihead_attention
import
MultiheadAttention
from
.relative_pos_enc
import
RelativePositionalEncoding
from
.transformer_layer
import
TransformerEncoderLayerBase
,
TransformerDecoderLayerBase
from
.w2v_encoder
import
TransformerEncoder
,
TransformerSentenceEncoderLayer
from
.transformer_encoder
import
TransformerEncoderBase
from
.transformer_decoder
import
TransformerDecoderScriptable
,
TransformerDecoderBaseScriptable
__all__
=
[
"MultiheadAttention"
,
"RelativePositionalEncoding"
,
"LearnedPositionalEmbedding"
,
"TransformerEncoderLayerBase"
,
"TransformerDecoderLayerBase"
,
"TransformerEncoder"
,
"TransformerSentenceEncoderLayer"
,
"TransformerEncoderBase"
,
"TransformerDecoderScriptable"
,
"TransformerDecoderBaseScriptable"
,
]
Speech2S/speech2s/modules/ctc_prefix_score.py
0 → 100644
View file @
12c90639
#!/usr/bin/env python3
# Copyright 2018 Mitsubishi Electric Research Labs (Takaaki Hori)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
numpy
as
np
import
six
class
CTCPrefixScore
(
object
):
"""Compute CTC label sequence scores
which is based on Algorithm 2 in WATANABE et al.
"HYBRID CTC/ATTENTION ARCHITECTURE FOR END-TO-END SPEECH RECOGNITION,"
but extended to efficiently compute the probablities of multiple labels
simultaneously
"""
def
__init__
(
self
,
x
,
blank
,
eos
,
xp
):
self
.
xp
=
xp
self
.
logzero
=
-
10000000000.0
self
.
blank
=
blank
self
.
eos
=
eos
self
.
input_length
=
len
(
x
)
self
.
x
=
x
def
initial_state
(
self
):
"""Obtain an initial CTC state
:return: CTC state
"""
# initial CTC state is made of a frame x 2 tensor that corresponds to
# r_t^n(<sos>) and r_t^b(<sos>), where 0 and 1 of axis=1 represent
# superscripts n and b (non-blank and blank), respectively.
r
=
self
.
xp
.
full
((
self
.
input_length
,
2
),
self
.
logzero
,
dtype
=
np
.
float32
)
r
[
0
,
1
]
=
self
.
x
[
0
,
self
.
blank
]
for
i
in
six
.
moves
.
range
(
1
,
self
.
input_length
):
r
[
i
,
1
]
=
r
[
i
-
1
,
1
]
+
self
.
x
[
i
,
self
.
blank
]
return
r
def
__call__
(
self
,
y
,
cs
,
r_prev
):
"""Compute CTC prefix scores for next labels
:param y : prefix label sequence
:param cs : array of next labels
:param r_prev: previous CTC state
:return ctc_scores, ctc_states
"""
# initialize CTC states
output_length
=
len
(
y
)
-
1
# ignore sos
# new CTC states are prepared as a frame x (n or b) x n_labels tensor
# that corresponds to r_t^n(h) and r_t^b(h).
r
=
self
.
xp
.
ndarray
((
self
.
input_length
,
2
,
len
(
cs
)),
dtype
=
np
.
float32
)
xs
=
self
.
x
[:,
cs
]
if
output_length
==
0
:
r
[
0
,
0
]
=
xs
[
0
]
r
[
0
,
1
]
=
self
.
logzero
else
:
r
[
output_length
-
1
]
=
self
.
logzero
# prepare forward probabilities for the last label
r_sum
=
self
.
xp
.
logaddexp
(
r_prev
[:,
0
],
r_prev
[:,
1
]
)
# log(r_t^n(g) + r_t^b(g))
last
=
y
[
-
1
]
if
output_length
>
0
and
last
in
cs
:
log_phi
=
self
.
xp
.
ndarray
((
self
.
input_length
,
len
(
cs
)),
dtype
=
np
.
float32
)
for
i
in
six
.
moves
.
range
(
len
(
cs
)):
log_phi
[:,
i
]
=
r_sum
if
cs
[
i
]
!=
last
else
r_prev
[:,
1
]
else
:
log_phi
=
r_sum
# compute forward probabilities log(r_t^n(h)), log(r_t^b(h)),
# and log prefix probabilities log(psi)
start
=
max
(
output_length
,
1
)
log_psi
=
r
[
start
-
1
,
0
]
for
t
in
six
.
moves
.
range
(
start
,
self
.
input_length
):
r
[
t
,
0
]
=
self
.
xp
.
logaddexp
(
r
[
t
-
1
,
0
],
log_phi
[
t
-
1
])
+
xs
[
t
]
r
[
t
,
1
]
=
(
self
.
xp
.
logaddexp
(
r
[
t
-
1
,
0
],
r
[
t
-
1
,
1
])
+
self
.
x
[
t
,
self
.
blank
]
)
log_psi
=
self
.
xp
.
logaddexp
(
log_psi
,
log_phi
[
t
-
1
]
+
xs
[
t
])
# get P(...eos|X) that ends with the prefix itself
eos_pos
=
self
.
xp
.
where
(
cs
==
self
.
eos
)[
0
]
if
len
(
eos_pos
)
>
0
:
log_psi
[
eos_pos
]
=
r_sum
[
-
1
]
# log(r_T^n(g) + r_T^b(g))
# exclude blank probs
blank_pos
=
self
.
xp
.
where
(
cs
==
self
.
blank
)[
0
]
if
len
(
blank_pos
)
>
0
:
log_psi
[
blank_pos
]
=
self
.
logzero
# return the log prefix probability and CTC states, where the label axis
# of the CTC states is moved to the first axis to slice it easily
return
log_psi
,
self
.
xp
.
rollaxis
(
r
,
2
)
Speech2S/speech2s/modules/learned_positional_embedding.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/learned_positional_embedding.py
1. Add clamping if the input length exceeds the max-source-tokens
"""
from
typing
import
Dict
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
torch
import
Tensor
class
LearnedPositionalEmbedding
(
nn
.
Embedding
):
"""
This module learns positional embeddings up to a fixed maximum size.
Padding ids are ignored by either offsetting based on padding_idx
or by setting padding_idx to None and ensuring that the appropriate
position ids are passed to the forward function.
"""
def
__init__
(
self
,
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
):
super
().
__init__
(
num_embeddings
,
embedding_dim
,
padding_idx
)
self
.
onnx_trace
=
False
if
self
.
padding_idx
is
not
None
:
self
.
max_positions
=
self
.
num_embeddings
-
self
.
padding_idx
-
1
else
:
self
.
max_positions
=
self
.
num_embeddings
def
forward
(
self
,
input
:
Tensor
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
positions
:
Optional
[
Tensor
]
=
None
,
):
"""Input is expected to be of size [bsz x seqlen]."""
assert
(
positions
is
None
)
or
(
self
.
padding_idx
is
None
),
"If positions is pre-computed then padding_idx should not be set."
if
positions
is
None
:
if
incremental_state
is
not
None
:
# positions is the same for every token when decoding a single step
# Without the int() cast, it doesn't work in some cases when exporting to ONNX
positions
=
torch
.
zeros
(
(
1
,
1
),
device
=
input
.
device
,
dtype
=
input
.
dtype
).
fill_
(
int
(
self
.
padding_idx
+
input
.
size
(
1
)))
else
:
positions
=
utils
.
make_positions
(
input
,
self
.
padding_idx
,
onnx_trace
=
self
.
onnx_trace
)
positions
=
torch
.
clamp
(
positions
,
max
=
self
.
padding_idx
+
self
.
max_positions
)
return
F
.
embedding
(
positions
,
self
.
weight
,
self
.
padding_idx
,
self
.
max_norm
,
self
.
norm_type
,
self
.
scale_grad_by_freq
,
self
.
sparse
,
)
Speech2S/speech2s/modules/multihead_attention.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
from
typing
import
Dict
,
Optional
,
Tuple
import
torch
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
torch
import
Tensor
from
fairseq.modules
import
MultiheadAttention
as
FairseqMultiheadAttention
class
MultiheadAttention
(
FairseqMultiheadAttention
):
"""Multi-headed attention.
See "Attention Is All You Need" for more details.
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
kdim
=
None
,
vdim
=
None
,
dropout
=
0.0
,
bias
=
True
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
self_attention
=
False
,
encoder_decoder_attention
=
False
,
q_noise
=
0.0
,
qn_block_size
=
8
,
scaling_for_att
=
1.0
):
super
().
__init__
(
embed_dim
,
num_heads
,
kdim
,
vdim
,
dropout
,
bias
,
add_bias_kv
,
add_zero_attn
,
self_attention
,
encoder_decoder_attention
,
q_noise
,
qn_block_size
,
)
self
.
scaling_for_att
=
scaling_for_att
def
forward
(
self
,
query
,
key
:
Optional
[
Tensor
],
value
:
Optional
[
Tensor
],
key_padding_mask
:
Optional
[
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
need_weights
:
bool
=
True
,
static_kv
:
bool
=
False
,
attn_mask
:
Optional
[
Tensor
]
=
None
,
before_softmax
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
position_bias
:
Optional
[
Tensor
]
=
None
,
)
->
Tuple
[
Tensor
,
Optional
[
Tensor
]]:
"""Input shape: Time x Batch x Channel
Args:
key_padding_mask (ByteTensor, optional): mask to exclude
keys that are pads, of shape `(batch, src_len)`, where
padding elements are indicated by 1s.
need_weights (bool, optional): return the attention weights,
averaged over heads (default: False).
attn_mask (ByteTensor, optional): typically used to
implement causal attention, where the mask prevents the
attention from looking forward in time (default: None).
before_softmax (bool, optional): return the raw attention
weights and values before the attention softmax.
need_head_weights (bool, optional): return the attention
weights for each head. Implies *need_weights*. Default:
return the average attention weights over all heads.
"""
if
need_head_weights
:
need_weights
=
True
is_tpu
=
query
.
device
.
type
==
"xla"
tgt_len
,
bsz
,
embed_dim
=
query
.
size
()
src_len
=
tgt_len
assert
embed_dim
==
self
.
embed_dim
,
f
"query dim
{
embed_dim
}
!=
{
self
.
embed_dim
}
"
assert
list
(
query
.
size
())
==
[
tgt_len
,
bsz
,
embed_dim
]
if
key
is
not
None
:
src_len
,
key_bsz
,
_
=
key
.
size
()
if
not
torch
.
jit
.
is_scripting
():
assert
key_bsz
==
bsz
assert
value
is
not
None
assert
src_len
,
bsz
==
value
.
shape
[:
2
]
if
(
not
self
.
onnx_trace
and
not
is_tpu
# don't use PyTorch version on TPUs
and
incremental_state
is
None
and
not
static_kv
# A workaround for quantization to work. Otherwise JIT compilation
# treats bias in linear module as method.
and
not
torch
.
jit
.
is_scripting
()
and
position_bias
is
None
):
assert
key
is
not
None
and
value
is
not
None
return
F
.
multi_head_attention_forward
(
query
,
key
,
value
,
self
.
embed_dim
,
self
.
num_heads
,
torch
.
empty
([
0
]),
torch
.
cat
((
self
.
q_proj
.
bias
,
self
.
k_proj
.
bias
,
self
.
v_proj
.
bias
)),
self
.
bias_k
,
self
.
bias_v
,
self
.
add_zero_attn
,
self
.
dropout_module
.
p
,
self
.
out_proj
.
weight
,
self
.
out_proj
.
bias
,
self
.
training
or
self
.
dropout_module
.
apply_during_inference
,
key_padding_mask
,
need_weights
,
attn_mask
,
use_separate_proj_weight
=
True
,
q_proj_weight
=
self
.
q_proj
.
weight
,
k_proj_weight
=
self
.
k_proj
.
weight
,
v_proj_weight
=
self
.
v_proj
.
weight
,
)
if
incremental_state
is
not
None
:
saved_state
=
self
.
_get_input_buffer
(
incremental_state
)
if
saved_state
is
not
None
and
"prev_key"
in
saved_state
:
# previous time steps are cached - no need to recompute
# key and value if they are static
if
static_kv
:
assert
self
.
encoder_decoder_attention
and
not
self
.
self_attention
key
=
value
=
None
else
:
saved_state
=
None
if
self
.
self_attention
:
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
query
)
v
=
self
.
v_proj
(
query
)
elif
self
.
encoder_decoder_attention
:
# encoder-decoder attention
q
=
self
.
q_proj
(
query
)
if
key
is
None
:
assert
value
is
None
k
=
v
=
None
else
:
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
key
)
else
:
assert
key
is
not
None
and
value
is
not
None
q
=
self
.
q_proj
(
query
)
k
=
self
.
k_proj
(
key
)
v
=
self
.
v_proj
(
value
)
q
*=
self
.
scaling
q
*=
(
1
/
self
.
scaling_for_att
)
if
self
.
bias_k
is
not
None
:
assert
self
.
bias_v
is
not
None
k
=
torch
.
cat
([
k
,
self
.
bias_k
.
repeat
(
1
,
bsz
,
1
)])
v
=
torch
.
cat
([
v
,
self
.
bias_v
.
repeat
(
1
,
bsz
,
1
)])
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
key_padding_mask
.
new_zeros
(
key_padding_mask
.
size
(
0
),
1
),
],
dim
=
1
,
)
q
=
(
q
.
contiguous
()
.
view
(
tgt_len
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
k
is
not
None
:
k
=
(
k
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
v
is
not
None
:
v
=
(
v
.
contiguous
()
.
view
(
-
1
,
bsz
*
self
.
num_heads
,
self
.
head_dim
)
.
transpose
(
0
,
1
)
)
if
saved_state
is
not
None
:
# saved states are stored with shape (bsz, num_heads, seq_len, head_dim)
if
"prev_key"
in
saved_state
:
_prev_key
=
saved_state
[
"prev_key"
]
assert
_prev_key
is
not
None
prev_key
=
_prev_key
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
k
=
prev_key
else
:
assert
k
is
not
None
k
=
torch
.
cat
([
prev_key
,
k
],
dim
=
1
)
src_len
=
k
.
size
(
1
)
if
"prev_value"
in
saved_state
:
_prev_value
=
saved_state
[
"prev_value"
]
assert
_prev_value
is
not
None
prev_value
=
_prev_value
.
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
)
if
static_kv
:
v
=
prev_value
else
:
assert
v
is
not
None
v
=
torch
.
cat
([
prev_value
,
v
],
dim
=
1
)
prev_key_padding_mask
:
Optional
[
Tensor
]
=
None
if
"prev_key_padding_mask"
in
saved_state
:
prev_key_padding_mask
=
saved_state
[
"prev_key_padding_mask"
]
assert
k
is
not
None
and
v
is
not
None
key_padding_mask
=
MultiheadAttention
.
_append_prev_key_padding_mask
(
key_padding_mask
=
key_padding_mask
,
prev_key_padding_mask
=
prev_key_padding_mask
,
batch_size
=
bsz
,
src_len
=
k
.
size
(
1
),
static_kv
=
static_kv
,
)
saved_state
[
"prev_key"
]
=
k
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_value"
]
=
v
.
view
(
bsz
,
self
.
num_heads
,
-
1
,
self
.
head_dim
)
saved_state
[
"prev_key_padding_mask"
]
=
key_padding_mask
# In this branch incremental_state is never None
assert
incremental_state
is
not
None
incremental_state
=
self
.
_set_input_buffer
(
incremental_state
,
saved_state
)
assert
k
is
not
None
assert
k
.
size
(
1
)
==
src_len
# This is part of a workaround to get around fork/join parallelism
# not supporting Optional types.
if
key_padding_mask
is
not
None
and
key_padding_mask
.
dim
()
==
0
:
key_padding_mask
=
None
if
key_padding_mask
is
not
None
:
assert
key_padding_mask
.
size
(
0
)
==
bsz
assert
key_padding_mask
.
size
(
1
)
==
src_len
if
self
.
add_zero_attn
:
assert
v
is
not
None
src_len
+=
1
k
=
torch
.
cat
([
k
,
k
.
new_zeros
((
k
.
size
(
0
),
1
)
+
k
.
size
()[
2
:])],
dim
=
1
)
v
=
torch
.
cat
([
v
,
v
.
new_zeros
((
v
.
size
(
0
),
1
)
+
v
.
size
()[
2
:])],
dim
=
1
)
if
attn_mask
is
not
None
:
attn_mask
=
torch
.
cat
(
[
attn_mask
,
attn_mask
.
new_zeros
(
attn_mask
.
size
(
0
),
1
)],
dim
=
1
)
if
key_padding_mask
is
not
None
:
key_padding_mask
=
torch
.
cat
(
[
key_padding_mask
,
torch
.
zeros
(
key_padding_mask
.
size
(
0
),
1
).
type_as
(
key_padding_mask
),
],
dim
=
1
,
)
attn_weights
=
torch
.
bmm
(
q
,
k
.
transpose
(
1
,
2
))
attn_weights
=
self
.
apply_sparse_mask
(
attn_weights
,
tgt_len
,
src_len
,
bsz
)
if
position_bias
is
not
None
:
## first order
## position_bias: [241, 241, 64]
#print ("attn_weights: ", attn_weights.size()) # [492, 241, 241]
reshape_q
=
q
.
contiguous
().
view
(
bsz
*
self
.
num_heads
,
-
1
,
self
.
head_dim
).
transpose
(
0
,
1
)
#[241, 492, 64]
#print ("reshape_q: ", reshape_q.size())
B
=
torch
.
matmul
(
reshape_q
,
position_bias
.
transpose
(
-
2
,
-
1
))
#print ("B: ", B.size()) ## [241, 492, 241]
#B = B.transpose(0, 1).view(bsz, self.num_heads, position_bias.size(0), position_bias.size(1))
B
=
B
.
transpose
(
0
,
1
).
view
(
bsz
*
self
.
num_heads
,
position_bias
.
size
(
0
),
position_bias
.
size
(
1
))
#print ("B 2: ", B.size())
attn_weights
+=
B
attn_weights
*=
self
.
scaling_for_att
assert
list
(
attn_weights
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
]
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
unsqueeze
(
0
)
if
self
.
onnx_trace
:
attn_mask
=
attn_mask
.
repeat
(
attn_weights
.
size
(
0
),
1
,
1
)
attn_weights
+=
attn_mask
if
key_padding_mask
is
not
None
:
# don't attend to padding symbols
attn_weights
=
attn_weights
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
)
if
not
is_tpu
:
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
.
unsqueeze
(
1
).
unsqueeze
(
2
).
to
(
torch
.
bool
),
float
(
"-inf"
),
)
else
:
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
masked_fill
(
key_padding_mask
,
float
(
"-inf"
))
attn_weights
=
attn_weights
.
transpose
(
0
,
2
)
attn_weights
=
attn_weights
.
view
(
bsz
*
self
.
num_heads
,
tgt_len
,
src_len
)
if
self
.
scaling_for_att
>
1.0
:
attn_weights
=
attn_weights
-
attn_weights
.
detach
().
max
(
dim
=-
1
,
keepdim
=
True
)[
0
]
if
before_softmax
:
return
attn_weights
,
v
attn_weights_float
=
utils
.
softmax
(
attn_weights
,
dim
=-
1
,
onnx_trace
=
self
.
onnx_trace
)
attn_weights
=
attn_weights_float
.
type_as
(
attn_weights
)
attn_probs
=
self
.
dropout_module
(
attn_weights
)
assert
v
is
not
None
attn
=
torch
.
bmm
(
attn_probs
,
v
)
assert
list
(
attn
.
size
())
==
[
bsz
*
self
.
num_heads
,
tgt_len
,
self
.
head_dim
]
if
self
.
onnx_trace
and
attn
.
size
(
1
)
==
1
:
# when ONNX tracing a single decoder step (sequence length == 1)
# the transpose is a no-op copy before view, thus unnecessary
attn
=
attn
.
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
else
:
attn
=
attn
.
transpose
(
0
,
1
).
contiguous
().
view
(
tgt_len
,
bsz
,
embed_dim
)
attn
=
self
.
out_proj
(
attn
)
attn_weights
:
Optional
[
Tensor
]
=
None
if
need_weights
:
attn_weights
=
attn_weights_float
.
view
(
bsz
,
self
.
num_heads
,
tgt_len
,
src_len
).
transpose
(
1
,
0
)
if
not
need_head_weights
:
# average attention weights over heads
attn_weights
=
attn_weights
.
mean
(
dim
=
0
)
return
attn
,
attn_weights
Speech2S/speech2s/modules/relative_pos_enc.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import
torch
class
RelativePositionalEncoding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
d_model
,
maxlen
=
1000
,
embed_v
=
False
):
super
(
RelativePositionalEncoding
,
self
).
__init__
()
self
.
d_model
=
d_model
self
.
maxlen
=
maxlen
self
.
pe_k
=
torch
.
nn
.
Embedding
(
2
*
maxlen
,
d_model
)
if
embed_v
:
self
.
pe_v
=
torch
.
nn
.
Embedding
(
2
*
maxlen
,
d_model
)
self
.
embed_v
=
embed_v
def
forward
(
self
,
pos_seq
,
incremental_state
=
None
):
pos_seq
[
pos_seq
<
-
self
.
maxlen
]
=
-
self
.
maxlen
pos_seq
[
pos_seq
>=
self
.
maxlen
]
=
self
.
maxlen
-
1
pos_seq
=
pos_seq
+
self
.
maxlen
if
incremental_state
is
not
None
:
pos_seq
=
pos_seq
[
-
1
:]
if
self
.
embed_v
:
return
self
.
pe_k
(
pos_seq
),
self
.
pe_v
(
pos_seq
)
else
:
return
self
.
pe_k
(
pos_seq
),
None
Speech2S/speech2s/modules/transformer_decoder.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/transformer/transformer_decoder.py
"""
import
math
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
from
fairseq
import
utils
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models
import
FairseqIncrementalDecoder
from
fairseq.models.transformer
import
TransformerConfig
from
fairseq.modules
import
(
AdaptiveSoftmax
,
BaseLayer
,
FairseqDropout
,
LayerDropModuleList
,
LayerNorm
,
PositionalEmbedding
,
SinusoidalPositionalEmbedding
,
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
torch
import
Tensor
from
speechut.modules
import
transformer_layer
from
speechut.modules
import
RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def
module_name_fordropout
(
module_name
:
str
)
->
str
:
if
module_name
==
"TransformerDecoderBase"
:
return
"TransformerDecoder"
else
:
return
module_name
class
TransformerDecoderBase
(
FairseqIncrementalDecoder
):
"""
Transformer decoder consisting of *cfg.decoder.layers* layers. Each layer
is a :class:`TransformerDecoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): decoding dictionary
embed_tokens (torch.nn.Embedding): output embedding
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
cfg
,
dictionary
,
embed_tokens
,
no_encoder_attn
=
False
,
output_projection
=
None
,
use_rel_pos_enc
=
False
,
):
self
.
cfg
=
cfg
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
self
.
_future_mask
=
torch
.
empty
(
0
)
self
.
dropout_module
=
FairseqDropout
(
cfg
.
dropout
,
module_name
=
module_name_fordropout
(
self
.
__class__
.
__name__
)
)
self
.
decoder_layerdrop
=
cfg
.
decoder
.
layerdrop
self
.
share_input_output_embed
=
cfg
.
share_decoder_input_output_embed
input_embed_dim
=
embed_tokens
.
embedding_dim
embed_dim
=
cfg
.
decoder
.
embed_dim
self
.
embed_dim
=
embed_dim
self
.
output_embed_dim
=
cfg
.
decoder
.
output_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_target_positions
=
cfg
.
max_target_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
1.0
if
cfg
.
no_scale_embedding
else
math
.
sqrt
(
embed_dim
)
if
not
cfg
.
adaptive_input
and
cfg
.
quant_noise
.
pq
>
0
:
self
.
quant_noise
=
apply_quant_noise_
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
),
cfg
.
quant_noise
.
pq
,
cfg
.
quant_noise
.
pq_block_size
,
)
else
:
self
.
quant_noise
=
None
self
.
project_in_dim
=
(
Linear
(
input_embed_dim
,
embed_dim
,
bias
=
False
)
if
embed_dim
!=
input_embed_dim
else
None
)
self
.
embed_positions
=
(
PositionalEmbedding
(
self
.
max_target_positions
,
embed_dim
,
self
.
padding_idx
,
learned
=
cfg
.
decoder
.
learned_pos
,
)
if
not
cfg
.
no_token_positional_embeddings
else
None
)
if
cfg
.
layernorm_embedding
:
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layernorm_embedding
=
None
self
.
cross_self_attention
=
cfg
.
cross_self_attention
if
self
.
decoder_layerdrop
>
0.0
:
self
.
layers
=
LayerDropModuleList
(
p
=
self
.
decoder_layerdrop
)
else
:
self
.
layers
=
nn
.
ModuleList
([])
self
.
use_rel_pos_enc
=
use_rel_pos_enc
self
.
layers
.
extend
(
[
self
.
build_decoder_layer
(
cfg
,
no_encoder_attn
)
for
_
in
range
(
cfg
.
decoder
.
layers
)
]
)
self
.
num_layers
=
len
(
self
.
layers
)
if
cfg
.
decoder
.
normalize_before
and
not
cfg
.
no_decoder_final_norm
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layer_norm
=
None
self
.
project_out_dim
=
(
Linear
(
embed_dim
,
self
.
output_embed_dim
,
bias
=
False
)
if
embed_dim
!=
self
.
output_embed_dim
and
not
cfg
.
tie_adaptive_weights
else
None
)
self
.
adaptive_softmax
=
None
self
.
output_projection
=
output_projection
if
self
.
output_projection
is
None
:
self
.
build_output_projection
(
cfg
,
dictionary
,
embed_tokens
)
if
self
.
use_rel_pos_enc
:
self
.
pos_emb
=
RelativePositionalEncoding
(
embed_dim
//
cfg
.
decoder
.
attention_heads
,
24
)
def
build_output_projection
(
self
,
cfg
,
dictionary
,
embed_tokens
):
if
cfg
.
adaptive_softmax_cutoff
is
not
None
:
self
.
adaptive_softmax
=
AdaptiveSoftmax
(
len
(
dictionary
),
self
.
output_embed_dim
,
utils
.
eval_str_list
(
cfg
.
adaptive_softmax_cutoff
,
type
=
int
),
dropout
=
cfg
.
adaptive_softmax_dropout
,
adaptive_inputs
=
embed_tokens
if
cfg
.
tie_adaptive_weights
else
None
,
factor
=
cfg
.
adaptive_softmax_factor
,
tie_proj
=
cfg
.
tie_adaptive_proj
,
)
elif
self
.
share_input_output_embed
:
self
.
output_projection
=
nn
.
Linear
(
self
.
embed_tokens
.
weight
.
shape
[
1
],
self
.
embed_tokens
.
weight
.
shape
[
0
],
bias
=
False
,
)
self
.
output_projection
.
weight
=
self
.
embed_tokens
.
weight
else
:
self
.
output_projection
=
nn
.
Linear
(
self
.
output_embed_dim
,
len
(
dictionary
),
bias
=
False
)
nn
.
init
.
normal_
(
self
.
output_projection
.
weight
,
mean
=
0
,
std
=
self
.
output_embed_dim
**
-
0.5
)
num_base_layers
=
cfg
.
base_layers
for
i
in
range
(
num_base_layers
):
self
.
layers
.
insert
(
((
i
+
1
)
*
cfg
.
decoder
.
layers
)
//
(
num_base_layers
+
1
),
BaseLayer
(
cfg
),
)
def
build_decoder_layer
(
self
,
cfg
,
no_encoder_attn
=
False
):
layer
=
transformer_layer
.
TransformerDecoderLayerBase
(
cfg
,
no_encoder_attn
,
has_relative_attention_bias
=
self
.
use_rel_pos_enc
)
checkpoint
=
cfg
.
checkpoint_activations
if
checkpoint
:
offload_to_cpu
=
cfg
.
offload_activations
layer
=
checkpoint_wrapper
(
layer
,
offload_to_cpu
=
offload_to_cpu
)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap
=
cfg
.
min_params_to_wrap
if
not
checkpoint
else
0
layer
=
fsdp_wrap
(
layer
,
min_num_params
=
min_params_to_wrap
)
return
layer
def
forward
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
features_only
:
bool
=
False
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
src_lengths
:
Optional
[
Any
]
=
None
,
return_all_hiddens
:
bool
=
False
,
):
"""
Args:
prev_output_tokens (LongTensor): previous decoder outputs of shape
`(batch, tgt_len)`, for teacher forcing
encoder_out (optional): output from the encoder, used for
encoder-side attention, should be of size T x B x C
incremental_state (dict): dictionary used for storing state during
:ref:`Incremental decoding`
features_only (bool, optional): only return features without
applying output layer (default: False).
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
Returns:
tuple:
- the decoder's output of shape `(batch, tgt_len, vocab)`
- a dictionary with any model-specific outputs
"""
x
,
extra
=
self
.
extract_features
(
prev_output_tokens
,
encoder_out
=
encoder_out
,
incremental_state
=
incremental_state
,
full_context_alignment
=
full_context_alignment
,
alignment_layer
=
alignment_layer
,
alignment_heads
=
alignment_heads
,
)
if
not
features_only
:
x
=
self
.
output_layer
(
x
)
return
x
,
extra
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
return
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
"""
A scriptable subclass of this class has an extract_features method and calls
super().extract_features, but super() is not supported in torchscript. A copy of
this function is made to be used in the subclass instead.
"""
def
extract_features_scriptable
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]],
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
"""
Similar to *forward* but only return features.
Includes several features from "Jointly Learning to Align and
Translate with Transformer Models" (Garg et al., EMNLP 2019).
Args:
full_context_alignment (bool, optional): don't apply
auto-regressive mask to self-attention (default: False).
alignment_layer (int, optional): return mean alignment over
heads at this layer (default: last layer).
alignment_heads (int, optional): only average alignment over
this many heads (default: all heads).
Returns:
tuple:
- the decoder's features of shape `(batch, tgt_len, embed_dim)`
- a dictionary with any model-specific outputs
"""
bs
,
slen
=
prev_output_tokens
.
size
()
if
alignment_layer
is
None
:
alignment_layer
=
self
.
num_layers
-
1
enc
:
Optional
[
Tensor
]
=
None
padding_mask
:
Optional
[
Tensor
]
=
None
if
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_out"
])
>
0
:
enc
=
encoder_out
[
"encoder_out"
][
0
]
assert
(
enc
.
size
()[
1
]
==
bs
),
f
"Expected enc.shape == (t,
{
bs
}
, c) got
{
enc
.
shape
}
"
if
encoder_out
is
not
None
and
len
(
encoder_out
[
"encoder_padding_mask"
])
>
0
:
padding_mask
=
encoder_out
[
"encoder_padding_mask"
][
0
]
# embed positions
positions
=
None
if
self
.
embed_positions
is
not
None
:
positions
=
self
.
embed_positions
(
prev_output_tokens
,
incremental_state
=
incremental_state
)
if
incremental_state
is
not
None
:
prev_output_tokens
=
prev_output_tokens
[:,
-
1
:]
if
positions
is
not
None
:
positions
=
positions
[:,
-
1
:]
# embed tokens and positions
x
=
self
.
embed_scale
*
self
.
embed_tokens
(
prev_output_tokens
)
if
self
.
quant_noise
is
not
None
:
x
=
self
.
quant_noise
(
x
)
if
self
.
project_in_dim
is
not
None
:
x
=
self
.
project_in_dim
(
x
)
if
positions
is
not
None
:
x
+=
positions
if
self
.
layernorm_embedding
is
not
None
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
dropout_module
(
x
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
use_rel_pos_enc
:
pos_seq
=
torch
.
arange
(
0
,
slen
).
long
().
to
(
x
.
device
)
pos_seq
=
pos_seq
[:,
None
]
-
pos_seq
[
None
,
:]
pos_k
,
_
=
self
.
pos_emb
(
pos_seq
,
incremental_state
)
else
:
pos_k
=
None
self_attn_padding_mask
:
Optional
[
Tensor
]
=
None
if
self
.
cross_self_attention
or
prev_output_tokens
.
eq
(
self
.
padding_idx
).
any
():
self_attn_padding_mask
=
prev_output_tokens
.
eq
(
self
.
padding_idx
)
# decoder layers
attn
:
Optional
[
Tensor
]
=
None
inner_states
:
List
[
Optional
[
Tensor
]]
=
[
x
]
for
idx
,
layer
in
enumerate
(
self
.
layers
):
if
incremental_state
is
None
and
not
full_context_alignment
:
self_attn_mask
=
self
.
buffered_future_mask
(
x
)
else
:
self_attn_mask
=
None
x
,
layer_attn
,
_
=
layer
(
x
,
enc
,
padding_mask
,
incremental_state
,
self_attn_mask
=
self_attn_mask
,
self_attn_padding_mask
=
self_attn_padding_mask
,
need_attn
=
bool
((
idx
==
alignment_layer
)),
need_head_weights
=
bool
((
idx
==
alignment_layer
)),
pos_bias
=
pos_k
,
)
inner_states
.
append
(
x
)
if
layer_attn
is
not
None
and
idx
==
alignment_layer
:
attn
=
layer_attn
.
float
().
to
(
x
)
if
attn
is
not
None
:
if
alignment_heads
is
not
None
:
attn
=
attn
[:
alignment_heads
]
# average probabilities over heads
attn
=
attn
.
mean
(
dim
=
0
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
project_out_dim
is
not
None
:
x
=
self
.
project_out_dim
(
x
)
return
x
,
{
"attn"
:
[
attn
],
"inner_states"
:
inner_states
}
def
output_layer
(
self
,
features
):
"""Project features to the vocabulary size."""
if
self
.
adaptive_softmax
is
None
:
# project back to size of vocabulary
return
self
.
output_projection
(
features
)
else
:
return
features
def
max_positions
(
self
):
"""Maximum output length supported by the decoder."""
if
self
.
embed_positions
is
None
:
return
self
.
max_target_positions
return
min
(
self
.
max_target_positions
,
self
.
embed_positions
.
max_positions
)
def
buffered_future_mask
(
self
,
tensor
):
dim
=
tensor
.
size
(
0
)
# self._future_mask.device != tensor.device is not working in TorchScript. This is a workaround.
if
(
self
.
_future_mask
.
size
(
0
)
==
0
or
(
not
self
.
_future_mask
.
device
==
tensor
.
device
)
or
self
.
_future_mask
.
size
(
0
)
<
dim
):
self
.
_future_mask
=
torch
.
triu
(
utils
.
fill_with_neg_inf
(
torch
.
zeros
([
dim
,
dim
])),
1
)
self
.
_future_mask
=
self
.
_future_mask
.
to
(
tensor
)
return
self
.
_future_mask
[:
dim
,
:
dim
]
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
weights_key
=
"{}.embed_positions.weights"
.
format
(
name
)
if
weights_key
in
state_dict
:
del
state_dict
[
weights_key
]
state_dict
[
"{}.embed_positions._float_tensor"
.
format
(
name
)
]
=
torch
.
FloatTensor
(
1
)
if
f
"
{
name
}
.output_projection.weight"
not
in
state_dict
:
if
self
.
share_input_output_embed
:
embed_out_key
=
f
"
{
name
}
.embed_tokens.weight"
else
:
embed_out_key
=
f
"
{
name
}
.embed_out"
if
embed_out_key
in
state_dict
:
state_dict
[
f
"
{
name
}
.output_projection.weight"
]
=
state_dict
[
embed_out_key
]
if
not
self
.
share_input_output_embed
:
del
state_dict
[
embed_out_key
]
for
i
in
range
(
self
.
num_layers
):
# update layer norms
layer_norm_map
=
{
"0"
:
"self_attn_layer_norm"
,
"1"
:
"encoder_attn_layer_norm"
,
"2"
:
"final_layer_norm"
,
}
for
old
,
new
in
layer_norm_map
.
items
():
for
m
in
(
"weight"
,
"bias"
):
k
=
"{}.layers.{}.layer_norms.{}.{}"
.
format
(
name
,
i
,
old
,
m
)
if
k
in
state_dict
:
state_dict
[
"{}.layers.{}.{}.{}"
.
format
(
name
,
i
,
new
,
m
)
]
=
state_dict
[
k
]
del
state_dict
[
k
]
version_key
=
"{}.version"
.
format
(
name
)
if
utils
.
item
(
state_dict
.
get
(
version_key
,
torch
.
Tensor
([
1
]))[
0
])
<=
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
version_key
]
=
torch
.
Tensor
([
1
])
return
state_dict
def
Linear
(
in_features
,
out_features
,
bias
=
True
):
m
=
nn
.
Linear
(
in_features
,
out_features
,
bias
)
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
if
bias
:
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
return
m
class
TransformerDecoderBaseScriptable
(
TransformerDecoderBase
):
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
# call scriptable method from parent class
x
,
_
=
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
return
x
,
None
class
TransformerDecoder
(
TransformerDecoderBase
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
,
no_encoder_attn
=
False
,
output_projection
=
None
,
):
self
.
args
=
args
super
().
__init__
(
TransformerConfig
.
from_namespace
(
args
),
dictionary
,
embed_tokens
,
no_encoder_attn
=
no_encoder_attn
,
output_projection
=
output_projection
,
use_rel_pos_enc
=
getattr
(
args
,
"use_rel_pos_enc"
,
False
),
)
def
build_output_projection
(
self
,
args
,
dictionary
,
embed_tokens
):
super
().
build_output_projection
(
TransformerConfig
.
from_namespace
(
args
),
dictionary
,
embed_tokens
)
def
build_decoder_layer
(
self
,
args
,
no_encoder_attn
=
False
):
return
super
().
build_decoder_layer
(
TransformerConfig
.
from_namespace
(
args
),
no_encoder_attn
=
no_encoder_attn
)
class
TransformerDecoderScriptable
(
TransformerDecoder
):
def
extract_features
(
self
,
prev_output_tokens
,
encoder_out
:
Optional
[
Dict
[
str
,
List
[
Tensor
]]]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
full_context_alignment
:
bool
=
False
,
alignment_layer
:
Optional
[
int
]
=
None
,
alignment_heads
:
Optional
[
int
]
=
None
,
):
# call scriptable method from parent class
x
,
_
=
self
.
extract_features_scriptable
(
prev_output_tokens
,
encoder_out
,
incremental_state
,
full_context_alignment
,
alignment_layer
,
alignment_heads
,
)
return
x
,
None
Speech2S/speech2s/modules/transformer_encoder.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
import
math
from
typing
import
Dict
,
List
,
Optional
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models
import
FairseqEncoder
from
fairseq.modules
import
(
FairseqDropout
,
LayerDropModuleList
,
LayerNorm
,
SinusoidalPositionalEmbedding
,
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.quant_noise
import
quant_noise
as
apply_quant_noise_
from
torch
import
Tensor
from
fairseq.models.transformer
import
(
TransformerConfig
,
)
from
speechut.modules
import
transformer_layer
,
LearnedPositionalEmbedding
from
speechut.modules
import
RelativePositionalEncoding
# rewrite name for backward compatibility in `make_generation_fast_`
def
module_name_fordropout
(
module_name
:
str
)
->
str
:
if
module_name
==
"TransformerEncoderBase"
:
return
"TransformerEncoder"
else
:
return
module_name
class
TransformerEncoderBase
(
FairseqEncoder
):
"""
Transformer encoder consisting of *cfg.encoder.layers* layers. Each layer
is a :class:`TransformerEncoderLayer`.
Args:
args (argparse.Namespace): parsed command-line arguments
dictionary (~fairseq.data.Dictionary): encoding dictionary
embed_tokens (torch.nn.Embedding): input embedding
"""
def
__init__
(
self
,
cfg
,
dictionary
,
embed_tokens
,
use_rel_pos_enc
=
False
,
scaling_for_att
=
1.0
):
self
.
cfg
=
cfg
super
().
__init__
(
dictionary
)
self
.
register_buffer
(
"version"
,
torch
.
Tensor
([
3
]))
self
.
dropout_module
=
FairseqDropout
(
cfg
.
dropout
,
module_name
=
module_name_fordropout
(
self
.
__class__
.
__name__
)
)
self
.
encoder_layerdrop
=
cfg
.
encoder
.
layerdrop
embed_dim
=
embed_tokens
.
embedding_dim
self
.
padding_idx
=
embed_tokens
.
padding_idx
self
.
max_source_positions
=
cfg
.
max_source_positions
self
.
embed_tokens
=
embed_tokens
self
.
embed_scale
=
1.0
if
cfg
.
no_scale_embedding
else
math
.
sqrt
(
embed_dim
)
self
.
embed_positions
=
(
PositionalEmbedding
(
cfg
.
max_source_positions
,
embed_dim
,
self
.
padding_idx
,
learned
=
cfg
.
encoder
.
learned_pos
,
)
if
not
cfg
.
no_token_positional_embeddings
else
None
)
if
cfg
.
layernorm_embedding
:
self
.
layernorm_embedding
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layernorm_embedding
=
None
if
not
cfg
.
adaptive_input
and
cfg
.
quant_noise
.
pq
>
0
:
self
.
quant_noise
=
apply_quant_noise_
(
nn
.
Linear
(
embed_dim
,
embed_dim
,
bias
=
False
),
cfg
.
quant_noise
.
pq
,
cfg
.
quant_noise
.
pq_block_size
,
)
else
:
self
.
quant_noise
=
None
if
self
.
encoder_layerdrop
>
0.0
:
self
.
layers
=
LayerDropModuleList
(
p
=
self
.
encoder_layerdrop
)
else
:
self
.
layers
=
nn
.
ModuleList
([])
self
.
use_rel_pos_enc
=
use_rel_pos_enc
self
.
scaling_for_att
=
scaling_for_att
self
.
layers
.
extend
(
[
self
.
build_encoder_layer
(
cfg
)
for
i
in
range
(
cfg
.
encoder
.
layers
)]
)
self
.
num_layers
=
len
(
self
.
layers
)
if
cfg
.
encoder
.
normalize_before
:
self
.
layer_norm
=
LayerNorm
(
embed_dim
,
export
=
cfg
.
export
)
else
:
self
.
layer_norm
=
None
if
self
.
use_rel_pos_enc
:
self
.
pos_emb
=
RelativePositionalEncoding
(
embed_dim
//
cfg
.
encoder
.
attention_heads
,
160
)
def
build_encoder_layer
(
self
,
cfg
):
layer
=
transformer_layer
.
TransformerEncoderLayerBase
(
cfg
,
has_relative_attention_bias
=
self
.
use_rel_pos_enc
,
scaling_for_att
=
self
.
scaling_for_att
)
checkpoint
=
cfg
.
checkpoint_activations
if
checkpoint
:
offload_to_cpu
=
cfg
.
offload_activations
layer
=
checkpoint_wrapper
(
layer
,
offload_to_cpu
=
offload_to_cpu
)
# if we are checkpointing, enforce that FSDP always wraps the
# checkpointed layer, regardless of layer size
min_params_to_wrap
=
cfg
.
min_params_to_wrap
if
not
checkpoint
else
0
layer
=
fsdp_wrap
(
layer
,
min_num_params
=
min_params_to_wrap
)
return
layer
def
forward_embedding
(
self
,
src_tokens
,
token_embedding
:
Optional
[
torch
.
Tensor
]
=
None
):
# embed tokens and positions
if
token_embedding
is
None
:
token_embedding
=
self
.
embed_tokens
(
src_tokens
)
x
=
embed
=
self
.
embed_scale
*
token_embedding
if
self
.
embed_positions
is
not
None
:
x
=
embed
+
self
.
embed_positions
(
src_tokens
)
if
self
.
layernorm_embedding
is
not
None
:
x
=
self
.
layernorm_embedding
(
x
)
x
=
self
.
dropout_module
(
x
)
if
self
.
quant_noise
is
not
None
:
x
=
self
.
quant_noise
(
x
)
return
x
,
embed
def
forward
(
self
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
uniformity_layers
:
Optional
[
List
[
int
]]
=
None
,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
return
self
.
forward_scriptable
(
src_tokens
,
src_lengths
,
return_all_hiddens
,
token_embeddings
,
uniformity_layers
)
# TorchScript doesn't support super() method so that the scriptable Subclass
# can't access the base class model in Torchscript.
# Current workaround is to add a helper function with different name and
# call the helper function from scriptable Subclass.
def
forward_scriptable
(
self
,
src_tokens
,
src_lengths
:
Optional
[
torch
.
Tensor
]
=
None
,
return_all_hiddens
:
bool
=
False
,
token_embeddings
:
Optional
[
torch
.
Tensor
]
=
None
,
uniformity_layers
:
Optional
[
List
[
int
]]
=
None
,
):
"""
Args:
src_tokens (LongTensor): tokens in the source language of shape
`(batch, src_len)`
src_lengths (torch.LongTensor): lengths of each source sentence of
shape `(batch)`
return_all_hiddens (bool, optional): also return all of the
intermediate hidden states (default: False).
token_embeddings (torch.Tensor, optional): precomputed embeddings
default `None` will recompute embeddings
Returns:
dict:
- **encoder_out** (Tensor): the last encoder layer's output of
shape `(src_len, batch, embed_dim)`
- **encoder_padding_mask** (ByteTensor): the positions of
padding elements of shape `(batch, src_len)`
- **encoder_embedding** (Tensor): the (scaled) embedding lookup
of shape `(batch, src_len, embed_dim)`
- **encoder_states** (List[Tensor]): all intermediate
hidden states of shape `(src_len, batch, embed_dim)`.
Only populated if *return_all_hiddens* is True.
"""
# compute padding mask
encoder_padding_mask
=
src_tokens
.
eq
(
self
.
padding_idx
)
has_pads
=
src_tokens
.
device
.
type
==
"xla"
or
encoder_padding_mask
.
any
()
x
,
encoder_embedding
=
self
.
forward_embedding
(
src_tokens
,
token_embeddings
)
# account for padding while computing the representation
if
has_pads
:
x
=
x
*
(
1
-
encoder_padding_mask
.
unsqueeze
(
-
1
).
type_as
(
x
))
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
use_rel_pos_enc
:
x_len
=
x
.
shape
[
0
]
pos_seq
=
torch
.
arange
(
0
,
x_len
).
long
().
to
(
x
.
device
)
pos_seq
=
pos_seq
[:,
None
]
-
pos_seq
[
None
,
:]
pos_k
,
pos_v
=
self
.
pos_emb
(
pos_seq
)
else
:
pos_k
=
None
encoder_states
=
[]
uniformity_hiddens
=
[]
if
return_all_hiddens
:
encoder_states
.
append
(
x
)
if
uniformity_layers
is
not
None
and
0
in
uniformity_layers
:
x
=
F
.
normalize
(
x
.
float
(),
dim
=-
1
).
type_as
(
x
)
uniformity_hiddens
.
append
(
x
)
# encoder layers
for
i
,
layer
in
enumerate
(
self
.
layers
):
x
=
layer
(
x
,
encoder_padding_mask
=
encoder_padding_mask
if
has_pads
else
None
,
pos_bias
=
pos_k
,
)
if
uniformity_layers
is
not
None
and
i
+
1
in
uniformity_layers
:
x
=
F
.
normalize
(
x
.
float
(),
dim
=-
1
).
type_as
(
x
)
uniformity_hiddens
.
append
(
x
)
if
return_all_hiddens
:
assert
encoder_states
is
not
None
encoder_states
.
append
(
x
)
if
self
.
layer_norm
is
not
None
:
x
=
self
.
layer_norm
(
x
)
# The Pytorch Mobile lite interpreter does not supports returning NamedTuple in
# `forward` so we use a dictionary instead.
# TorchScript does not support mixed values so the values are all lists.
# The empty list is equivalent to None.
src_lengths
=
(
src_tokens
.
ne
(
self
.
padding_idx
)
.
sum
(
dim
=
1
,
dtype
=
torch
.
int32
)
.
reshape
(
-
1
,
1
)
.
contiguous
()
)
return
{
"encoder_out"
:
[
x
],
# T x B x C
"encoder_padding_mask"
:
[
encoder_padding_mask
],
# B x T
"encoder_embedding"
:
[
encoder_embedding
],
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"uniformity_hiddens"
:
uniformity_hiddens
,
# List[T x B x C]
"src_tokens"
:
[],
"src_lengths"
:
[
src_lengths
],
}
@
torch
.
jit
.
export
def
reorder_encoder_out
(
self
,
encoder_out
:
Dict
[
str
,
List
[
Tensor
]],
new_order
):
"""
Reorder encoder output according to *new_order*.
Args:
encoder_out: output from the ``forward()`` method
new_order (LongTensor): desired order
Returns:
*encoder_out* rearranged according to *new_order*
"""
if
len
(
encoder_out
[
"encoder_out"
])
==
0
:
new_encoder_out
=
[]
else
:
new_encoder_out
=
[
encoder_out
[
"encoder_out"
][
0
].
index_select
(
1
,
new_order
)]
if
len
(
encoder_out
[
"encoder_padding_mask"
])
==
0
:
new_encoder_padding_mask
=
[]
else
:
new_encoder_padding_mask
=
[
encoder_out
[
"encoder_padding_mask"
][
0
].
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"encoder_embedding"
])
==
0
:
new_encoder_embedding
=
[]
else
:
new_encoder_embedding
=
[
encoder_out
[
"encoder_embedding"
][
0
].
index_select
(
0
,
new_order
)
]
if
len
(
encoder_out
[
"src_tokens"
])
==
0
:
src_tokens
=
[]
else
:
src_tokens
=
[(
encoder_out
[
"src_tokens"
][
0
]).
index_select
(
0
,
new_order
)]
if
len
(
encoder_out
[
"src_lengths"
])
==
0
:
src_lengths
=
[]
else
:
src_lengths
=
[(
encoder_out
[
"src_lengths"
][
0
]).
index_select
(
0
,
new_order
)]
encoder_states
=
encoder_out
[
"encoder_states"
]
if
len
(
encoder_states
)
>
0
:
for
idx
,
state
in
enumerate
(
encoder_states
):
encoder_states
[
idx
]
=
state
.
index_select
(
1
,
new_order
)
return
{
"encoder_out"
:
new_encoder_out
,
# T x B x C
"encoder_padding_mask"
:
new_encoder_padding_mask
,
# B x T
"encoder_embedding"
:
new_encoder_embedding
,
# B x T x C
"encoder_states"
:
encoder_states
,
# List[T x B x C]
"src_tokens"
:
src_tokens
,
# B x T
"src_lengths"
:
src_lengths
,
# B x 1
}
def
max_positions
(
self
):
"""Maximum input length supported by the encoder."""
if
self
.
embed_positions
is
None
:
return
self
.
max_source_positions
return
min
(
self
.
max_source_positions
,
self
.
embed_positions
.
max_positions
)
def
upgrade_state_dict_named
(
self
,
state_dict
,
name
):
"""Upgrade a (possibly old) state dict for new versions of fairseq."""
if
isinstance
(
self
.
embed_positions
,
SinusoidalPositionalEmbedding
):
weights_key
=
"{}.embed_positions.weights"
.
format
(
name
)
if
weights_key
in
state_dict
:
print
(
"deleting {0}"
.
format
(
weights_key
))
del
state_dict
[
weights_key
]
state_dict
[
"{}.embed_positions._float_tensor"
.
format
(
name
)
]
=
torch
.
FloatTensor
(
1
)
for
i
in
range
(
self
.
num_layers
):
# update layer norms
self
.
layers
[
i
].
upgrade_state_dict_named
(
state_dict
,
"{}.layers.{}"
.
format
(
name
,
i
)
)
version_key
=
"{}.version"
.
format
(
name
)
if
utils
.
item
(
state_dict
.
get
(
version_key
,
torch
.
Tensor
([
1
]))[
0
])
<
2
:
# earlier checkpoints did not normalize after the stack of layers
self
.
layer_norm
=
None
self
.
normalize
=
False
state_dict
[
version_key
]
=
torch
.
Tensor
([
1
])
return
state_dict
class
TransformerEncoder
(
TransformerEncoderBase
):
def
__init__
(
self
,
args
,
dictionary
,
embed_tokens
):
self
.
args
=
args
super
().
__init__
(
TransformerConfig
.
from_namespace
(
args
),
dictionary
,
embed_tokens
,
use_rel_pos_enc
=
getattr
(
args
,
"use_rel_pos_enc"
,
False
),
scaling_for_att
=
getattr
(
args
,
"scaling_for_att"
,
1.0
),
)
def
build_encoder_layer
(
self
,
args
):
return
super
().
build_encoder_layer
(
TransformerConfig
.
from_namespace
(
args
),
)
def
PositionalEmbedding
(
num_embeddings
:
int
,
embedding_dim
:
int
,
padding_idx
:
int
,
learned
:
bool
=
False
,
):
if
learned
:
# if padding_idx is specified then offset the embedding ids by
# this index and adjust num_embeddings appropriately
# TODO: The right place for this offset would be inside
# LearnedPositionalEmbedding. Move this there for a cleaner implementation.
if
padding_idx
is
not
None
:
num_embeddings
=
num_embeddings
+
padding_idx
+
1
m
=
LearnedPositionalEmbedding
(
num_embeddings
,
embedding_dim
,
padding_idx
)
nn
.
init
.
normal_
(
m
.
weight
,
mean
=
0
,
std
=
embedding_dim
**-
0.5
)
if
padding_idx
is
not
None
:
nn
.
init
.
constant_
(
m
.
weight
[
padding_idx
],
0
)
else
:
m
=
SinusoidalPositionalEmbedding
(
embedding_dim
,
padding_idx
,
init_size
=
num_embeddings
+
padding_idx
+
1
,
)
return
m
Speech2S/speech2s/modules/transformer_layer.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
Modified from https://github.com/facebookresearch/fairseq/blob/main/fairseq/modules/transformer_layer.py
https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_decoder_layer.py
"""
from
typing
import
Dict
,
List
,
Optional
import
torch
from
torch
import
Tensor
from
fairseq.modules
import
LayerNorm
from
fairseq.modules.transformer_layer
import
TransformerEncoderLayerBase
as
FairseqTransformerEncoderLayerBase
from
fairseq.modules.transformer_layer
import
TransformerDecoderLayerBase
as
FairseqTransformerDecoderLayerBase
from
speechut.modules
import
MultiheadAttention
class
TransformerEncoderLayerBase
(
FairseqTransformerEncoderLayerBase
):
"""Encoder layer block.
In the original paper each operation (multi-head attention or FFN) is
postprocessed with: `dropout -> add residual -> layernorm`. In the
tensor2tensor code they suggest that learning is more robust when
preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.encoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
"""
def
__init__
(
self
,
cfg
,
has_relative_attention_bias
=
False
,
scaling_for_att
=
1.0
):
self
.
scaling_for_att
=
scaling_for_att
super
().
__init__
(
cfg
)
if
has_relative_attention_bias
:
self
.
norm_k
=
LayerNorm
(
self
.
embed_dim
//
cfg
.
encoder
.
attention_heads
)
def
build_self_attention
(
self
,
embed_dim
,
cfg
,
scaling_for_att
=
1.0
):
return
MultiheadAttention
(
embed_dim
,
cfg
.
encoder
.
attention_heads
,
dropout
=
cfg
.
attention_dropout
,
self_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
scaling_for_att
=
self
.
scaling_for_att
,
)
def
forward
(
self
,
x
,
encoder_padding_mask
:
Optional
[
Tensor
],
attn_mask
:
Optional
[
Tensor
]
=
None
,
pos_bias
=
None
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor): binary ByteTensor of shape
`(batch, seq_len)` where padding elements are indicated by ``1``.
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`,
where `tgt_len` is the length of output and `src_len` is the
length of input, though here both are equal to `seq_len`.
`attn_mask[tgt_i, src_j] = 1` means that when calculating the
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is
useful for strided self-attention.
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
# anything in original attn_mask = 1, becomes -1e8
# anything in original attn_mask = 0, becomes 0
# Note that we cannot use -inf here, because at some edge cases,
# the attention weight (before softmax) for some padded element in query
# will become -inf, which results in NaN in model parameters
if
attn_mask
is
not
None
:
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
.
to
(
torch
.
bool
),
-
1e8
if
x
.
dtype
==
torch
.
float32
else
-
1e4
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
pos_bias
is
not
None
:
pos_bias
=
self
.
norm_k
(
pos_bias
)
x
,
_
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
encoder_padding_mask
,
need_weights
=
False
,
attn_mask
=
attn_mask
,
position_bias
=
pos_bias
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
return
x
class
TransformerDecoderLayerBase
(
FairseqTransformerDecoderLayerBase
):
"""Decoder layer block.
In the original paper each operation (multi-head attention, encoder
attention or FFN) is postprocessed with: `dropout -> add residual ->
layernorm`. In the tensor2tensor code they suggest that learning is more
robust when preprocessing each layer with layernorm and postprocessing with:
`dropout -> add residual`. We default to the approach in the paper, but the
tensor2tensor approach can be enabled by setting
*cfg.decoder.normalize_before* to ``True``.
Args:
args (argparse.Namespace): parsed command-line arguments
no_encoder_attn (bool, optional): whether to attend to encoder outputs
(default: False).
"""
def
__init__
(
self
,
cfg
,
no_encoder_attn
=
False
,
add_bias_kv
=
False
,
add_zero_attn
=
False
,
has_relative_attention_bias
=
False
,
scaling_for_att
=
1.0
,
):
self
.
scaling_for_att
=
scaling_for_att
super
().
__init__
(
cfg
,
no_encoder_attn
,
add_bias_kv
,
add_zero_attn
,
)
if
has_relative_attention_bias
:
self
.
norm_k
=
LayerNorm
(
self
.
embed_dim
//
cfg
.
decoder
.
attention_heads
)
def
build_self_attention
(
self
,
embed_dim
,
cfg
,
add_bias_kv
=
False
,
add_zero_attn
=
False
):
return
MultiheadAttention
(
embed_dim
,
cfg
.
decoder
.
attention_heads
,
dropout
=
cfg
.
attention_dropout
,
add_bias_kv
=
add_bias_kv
,
add_zero_attn
=
add_zero_attn
,
self_attention
=
not
cfg
.
cross_self_attention
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
scaling_for_att
=
self
.
scaling_for_att
,
)
def
build_encoder_attention
(
self
,
embed_dim
,
cfg
):
return
MultiheadAttention
(
embed_dim
,
cfg
.
decoder
.
attention_heads
,
kdim
=
cfg
.
encoder
.
embed_dim
,
vdim
=
cfg
.
encoder
.
embed_dim
,
dropout
=
cfg
.
attention_dropout
,
encoder_decoder_attention
=
True
,
q_noise
=
self
.
quant_noise
,
qn_block_size
=
self
.
quant_noise_block_size
,
scaling_for_att
=
self
.
scaling_for_att
,
)
def
forward
(
self
,
x
,
encoder_out
:
Optional
[
torch
.
Tensor
]
=
None
,
encoder_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
incremental_state
:
Optional
[
Dict
[
str
,
Dict
[
str
,
Optional
[
Tensor
]]]]
=
None
,
prev_self_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
prev_attn_state
:
Optional
[
List
[
torch
.
Tensor
]]
=
None
,
self_attn_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
self_attn_padding_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
need_attn
:
bool
=
False
,
need_head_weights
:
bool
=
False
,
pos_bias
=
None
,
):
"""
Args:
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)`
encoder_padding_mask (ByteTensor, optional): binary
ByteTensor of shape `(batch, src_len)` where padding
elements are indicated by ``1``.
need_attn (bool, optional): return attention weights
need_head_weights (bool, optional): return attention weights
for each head (default: return average over heads).
Returns:
encoded output of shape `(seq_len, batch, embed_dim)`
"""
if
need_head_weights
:
need_attn
=
True
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
pos_bias
is
not
None
:
pos_bias
=
self
.
norm_k
(
pos_bias
)
if
prev_self_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_self_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_self_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_self_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
self_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
_self_attn_input_buffer
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
if
self
.
cross_self_attention
and
not
(
incremental_state
is
not
None
and
_self_attn_input_buffer
is
not
None
and
"prev_key"
in
_self_attn_input_buffer
):
if
self_attn_mask
is
not
None
:
assert
encoder_out
is
not
None
self_attn_mask
=
torch
.
cat
(
(
x
.
new_zeros
(
x
.
size
(
0
),
encoder_out
.
size
(
0
)),
self_attn_mask
),
dim
=
1
)
if
self_attn_padding_mask
is
not
None
:
if
encoder_padding_mask
is
None
:
assert
encoder_out
is
not
None
encoder_padding_mask
=
self_attn_padding_mask
.
new_zeros
(
encoder_out
.
size
(
1
),
encoder_out
.
size
(
0
)
)
self_attn_padding_mask
=
torch
.
cat
(
(
encoder_padding_mask
,
self_attn_padding_mask
),
dim
=
1
)
assert
encoder_out
is
not
None
y
=
torch
.
cat
((
encoder_out
,
x
),
dim
=
0
)
else
:
y
=
x
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
y
,
value
=
y
,
key_padding_mask
=
self_attn_padding_mask
,
incremental_state
=
incremental_state
,
need_weights
=
False
,
attn_mask
=
self_attn_mask
,
position_bias
=
pos_bias
,
)
if
self
.
c_attn
is
not
None
:
tgt_len
,
bsz
=
x
.
size
(
0
),
x
.
size
(
1
)
x
=
x
.
view
(
tgt_len
,
bsz
,
self
.
nh
,
self
.
head_dim
)
x
=
torch
.
einsum
(
"tbhd,h->tbhd"
,
x
,
self
.
c_attn
)
x
=
x
.
reshape
(
tgt_len
,
bsz
,
self
.
embed_dim
)
if
self
.
attn_ln
is
not
None
:
x
=
self
.
attn_ln
(
x
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
self
.
encoder_attn
is
not
None
and
encoder_out
is
not
None
:
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
if
prev_attn_state
is
not
None
:
prev_key
,
prev_value
=
prev_attn_state
[:
2
]
saved_state
:
Dict
[
str
,
Optional
[
Tensor
]]
=
{
"prev_key"
:
prev_key
,
"prev_value"
:
prev_value
,
}
if
len
(
prev_attn_state
)
>=
3
:
saved_state
[
"prev_key_padding_mask"
]
=
prev_attn_state
[
2
]
assert
incremental_state
is
not
None
self
.
encoder_attn
.
_set_input_buffer
(
incremental_state
,
saved_state
)
x
,
attn
=
self
.
encoder_attn
(
query
=
x
,
key
=
encoder_out
,
value
=
encoder_out
,
key_padding_mask
=
encoder_padding_mask
,
incremental_state
=
incremental_state
,
static_kv
=
True
,
need_weights
=
need_attn
or
(
not
self
.
training
and
self
.
need_attn
),
need_head_weights
=
need_head_weights
,
)
x
=
self
.
dropout_module
(
x
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
encoder_attn_layer_norm
(
x
)
residual
=
x
if
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
activation_dropout_module
(
x
)
if
self
.
ffn_layernorm
is
not
None
:
x
=
self
.
ffn_layernorm
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout_module
(
x
)
if
self
.
w_resid
is
not
None
:
residual
=
torch
.
mul
(
self
.
w_resid
,
residual
)
x
=
self
.
residual_connection
(
x
,
residual
)
if
not
self
.
normalize_before
:
x
=
self
.
final_layer_norm
(
x
)
if
self
.
onnx_trace
and
incremental_state
is
not
None
:
saved_state
=
self
.
self_attn
.
_get_input_buffer
(
incremental_state
)
assert
saved_state
is
not
None
if
self_attn_padding_mask
is
not
None
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
],
saved_state
[
"prev_key_padding_mask"
],
]
else
:
self_attn_state
=
[
saved_state
[
"prev_key"
],
saved_state
[
"prev_value"
]]
return
x
,
attn
,
self_attn_state
return
x
,
attn
,
None
def
make_generation_fast_
(
self
,
need_attn
:
bool
=
False
,
**
kwargs
):
self
.
need_attn
=
need_attn
Speech2S/speech2s/modules/w2v_encoder.py
0 → 100644
View file @
12c90639
# --------------------------------------------------------
# Copyright (c) 2022 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Based on fairseq code bases
# https://github.com/facebookresearch/fairseq
# --------------------------------------------------------
"""
wav2vec encoder adding relitive position bias, modified from
https://github.com/microsoft/SpeechT5/blob/main/Speech2C/speech2c/models/modules/transformer_encoder.py
https://github.com/facebookresearch/fairseq/blob/main/fairseq/models/wav2vec/wav2vec2.py
"""
import
math
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
fairseq
import
utils
from
fairseq.dataclass
import
ChoiceEnum
from
fairseq.modules
import
(
LayerNorm
,
SamePad
,
)
from
fairseq.modules.checkpoint_activations
import
checkpoint_wrapper
from
fairseq.modules.transformer_sentence_encoder
import
init_bert_params
from
fairseq.utils
import
index_put
from
fairseq.distributed
import
fsdp_wrap
from
fairseq.models.wav2vec.utils
import
pad_to_multiple
## reload multi-head attition with rel-pos-bias
from
fairseq.models.wav2vec.wav2vec2
import
TransformerEncoder
as
W2vTransformerEncoder
from
speechut.modules
import
RelativePositionalEncoding
from
speechut.modules
import
MultiheadAttention
EXTRACTOR_MODE_CHOICES
=
ChoiceEnum
([
"default"
,
"layer_norm"
])
MASKING_DISTRIBUTION_CHOICES
=
ChoiceEnum
([
"static"
,
"uniform"
,
"normal"
,
"poisson"
])
class
TransformerEncoder
(
W2vTransformerEncoder
):
def
__init__
(
self
,
args
):
super
().
__init__
(
args
)
self
.
dropout
=
args
.
dropout
self
.
embedding_dim
=
args
.
encoder_embed_dim
self
.
required_seq_len_multiple
=
args
.
required_seq_len_multiple
self
.
use_rel_pos_enc
=
getattr
(
args
,
"use_rel_pos_enc"
,
False
)
self
.
pos_conv
=
nn
.
Conv1d
(
self
.
embedding_dim
,
self
.
embedding_dim
,
kernel_size
=
args
.
conv_pos
,
padding
=
args
.
conv_pos
//
2
,
groups
=
args
.
conv_pos_groups
,
)
dropout
=
0
std
=
math
.
sqrt
((
4
*
(
1.0
-
dropout
))
/
(
args
.
conv_pos
*
self
.
embedding_dim
))
nn
.
init
.
normal_
(
self
.
pos_conv
.
weight
,
mean
=
0
,
std
=
std
)
nn
.
init
.
constant_
(
self
.
pos_conv
.
bias
,
0
)
self
.
pos_conv
=
nn
.
utils
.
weight_norm
(
self
.
pos_conv
,
name
=
"weight"
,
dim
=
2
)
self
.
pos_conv
=
nn
.
Sequential
(
self
.
pos_conv
,
SamePad
(
args
.
conv_pos
),
nn
.
GELU
())
layers
=
[]
for
_
in
range
(
args
.
encoder_layers
):
layer
=
TransformerSentenceEncoderLayer
(
embedding_dim
=
self
.
embedding_dim
,
ffn_embedding_dim
=
args
.
encoder_ffn_embed_dim
,
num_attention_heads
=
args
.
encoder_attention_heads
,
dropout
=
self
.
dropout
,
attention_dropout
=
args
.
attention_dropout
,
activation_dropout
=
args
.
activation_dropout
,
activation_fn
=
args
.
activation_fn
,
layer_norm_first
=
args
.
layer_norm_first
,
has_relative_attention_bias
=
self
.
use_rel_pos_enc
,
)
if
args
.
checkpoint_activations
:
layer
=
fsdp_wrap
(
layer
)
layer
=
checkpoint_wrapper
(
layer
)
layers
.
append
(
layer
)
self
.
layers
=
nn
.
ModuleList
(
layers
)
self
.
layer_norm_first
=
args
.
layer_norm_first
self
.
layer_norm
=
LayerNorm
(
self
.
embedding_dim
)
self
.
layerdrop
=
args
.
encoder_layerdrop
if
self
.
use_rel_pos_enc
:
self
.
pos_emb
=
RelativePositionalEncoding
(
args
.
encoder_embed_dim
//
args
.
encoder_attention_heads
,
160
)
self
.
apply
(
init_bert_params
)
def
forward
(
self
,
x
,
padding_mask
=
None
,
layer
=
None
):
x
,
layer_results
=
self
.
extract_features
(
x
,
padding_mask
,
layer
)
if
self
.
layer_norm_first
and
layer
is
None
:
x
=
self
.
layer_norm
(
x
)
return
x
,
layer_results
def
extract_features
(
self
,
x
,
padding_mask
=
None
,
tgt_layer
=
None
):
if
padding_mask
is
not
None
:
x
=
index_put
(
x
,
padding_mask
,
0
)
x_conv
=
self
.
pos_conv
(
x
.
transpose
(
1
,
2
))
x_conv
=
x_conv
.
transpose
(
1
,
2
)
x
=
x
+
x_conv
if
not
self
.
layer_norm_first
:
x
=
self
.
layer_norm
(
x
)
# pad to the sequence length dimension
x
,
pad_length
=
pad_to_multiple
(
x
,
self
.
required_seq_len_multiple
,
dim
=-
2
,
value
=
0
)
if
pad_length
>
0
and
padding_mask
is
None
:
padding_mask
=
x
.
new_zeros
((
x
.
size
(
0
),
x
.
size
(
1
)),
dtype
=
torch
.
bool
)
padding_mask
[:,
-
pad_length
:]
=
True
else
:
padding_mask
,
_
=
pad_to_multiple
(
padding_mask
,
self
.
required_seq_len_multiple
,
dim
=-
1
,
value
=
True
)
x
=
F
.
dropout
(
x
,
p
=
self
.
dropout
,
training
=
self
.
training
)
# B x T x C -> T x B x C
x
=
x
.
transpose
(
0
,
1
)
if
self
.
use_rel_pos_enc
:
x_len
=
x
.
shape
[
0
]
pos_seq
=
torch
.
arange
(
0
,
x_len
).
long
().
to
(
x
.
device
)
pos_seq
=
pos_seq
[:,
None
]
-
pos_seq
[
None
,
:]
pos_k
,
pos_v
=
self
.
pos_emb
(
pos_seq
)
else
:
pos_k
=
None
layer_results
=
[]
r
=
None
for
i
,
layer
in
enumerate
(
self
.
layers
):
dropout_probability
=
np
.
random
.
random
()
if
not
self
.
training
or
(
dropout_probability
>
self
.
layerdrop
):
x
,
z
=
layer
(
x
,
self_attn_padding_mask
=
padding_mask
,
need_weights
=
False
,
pos_bias
=
pos_k
)
if
tgt_layer
is
not
None
:
# unpad if needed
if
pad_length
>
0
:
layer_results
.
append
(
(
x
[:
-
pad_length
],
z
[:,
:
-
pad_length
,
:
-
pad_length
]
if
z
is
not
None
else
z
,
)
)
else
:
layer_results
.
append
((
x
,
z
))
if
i
==
tgt_layer
:
r
=
x
break
if
r
is
not
None
:
x
=
r
# T x B x C -> B x T x C
x
=
x
.
transpose
(
0
,
1
)
# undo paddding
if
pad_length
>
0
:
x
=
x
[:,
:
-
pad_length
]
return
x
,
layer_results
class
TransformerSentenceEncoderLayer
(
nn
.
Module
):
"""
Implements a Transformer Encoder Layer used in BERT/XLM style pre-trained
models.
"""
def
__init__
(
self
,
embedding_dim
:
float
=
768
,
ffn_embedding_dim
:
float
=
3072
,
num_attention_heads
:
float
=
8
,
dropout
:
float
=
0.1
,
attention_dropout
:
float
=
0.1
,
activation_dropout
:
float
=
0.1
,
activation_fn
:
str
=
"relu"
,
layer_norm_first
:
bool
=
False
,
has_relative_attention_bias
:
bool
=
False
,
)
->
None
:
super
().
__init__
()
# Initialize parameters
self
.
embedding_dim
=
embedding_dim
self
.
dropout
=
dropout
self
.
activation_dropout
=
activation_dropout
# Initialize blocks
self
.
activation_fn
=
utils
.
get_activation_fn
(
activation_fn
)
self
.
self_attn
=
MultiheadAttention
(
self
.
embedding_dim
,
num_attention_heads
,
dropout
=
attention_dropout
,
self_attention
=
True
,
)
self
.
dropout1
=
nn
.
Dropout
(
dropout
)
self
.
dropout2
=
nn
.
Dropout
(
self
.
activation_dropout
)
self
.
dropout3
=
nn
.
Dropout
(
dropout
)
self
.
layer_norm_first
=
layer_norm_first
# layer norm associated with the self attention layer
self
.
self_attn_layer_norm
=
LayerNorm
(
self
.
embedding_dim
)
self
.
fc1
=
nn
.
Linear
(
self
.
embedding_dim
,
ffn_embedding_dim
)
self
.
fc2
=
nn
.
Linear
(
ffn_embedding_dim
,
self
.
embedding_dim
)
# layer norm associated with the position wise feed-forward NN
self
.
final_layer_norm
=
LayerNorm
(
self
.
embedding_dim
)
if
has_relative_attention_bias
:
self
.
norm_k
=
LayerNorm
(
self
.
embedding_dim
//
num_attention_heads
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
self_attn_mask
:
torch
.
Tensor
=
None
,
self_attn_padding_mask
:
torch
.
Tensor
=
None
,
need_weights
:
bool
=
False
,
att_args
=
None
,
pos_bias
=
None
,
):
"""
LayerNorm is applied either before or after the self-attention/ffn
modules similar to the original Transformer imlementation.
"""
residual
=
x
if
self
.
layer_norm_first
:
x
=
self
.
self_attn_layer_norm
(
x
)
if
pos_bias
is
not
None
:
pos_bias
=
self
.
norm_k
(
pos_bias
)
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
self_attn_padding_mask
,
attn_mask
=
self_attn_mask
,
position_bias
=
pos_bias
,
)
x
=
self
.
dropout1
(
x
)
x
=
residual
+
x
residual
=
x
x
=
self
.
final_layer_norm
(
x
)
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout3
(
x
)
x
=
residual
+
x
else
:
x
,
attn
=
self
.
self_attn
(
query
=
x
,
key
=
x
,
value
=
x
,
key_padding_mask
=
self_attn_padding_mask
,
position_bias
=
pos_bias
,
)
x
=
self
.
dropout1
(
x
)
x
=
residual
+
x
x
=
self
.
self_attn_layer_norm
(
x
)
residual
=
x
x
=
self
.
activation_fn
(
self
.
fc1
(
x
))
x
=
self
.
dropout2
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
dropout3
(
x
)
x
=
residual
+
x
x
=
self
.
final_layer_norm
(
x
)
return
x
,
attn
Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_asr.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechUT Base model #
# ####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> [mount=
${
PWD
}
] [world_size=32] [update_freq=1]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechUT
]
&&
echo
"Error: dir not match! Switch to SpeechUT/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
mount
=
$3
world_size
=
$4
update_freq
=
$5
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/base_speechut4asr_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechut/config/pretrain
\
--config-name
speechut_base_librispeech
\
common.user_dir
=
$CODE_ROOT
/speechut
\
\
task.labels
=
'["km"]'
\
model.label_rate
=
50
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
dataset.train_subset
=
\"
train_960+pseudo_libritext.kmu-ltr+merge_960.kmu-none
\"
\
dataset.valid_subset
=
\"
dev_clean+dev.kmu-ltr+dev.kmu-none
\"
\
dataset.num_workers
=
0
\
dataset.max_tokens
=
1400000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
base_speechut4asr_
${
world_size
}
gpu_
${
update_freq
}
accum
Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechUT Base model #
# ####################################
[
$#
-lt
3
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> <lang=de/es> [mount=
${
PWD
}
] [world_size=32] [update_freq=1]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechUT
]
&&
echo
"Error: dir not match! Switch to SpeechUT/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
lang
=
$3
mount
=
$4
world_size
=
$5
update_freq
=
$6
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/base_speechut4en
${
lang
}
_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechut/config/pretrain
\
--config-name
speechut_base_librispeech
\
common.user_dir
=
$CODE_ROOT
/speechut
\
\
task.labels
=
'["km"]'
\
model.label_rate
=
50
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
model.add_text_ctc
=
false
\
model.text_transformer.share_decoder_input_output_embed
=
true
\
criterion.u2t_ed_weight
=
1.0
\
criterion.u2t_ctc_weight
=
0
\
\
dataset.train_subset
=
\"
train_960,mustcuns_
${
lang
}
+pseudo_wmt_en
${
lang
}
.kmu-spm+train_960.kmu-none,mustcuns_
${
lang
}
.kmu-none
\"
\
dataset.valid_subset
=
\"
dev_clean+pseudo_valid.kmu-spm+dev.kmu-none
\"
\
dataset.num_workers
=
0
\
dataset.max_tokens
=
1400000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
base_speechut4en
${
lang
}
_
${
world_size
}
gpu_
${
update_freq
}
accum
Speech2S/speech2s/scripts copy/pretrain_speechut/base_speechut_for_st_enfr.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechUT Base model #
# ####################################
[
$#
-lt
3
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> [lang=fr] [mount=
${
PWD
}
] [world_size=32] [update_freq=1]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechUT
]
&&
echo
"Error: dir not match! Switch to SpeechUT/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
lang
=
$3
mount
=
$4
world_size
=
$5
update_freq
=
$6
[
-z
$lang
]
&&
lang
=
fr
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
1
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/base_speechut4en
${
lang
}
_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechut/config/pretrain
\
--config-name
speechut_base_librispeech
\
common.user_dir
=
$CODE_ROOT
/speechut
\
\
task.labels
=
'["km"]'
\
model.label_rate
=
50
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
model.add_text_ctc
=
false
\
criterion.u2t_ed_weight
=
1.0
\
criterion.u2t_ctc_weight
=
0
\
\
dataset.train_subset
=
\"
train_960,pretrain_mustc+pseudo_wmt14_enfr.kmu-spm+train_960.kmu-none,pretrain_mustc.kmu-none
\"
\
dataset.valid_subset
=
\"
dev_clean+pseudo_valid.kmu-spm+dev.kmu-none
\"
\
dataset.num_workers
=
0
\
dataset.max_tokens
=
1400000
\
optimization.max_update
=
600000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
base_speechut4en
${
lang
}
_
${
world_size
}
gpu_
${
update_freq
}
accum
Speech2S/speech2s/scripts copy/pretrain_speechut/large_speechut_for_asr.sh
0 → 100644
View file @
12c90639
# ####################################
# SpeechUT Large model #
# ####################################
[
$#
-lt
2
]
&&
echo
"Usage:
$0
<data_dir> <text_data_dir> [mount=
${
PWD
}
] [world_size=32] [update_freq=4]"
&&
exit
1
[
${
PWD
##*/
}
!=
SpeechUT
]
&&
echo
"Error: dir not match! Switch to SpeechUT/ and run it again!"
&&
exit
1
DATA_DIR
=
$1
TEXT_DATA_DIR
=
$2
mount
=
$3
world_size
=
$4
update_freq
=
$5
[
-z
$mount
]
&&
mount
=
${
PWD
}
[
-z
$world_size
]
&&
world_size
=
32
[
-z
$update_freq
]
&&
update_freq
=
4
CODE_ROOT
=
${
PWD
}
MODEL_DIR
=
"
${
mount
}
/exp/pretrain/large_speechut4asr_
${
world_size
}
gpu_
${
update_freq
}
accum"
[
-d
$MODEL_DIR
]
||
mkdir
-p
$MODEL_DIR
python
$CODE_ROOT
/fairseq/fairseq_cli/hydra_train.py
\
--config-dir
$CODE_ROOT
/speechut/config/pretrain
\
--config-name
speechut_large_librilight
\
common.user_dir
=
$CODE_ROOT
/speechut
\
\
task.labels
=
'["km"]'
\
model.label_rate
=
50
\
task.data
=
$DATA_DIR
\
task.label_dir
=
$DATA_DIR
\
task.text_cfg.text_data
=
$TEXT_DATA_DIR
\
\
dataset.train_subset
=
\"
train_small+pseudo_libritext.kmu-ltr
\"
\
dataset.valid_subset
=
\"
dev_clean+dev.kmu-ltr
\"
\
dataset.num_workers
=
0
\
dataset.max_tokens
=
900000
\
distributed_training.distributed_world_size
=
${
world_size
}
\
optimization.update_freq
=[
${
update_freq
}
]
\
\
common.tensorboard_logdir
=
$MODEL_DIR
\
checkpoint.save_dir
=
$MODEL_DIR
\
hydra.run.dir
=
$MODEL_DIR
\
hydra.job.name
=
large_speechut4asr_
${
world_size
}
gpu_
${
update_freq
}
accum
\ No newline at end of file
Prev
1
2
3
4
5
6
7
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment