Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
dcuai
dlexamples
Commits
c0f05c10
Commit
c0f05c10
authored
Nov 29, 2022
by
hepj
Browse files
更新transformer代码
parent
c056df78
Changes
321
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5102 additions
and
0 deletions
+5102
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/multi_modality_dataset.py
...-Transformer/fairseq/data/audio/multi_modality_dataset.py
+266
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/raw_audio_dataset.py
...P/new-Transformer/fairseq/data/audio/raw_audio_dataset.py
+393
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/speech_to_speech_dataset.py
...ransformer/fairseq/data/audio/speech_to_speech_dataset.py
+428
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/speech_to_text_dataset.py
...-Transformer/fairseq/data/audio/speech_to_text_dataset.py
+561
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/speech_to_text_joint_dataset.py
...former/fairseq/data/audio/speech_to_text_joint_dataset.py
+359
-0
PyTorch/NLP/new-Transformer/fairseq/data/audio/text_to_speech_dataset.py
...-Transformer/fairseq/data/audio/text_to_speech_dataset.py
+248
-0
PyTorch/NLP/new-Transformer/fairseq/data/backtranslation_dataset.py
...P/new-Transformer/fairseq/data/backtranslation_dataset.py
+165
-0
PyTorch/NLP/new-Transformer/fairseq/data/base_wrapper_dataset.py
.../NLP/new-Transformer/fairseq/data/base_wrapper_dataset.py
+78
-0
PyTorch/NLP/new-Transformer/fairseq/data/bucket_pad_length_dataset.py
...new-Transformer/fairseq/data/bucket_pad_length_dataset.py
+78
-0
PyTorch/NLP/new-Transformer/fairseq/data/codedataset.py
PyTorch/NLP/new-Transformer/fairseq/data/codedataset.py
+576
-0
PyTorch/NLP/new-Transformer/fairseq/data/colorize_dataset.py
PyTorch/NLP/new-Transformer/fairseq/data/colorize_dataset.py
+25
-0
PyTorch/NLP/new-Transformer/fairseq/data/concat_dataset.py
PyTorch/NLP/new-Transformer/fairseq/data/concat_dataset.py
+124
-0
PyTorch/NLP/new-Transformer/fairseq/data/concat_sentences_dataset.py
.../new-Transformer/fairseq/data/concat_sentences_dataset.py
+54
-0
PyTorch/NLP/new-Transformer/fairseq/data/data_utils.py
PyTorch/NLP/new-Transformer/fairseq/data/data_utils.py
+604
-0
PyTorch/NLP/new-Transformer/fairseq/data/data_utils_fast.pyx
PyTorch/NLP/new-Transformer/fairseq/data/data_utils_fast.pyx
+178
-0
PyTorch/NLP/new-Transformer/fairseq/data/denoising_dataset.py
...rch/NLP/new-Transformer/fairseq/data/denoising_dataset.py
+436
-0
PyTorch/NLP/new-Transformer/fairseq/data/dictionary.py
PyTorch/NLP/new-Transformer/fairseq/data/dictionary.py
+401
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/__init__.py
...rch/NLP/new-Transformer/fairseq/data/encoders/__init__.py
+29
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/byte_bpe.py
...rch/NLP/new-Transformer/fairseq/data/encoders/byte_bpe.py
+48
-0
PyTorch/NLP/new-Transformer/fairseq/data/encoders/byte_utils.py
...h/NLP/new-Transformer/fairseq/data/encoders/byte_utils.py
+51
-0
No files found.
Too many changes to show.
To preserve performance only
321 of 321+
files are displayed.
Plain diff
Email patch
PyTorch/NLP/new-Transformer/fairseq/data/audio/multi_modality_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2021-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.
import
logging
import
math
from
typing
import
List
,
Optional
,
NamedTuple
import
numpy
as
np
import
torch
from
fairseq.data
import
(
ConcatDataset
,
LanguagePairDataset
,
FileAudioDataset
,
data_utils
,
)
from
fairseq.data
import
FairseqDataset
logger
=
logging
.
getLogger
(
__name__
)
class
ModalityDatasetItem
(
NamedTuple
):
datasetname
:
str
dataset
:
any
max_positions
:
List
[
int
]
max_tokens
:
Optional
[
int
]
=
None
max_sentences
:
Optional
[
int
]
=
None
# MultiModalityDataset: it concate multiple datasets with different modalities.
# Compared with ConcatDataset it can 1) sample data given the ratios for different datasets
# 2) it adds mode to indicate what type of the data samples come from.
# It will be used with GroupedEpochBatchIterator together to generate mini-batch with samples
# from the same type of dataset
# If only one dataset is used, it will perform like the original dataset with mode added
class
MultiModalityDataset
(
ConcatDataset
):
def
__init__
(
self
,
datasets
:
List
[
ModalityDatasetItem
]):
id_to_mode
=
[]
dsets
=
[]
max_tokens
=
[]
max_sentences
=
[]
max_positions
=
[]
for
dset
in
datasets
:
id_to_mode
.
append
(
dset
.
datasetname
)
dsets
.
append
(
dset
.
dataset
)
max_tokens
.
append
(
dset
.
max_tokens
)
max_positions
.
append
(
dset
.
max_positions
)
max_sentences
.
append
(
dset
.
max_sentences
)
weights
=
[
1.0
for
s
in
dsets
]
super
().
__init__
(
dsets
,
weights
)
self
.
max_tokens
=
max_tokens
self
.
max_positions
=
max_positions
self
.
max_sentences
=
max_sentences
self
.
id_to_mode
=
id_to_mode
self
.
raw_sub_batch_samplers
=
[]
self
.
_cur_epoch
=
0
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
self
.
_cur_epoch
=
epoch
def
__getitem__
(
self
,
idx
):
dataset_idx
,
sample_idx
=
self
.
_get_dataset_and_sample_index
(
idx
)
sample
=
self
.
datasets
[
dataset_idx
][
sample_idx
]
return
(
dataset_idx
,
sample
)
def
collater
(
self
,
samples
):
if
len
(
samples
)
==
0
:
return
{}
dataset_idx
=
samples
[
0
][
0
]
# make sure all samples in samples are from same dataset
assert
sum
([
0
if
dataset_idx
==
s
[
0
]
else
1
for
s
in
samples
])
==
0
samples
=
self
.
datasets
[
dataset_idx
].
collater
([
x
[
1
]
for
x
in
samples
])
# add mode
samples
[
"net_input"
][
"mode"
]
=
self
.
id_to_mode
[
dataset_idx
]
return
samples
def
size
(
self
,
index
:
int
):
if
len
(
self
.
datasets
)
==
1
:
return
self
.
datasets
[
0
].
size
(
index
)
return
super
().
size
(
index
)
@
property
def
sizes
(
self
):
if
len
(
self
.
datasets
)
==
1
:
return
self
.
datasets
[
0
].
sizes
super
().
sizes
def
ordered_indices
(
self
):
"""
Returns indices sorted by length. So less padding is needed.
"""
if
len
(
self
.
datasets
)
==
1
:
return
self
.
datasets
[
0
].
ordered_indices
()
indices_group
=
[]
for
d_idx
,
ds
in
enumerate
(
self
.
datasets
):
sample_num
=
self
.
cumulative_sizes
[
d_idx
]
if
d_idx
>
0
:
sample_num
=
sample_num
-
self
.
cumulative_sizes
[
d_idx
-
1
]
assert
sample_num
==
len
(
ds
)
indices_group
.
append
(
ds
.
ordered_indices
())
return
indices_group
def
get_raw_batch_samplers
(
self
,
required_batch_size_multiple
,
seed
):
if
len
(
self
.
raw_sub_batch_samplers
)
>
0
:
logger
.
info
(
" raw_sub_batch_samplers exists. No action is taken"
)
return
with
data_utils
.
numpy_seed
(
seed
):
indices
=
self
.
ordered_indices
()
for
i
,
ds
in
enumerate
(
self
.
datasets
):
indices
[
i
]
=
ds
.
filter_indices_by_size
(
indices
[
i
],
self
.
max_positions
[
i
],
)[
0
]
sub_batch_sampler
=
ds
.
batch_by_size
(
indices
[
i
],
max_tokens
=
self
.
max_tokens
[
i
],
max_sentences
=
self
.
max_sentences
[
i
],
required_batch_size_multiple
=
required_batch_size_multiple
,
)
self
.
raw_sub_batch_samplers
.
append
(
sub_batch_sampler
)
def
get_batch_samplers
(
self
,
mult_ratios
,
required_batch_size_multiple
,
seed
):
self
.
get_raw_batch_samplers
(
required_batch_size_multiple
,
seed
)
batch_samplers
=
[]
for
i
,
_
in
enumerate
(
self
.
datasets
):
if
i
>
0
:
sub_batch_sampler
=
[
[
y
+
self
.
cumulative_sizes
[
i
-
1
]
for
y
in
x
]
for
x
in
self
.
raw_sub_batch_samplers
[
i
]
]
else
:
sub_batch_sampler
=
list
(
self
.
raw_sub_batch_samplers
[
i
])
smp_r
=
mult_ratios
[
i
]
if
smp_r
!=
1
:
is_increase
=
"increased"
if
smp_r
>
1
else
"decreased"
logger
.
info
(
"number of batch for the dataset {} is {} from {} to {}"
.
format
(
self
.
id_to_mode
[
i
],
is_increase
,
len
(
sub_batch_sampler
),
int
(
len
(
sub_batch_sampler
)
*
smp_r
),
)
)
mul_samplers
=
[]
for
_
in
range
(
math
.
floor
(
smp_r
)):
mul_samplers
=
mul_samplers
+
sub_batch_sampler
if
math
.
floor
(
smp_r
)
!=
smp_r
:
with
data_utils
.
numpy_seed
(
seed
+
self
.
_cur_epoch
):
np
.
random
.
shuffle
(
sub_batch_sampler
)
smp_num
=
int
(
(
smp_r
-
math
.
floor
(
smp_r
))
*
len
(
sub_batch_sampler
)
)
mul_samplers
=
mul_samplers
+
sub_batch_sampler
[:
smp_num
]
sub_batch_sampler
=
mul_samplers
else
:
logger
.
info
(
"dataset {} batch number is {} "
.
format
(
self
.
id_to_mode
[
i
],
len
(
sub_batch_sampler
)
)
)
batch_samplers
.
append
(
sub_batch_sampler
)
return
batch_samplers
class
LangPairMaskDataset
(
FairseqDataset
):
def
__init__
(
self
,
dataset
:
LanguagePairDataset
,
src_eos
:
int
,
src_bos
:
Optional
[
int
]
=
None
,
noise_id
:
Optional
[
int
]
=
-
1
,
mask_ratio
:
Optional
[
float
]
=
0
,
mask_type
:
Optional
[
str
]
=
"random"
,
):
self
.
dataset
=
dataset
self
.
src_eos
=
src_eos
self
.
src_bos
=
src_bos
self
.
noise_id
=
noise_id
self
.
mask_ratio
=
mask_ratio
self
.
mask_type
=
mask_type
assert
mask_type
in
(
"random"
,
"tail"
)
@
property
def
src_sizes
(
self
):
return
self
.
dataset
.
src_sizes
@
property
def
tgt_sizes
(
self
):
return
self
.
dataset
.
tgt_sizes
@
property
def
sizes
(
self
):
# dataset.sizes can be a dynamically computed sizes:
return
self
.
dataset
.
sizes
def
get_batch_shapes
(
self
):
if
hasattr
(
self
.
dataset
,
"get_batch_shapes"
):
return
self
.
dataset
.
get_batch_shapes
()
return
self
.
dataset
.
buckets
def
num_tokens_vec
(
self
,
indices
):
return
self
.
dataset
.
num_tokens_vec
(
indices
)
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
num_tokens
(
self
,
index
):
return
self
.
dataset
.
num_tokens
(
index
)
def
size
(
self
,
index
):
return
self
.
dataset
.
size
(
index
)
def
ordered_indices
(
self
):
return
self
.
dataset
.
ordered_indices
()
@
property
def
supports_prefetch
(
self
):
return
getattr
(
self
.
dataset
,
"supports_prefetch"
,
False
)
def
prefetch
(
self
,
indices
):
return
self
.
dataset
.
prefetch
(
indices
)
def
mask_src_tokens
(
self
,
sample
):
src_item
=
sample
[
"source"
]
mask
=
None
if
self
.
mask_type
==
"random"
:
mask
=
torch
.
rand
(
len
(
src_item
)).
le
(
self
.
mask_ratio
)
else
:
mask
=
torch
.
ones
(
len
(
src_item
))
mask
[:
int
(
len
(
src_item
)
*
(
1
-
self
.
mask_ratio
))]
=
0
mask
=
mask
.
eq
(
1
)
if
src_item
[
0
]
==
self
.
src_bos
:
mask
[
0
]
=
False
if
src_item
[
-
1
]
==
self
.
src_eos
:
mask
[
-
1
]
=
False
mask_src_item
=
src_item
.
masked_fill
(
mask
,
self
.
noise_id
)
smp
=
{
"id"
:
sample
[
"id"
],
"source"
:
mask_src_item
,
"target"
:
sample
[
"target"
]}
return
smp
def
__getitem__
(
self
,
index
):
sample
=
self
.
dataset
[
index
]
if
self
.
mask_ratio
>
0
:
sample
=
self
.
mask_src_tokens
(
sample
)
return
sample
def
collater
(
self
,
samples
,
pad_to_length
=
None
):
return
self
.
dataset
.
collater
(
samples
,
pad_to_length
)
class
FileAudioDatasetWrapper
(
FileAudioDataset
):
def
collater
(
self
,
samples
):
samples
=
super
().
collater
(
samples
)
if
len
(
samples
)
==
0
:
return
{}
samples
[
"net_input"
][
"src_tokens"
]
=
samples
[
"net_input"
][
"source"
]
samples
[
"net_input"
][
"prev_output_tokens"
]
=
None
del
samples
[
"net_input"
][
"source"
]
samples
[
"net_input"
][
"src_lengths"
]
=
None
samples
[
"net_input"
][
"alignment"
]
=
None
return
samples
PyTorch/NLP/new-Transformer/fairseq/data/audio/raw_audio_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
os
import
sys
import
io
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
..
import
FairseqDataset
from
..data_utils
import
compute_mask_indices
,
get_buckets
,
get_bucketed_sizes
from
fairseq.data.audio.audio_utils
import
(
parse_path
,
read_from_stored_zip
,
is_sf_audio_data
,
)
from
fairseq.data.text_compressor
import
TextCompressor
,
TextCompressionLevel
logger
=
logging
.
getLogger
(
__name__
)
class
RawAudioDataset
(
FairseqDataset
):
def
__init__
(
self
,
sample_rate
,
max_sample_size
=
None
,
min_sample_size
=
0
,
shuffle
=
True
,
pad
=
False
,
normalize
=
False
,
compute_mask_indices
=
False
,
**
mask_compute_kwargs
,
):
super
().
__init__
()
self
.
sample_rate
=
sample_rate
self
.
sizes
=
[]
self
.
max_sample_size
=
(
max_sample_size
if
max_sample_size
is
not
None
else
sys
.
maxsize
)
self
.
min_sample_size
=
min_sample_size
self
.
pad
=
pad
self
.
shuffle
=
shuffle
self
.
normalize
=
normalize
self
.
compute_mask_indices
=
compute_mask_indices
if
self
.
compute_mask_indices
:
self
.
mask_compute_kwargs
=
mask_compute_kwargs
self
.
_features_size_map
=
{}
self
.
_C
=
mask_compute_kwargs
[
"encoder_embed_dim"
]
self
.
_conv_feature_layers
=
eval
(
mask_compute_kwargs
[
"conv_feature_layers"
])
def
__getitem__
(
self
,
index
):
raise
NotImplementedError
()
def
__len__
(
self
):
return
len
(
self
.
sizes
)
def
postprocess
(
self
,
feats
,
curr_sample_rate
):
if
feats
.
dim
()
==
2
:
feats
=
feats
.
mean
(
-
1
)
if
curr_sample_rate
!=
self
.
sample_rate
:
raise
Exception
(
f
"sample rate:
{
curr_sample_rate
}
, need
{
self
.
sample_rate
}
"
)
assert
feats
.
dim
()
==
1
,
feats
.
dim
()
if
self
.
normalize
:
with
torch
.
no_grad
():
feats
=
F
.
layer_norm
(
feats
,
feats
.
shape
)
return
feats
def
crop_to_max_size
(
self
,
wav
,
target_size
):
size
=
len
(
wav
)
diff
=
size
-
target_size
if
diff
<=
0
:
return
wav
start
=
np
.
random
.
randint
(
0
,
diff
+
1
)
end
=
size
-
diff
+
start
return
wav
[
start
:
end
]
def
_compute_mask_indices
(
self
,
dims
,
padding_mask
):
B
,
T
,
C
=
dims
mask_indices
,
mask_channel_indices
=
None
,
None
if
self
.
mask_compute_kwargs
[
"mask_prob"
]
>
0
:
mask_indices
=
compute_mask_indices
(
(
B
,
T
),
padding_mask
,
self
.
mask_compute_kwargs
[
"mask_prob"
],
self
.
mask_compute_kwargs
[
"mask_length"
],
self
.
mask_compute_kwargs
[
"mask_selection"
],
self
.
mask_compute_kwargs
[
"mask_other"
],
min_masks
=
2
,
no_overlap
=
self
.
mask_compute_kwargs
[
"no_mask_overlap"
],
min_space
=
self
.
mask_compute_kwargs
[
"mask_min_space"
],
)
mask_indices
=
torch
.
from_numpy
(
mask_indices
)
if
self
.
mask_compute_kwargs
[
"mask_channel_prob"
]
>
0
:
mask_channel_indices
=
compute_mask_indices
(
(
B
,
C
),
None
,
self
.
mask_compute_kwargs
[
"mask_channel_prob"
],
self
.
mask_compute_kwargs
[
"mask_channel_length"
],
self
.
mask_compute_kwargs
[
"mask_channel_selection"
],
self
.
mask_compute_kwargs
[
"mask_channel_other"
],
no_overlap
=
self
.
mask_compute_kwargs
[
"no_mask_channel_overlap"
],
min_space
=
self
.
mask_compute_kwargs
[
"mask_channel_min_space"
],
)
mask_channel_indices
=
(
torch
.
from_numpy
(
mask_channel_indices
).
unsqueeze
(
1
).
expand
(
-
1
,
T
,
-
1
)
)
return
mask_indices
,
mask_channel_indices
@
staticmethod
def
_bucket_tensor
(
tensor
,
num_pad
,
value
):
return
F
.
pad
(
tensor
,
(
0
,
num_pad
),
value
=
value
)
def
collater
(
self
,
samples
):
samples
=
[
s
for
s
in
samples
if
s
[
"source"
]
is
not
None
]
if
len
(
samples
)
==
0
:
return
{}
sources
=
[
s
[
"source"
]
for
s
in
samples
]
sizes
=
[
len
(
s
)
for
s
in
sources
]
if
self
.
pad
:
target_size
=
min
(
max
(
sizes
),
self
.
max_sample_size
)
else
:
target_size
=
min
(
min
(
sizes
),
self
.
max_sample_size
)
collated_sources
=
sources
[
0
].
new_zeros
(
len
(
sources
),
target_size
)
padding_mask
=
(
torch
.
BoolTensor
(
collated_sources
.
shape
).
fill_
(
False
)
if
self
.
pad
else
None
)
for
i
,
(
source
,
size
)
in
enumerate
(
zip
(
sources
,
sizes
)):
diff
=
size
-
target_size
if
diff
==
0
:
collated_sources
[
i
]
=
source
elif
diff
<
0
:
assert
self
.
pad
collated_sources
[
i
]
=
torch
.
cat
(
[
source
,
source
.
new_full
((
-
diff
,),
0.0
)]
)
padding_mask
[
i
,
diff
:]
=
True
else
:
collated_sources
[
i
]
=
self
.
crop_to_max_size
(
source
,
target_size
)
input
=
{
"source"
:
collated_sources
}
out
=
{
"id"
:
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
])}
if
self
.
pad
:
input
[
"padding_mask"
]
=
padding_mask
if
hasattr
(
self
,
"num_buckets"
)
and
self
.
num_buckets
>
0
:
assert
self
.
pad
,
"Cannot bucket without padding first."
bucket
=
max
(
self
.
_bucketed_sizes
[
s
[
"id"
]]
for
s
in
samples
)
num_pad
=
bucket
-
collated_sources
.
size
(
-
1
)
if
num_pad
:
input
[
"source"
]
=
self
.
_bucket_tensor
(
collated_sources
,
num_pad
,
0
)
input
[
"padding_mask"
]
=
self
.
_bucket_tensor
(
padding_mask
,
num_pad
,
True
)
if
self
.
compute_mask_indices
:
B
=
input
[
"source"
].
size
(
0
)
T
=
self
.
_get_mask_indices_dims
(
input
[
"source"
].
size
(
-
1
))
padding_mask_reshaped
=
input
[
"padding_mask"
].
clone
()
extra
=
padding_mask_reshaped
.
size
(
1
)
%
T
if
extra
>
0
:
padding_mask_reshaped
=
padding_mask_reshaped
[:,
:
-
extra
]
padding_mask_reshaped
=
padding_mask_reshaped
.
view
(
padding_mask_reshaped
.
size
(
0
),
T
,
-
1
)
padding_mask_reshaped
=
padding_mask_reshaped
.
all
(
-
1
)
input
[
"padding_count"
]
=
padding_mask_reshaped
.
sum
(
-
1
).
max
().
item
()
mask_indices
,
mask_channel_indices
=
self
.
_compute_mask_indices
(
(
B
,
T
,
self
.
_C
),
padding_mask_reshaped
,
)
input
[
"mask_indices"
]
=
mask_indices
input
[
"mask_channel_indices"
]
=
mask_channel_indices
out
[
"sample_size"
]
=
mask_indices
.
sum
().
item
()
out
[
"net_input"
]
=
input
return
out
def
_get_mask_indices_dims
(
self
,
size
,
padding
=
0
,
dilation
=
1
):
if
size
not
in
self
.
_features_size_map
:
L_in
=
size
for
(
_
,
kernel_size
,
stride
)
in
self
.
_conv_feature_layers
:
L_out
=
L_in
+
2
*
padding
-
dilation
*
(
kernel_size
-
1
)
-
1
L_out
=
1
+
L_out
//
stride
L_in
=
L_out
self
.
_features_size_map
[
size
]
=
L_out
return
self
.
_features_size_map
[
size
]
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
if
self
.
pad
:
return
self
.
sizes
[
index
]
return
min
(
self
.
sizes
[
index
],
self
.
max_sample_size
)
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
order
.
append
(
np
.
minimum
(
np
.
array
(
self
.
sizes
),
self
.
max_sample_size
,
)
)
return
np
.
lexsort
(
order
)[::
-
1
]
else
:
return
np
.
arange
(
len
(
self
))
def
set_bucket_info
(
self
,
num_buckets
):
self
.
num_buckets
=
num_buckets
if
self
.
num_buckets
>
0
:
self
.
_collated_sizes
=
np
.
minimum
(
np
.
array
(
self
.
sizes
),
self
.
max_sample_size
,
)
self
.
buckets
=
get_buckets
(
self
.
_collated_sizes
,
self
.
num_buckets
,
)
self
.
_bucketed_sizes
=
get_bucketed_sizes
(
self
.
_collated_sizes
,
self
.
buckets
)
logger
.
info
(
f
"
{
len
(
self
.
buckets
)
}
bucket(s) for the audio dataset: "
f
"
{
self
.
buckets
}
"
)
class
FileAudioDataset
(
RawAudioDataset
):
def
__init__
(
self
,
manifest_path
,
sample_rate
,
max_sample_size
=
None
,
min_sample_size
=
0
,
shuffle
=
True
,
pad
=
False
,
normalize
=
False
,
num_buckets
=
0
,
compute_mask_indices
=
False
,
text_compression_level
=
TextCompressionLevel
.
none
,
**
mask_compute_kwargs
,
):
super
().
__init__
(
sample_rate
=
sample_rate
,
max_sample_size
=
max_sample_size
,
min_sample_size
=
min_sample_size
,
shuffle
=
shuffle
,
pad
=
pad
,
normalize
=
normalize
,
compute_mask_indices
=
compute_mask_indices
,
**
mask_compute_kwargs
,
)
self
.
text_compressor
=
TextCompressor
(
level
=
text_compression_level
)
skipped
=
0
self
.
fnames
=
[]
sizes
=
[]
self
.
skipped_indices
=
set
()
with
open
(
manifest_path
,
"r"
)
as
f
:
self
.
root_dir
=
f
.
readline
().
strip
()
for
i
,
line
in
enumerate
(
f
):
items
=
line
.
strip
().
split
(
"
\t
"
)
assert
len
(
items
)
==
2
,
line
sz
=
int
(
items
[
1
])
if
min_sample_size
is
not
None
and
sz
<
min_sample_size
:
skipped
+=
1
self
.
skipped_indices
.
add
(
i
)
continue
self
.
fnames
.
append
(
self
.
text_compressor
.
compress
(
items
[
0
]))
sizes
.
append
(
sz
)
logger
.
info
(
f
"loaded
{
len
(
self
.
fnames
)
}
, skipped
{
skipped
}
samples"
)
self
.
sizes
=
np
.
array
(
sizes
,
dtype
=
np
.
int64
)
try
:
import
pyarrow
self
.
fnames
=
pyarrow
.
array
(
self
.
fnames
)
except
:
logger
.
debug
(
"Could not create a pyarrow array. Please install pyarrow for better performance"
)
pass
self
.
set_bucket_info
(
num_buckets
)
def
__getitem__
(
self
,
index
):
import
soundfile
as
sf
fn
=
self
.
fnames
[
index
]
fn
=
fn
if
isinstance
(
self
.
fnames
,
list
)
else
fn
.
as_py
()
fn
=
self
.
text_compressor
.
decompress
(
fn
)
path_or_fp
=
os
.
path
.
join
(
self
.
root_dir
,
fn
)
_path
,
slice_ptr
=
parse_path
(
path_or_fp
)
if
len
(
slice_ptr
)
==
2
:
byte_data
=
read_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
])
assert
is_sf_audio_data
(
byte_data
)
path_or_fp
=
io
.
BytesIO
(
byte_data
)
wav
,
curr_sample_rate
=
sf
.
read
(
path_or_fp
,
dtype
=
"float32"
)
feats
=
torch
.
from_numpy
(
wav
).
float
()
feats
=
self
.
postprocess
(
feats
,
curr_sample_rate
)
return
{
"id"
:
index
,
"source"
:
feats
}
class
BinarizedAudioDataset
(
RawAudioDataset
):
def
__init__
(
self
,
data_dir
,
split
,
sample_rate
,
max_sample_size
=
None
,
min_sample_size
=
0
,
shuffle
=
True
,
pad
=
False
,
normalize
=
False
,
num_buckets
=
0
,
compute_mask_indices
=
False
,
**
mask_compute_kwargs
,
):
super
().
__init__
(
sample_rate
=
sample_rate
,
max_sample_size
=
max_sample_size
,
min_sample_size
=
min_sample_size
,
shuffle
=
shuffle
,
pad
=
pad
,
normalize
=
normalize
,
compute_mask_indices
=
compute_mask_indices
,
**
mask_compute_kwargs
,
)
from
fairseq.data
import
data_utils
,
Dictionary
self
.
fnames_dict
=
Dictionary
.
load
(
os
.
path
.
join
(
data_dir
,
"dict.txt"
))
root_path
=
os
.
path
.
join
(
data_dir
,
f
"
{
split
}
.root"
)
if
os
.
path
.
exists
(
root_path
):
with
open
(
root_path
,
"r"
)
as
f
:
self
.
root_dir
=
next
(
f
).
strip
()
else
:
self
.
root_dir
=
None
fnames_path
=
os
.
path
.
join
(
data_dir
,
split
)
self
.
fnames
=
data_utils
.
load_indexed_dataset
(
fnames_path
,
self
.
fnames_dict
)
lengths_path
=
os
.
path
.
join
(
data_dir
,
f
"
{
split
}
.lengths"
)
with
open
(
lengths_path
,
"r"
)
as
f
:
for
line
in
f
:
sz
=
int
(
line
.
rstrip
())
assert
(
sz
>=
min_sample_size
),
f
"Min sample size is not supported for binarized dataset, but found a sample with size
{
sz
}
"
self
.
sizes
.
append
(
sz
)
self
.
sizes
=
np
.
array
(
self
.
sizes
,
dtype
=
np
.
int64
)
self
.
set_bucket_info
(
num_buckets
)
logger
.
info
(
f
"loaded
{
len
(
self
.
fnames
)
}
samples"
)
def
__getitem__
(
self
,
index
):
import
soundfile
as
sf
fname
=
self
.
fnames_dict
.
string
(
self
.
fnames
[
index
],
separator
=
""
)
if
self
.
root_dir
:
fname
=
os
.
path
.
join
(
self
.
root_dir
,
fname
)
wav
,
curr_sample_rate
=
sf
.
read
(
fname
)
feats
=
torch
.
from_numpy
(
wav
).
float
()
feats
=
self
.
postprocess
(
feats
,
curr_sample_rate
)
return
{
"id"
:
index
,
"source"
:
feats
}
PyTorch/NLP/new-Transformer/fairseq/data/audio/speech_to_speech_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
,
Tuple
import
torch
from
fairseq.data
import
ConcatDataset
,
Dictionary
from
fairseq.data
import
data_utils
as
fairseq_data_utils
from
fairseq.data.audio.data_cfg
import
S2SDataConfig
from
fairseq.data.audio.speech_to_text_dataset
import
(
SpeechToTextDataset
,
SpeechToTextDatasetCreator
,
_collate_frames
,
get_features_or_waveform
,
)
logger
=
logging
.
getLogger
(
__name__
)
@
dataclass
class
SpeechToSpeechDatasetItem
(
object
):
index
:
int
source
:
torch
.
Tensor
target
:
Optional
[
torch
.
Tensor
]
=
None
target_speaker
:
Optional
[
torch
.
Tensor
]
=
None
tgt_lang_tag
:
Optional
[
int
]
=
None
class
SpeechToSpeechDataset
(
SpeechToTextDataset
):
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
data_cfg
:
S2SDataConfig
,
src_audio_paths
:
List
[
str
],
src_n_frames
:
List
[
int
],
tgt_audio_paths
:
List
[
str
],
tgt_n_frames
:
List
[
int
],
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
target_is_code
:
bool
=
False
,
tgt_dict
:
Dictionary
=
None
,
n_frames_per_step
:
int
=
1
,
):
tgt_texts
=
tgt_audio_paths
if
target_is_code
else
None
super
().
__init__
(
split
,
is_train_split
,
data_cfg
,
src_audio_paths
,
src_n_frames
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
tgt_texts
=
tgt_texts
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
n_frames_per_step
=
n_frames_per_step
,
)
self
.
tgt_audio_paths
=
tgt_audio_paths
self
.
tgt_lens
=
[
t
//
self
.
n_frames_per_step
for
t
in
tgt_n_frames
]
assert
not
target_is_code
or
tgt_dict
is
not
None
self
.
target_is_code
=
target_is_code
assert
len
(
tgt_audio_paths
)
==
self
.
n_samples
assert
len
(
tgt_n_frames
)
==
self
.
n_samples
self
.
tgt_speakers
=
None
if
self
.
cfg
.
target_speaker_embed
:
samples
=
SpeechToTextDatasetCreator
.
_load_samples_from_tsv
(
self
.
cfg
.
target_speaker_embed
,
split
)
spk_emb_dict
=
{
s
[
"id"
]:
s
[
"speaker_embed"
]
for
s
in
samples
}
self
.
tgt_speakers
=
[
spk_emb_dict
[
id
]
for
id
in
self
.
ids
]
assert
len
(
self
.
tgt_speakers
)
==
self
.
n_samples
logger
.
info
(
self
.
__repr__
())
def
pack_units
(
self
,
input
:
torch
.
Tensor
)
->
torch
.
Tensor
:
if
self
.
n_frames_per_step
<=
1
:
return
input
offset
=
4
vocab_size
=
(
len
(
self
.
tgt_dict
)
-
offset
)
# remove offset from <bos>, <pad>, <eos>, <unk>, which is specific to fairseq dictionary
assert
input
.
dim
()
==
1
stacked_input
=
(
input
[:
-
1
].
view
(
-
1
,
self
.
n_frames_per_step
)
-
offset
)
# remove <eos>
scale
=
[
pow
(
vocab_size
,
self
.
n_frames_per_step
-
1
-
i
)
for
i
in
range
(
self
.
n_frames_per_step
)
]
scale
=
torch
.
LongTensor
(
scale
).
squeeze
(
0
)
res
=
input
.
new
((
len
(
input
)
-
1
)
//
self
.
n_frames_per_step
+
1
).
fill_
(
input
[
-
1
])
res
[:
-
1
]
=
(
stacked_input
*
scale
).
sum
(
dim
=
1
)
+
offset
return
res
def
__getitem__
(
self
,
index
:
int
)
->
SpeechToSpeechDatasetItem
:
source
=
self
.
_get_source_audio
(
index
)
tgt_lang_tag
=
None
if
self
.
cfg
.
prepend_tgt_lang_tag_as_bos
:
# prepend_tgt_lang_tag_as_bos: put tgt_lang_tag as bos of target
tgt_lang_tag
=
self
.
get_lang_tag_idx
(
self
.
tgt_langs
[
index
],
self
.
tgt_dict
)
if
not
self
.
target_is_code
:
target
=
get_features_or_waveform
(
self
.
tgt_audio_paths
[
index
])
target
=
torch
.
from_numpy
(
target
).
float
()
target
=
self
.
pack_frames
(
target
)
else
:
target
=
self
.
tgt_dict
.
encode_line
(
self
.
tgt_audio_paths
[
index
],
add_if_not_exist
=
False
,
append_eos
=
True
,
).
long
()
if
self
.
n_frames_per_step
>
1
:
n_tgt_frame
=
target
.
size
(
0
)
-
1
# exclude <eos>
keep_n_tgt_frame
=
n_tgt_frame
-
n_tgt_frame
%
self
.
n_frames_per_step
target
=
torch
.
cat
(
(
target
[:
keep_n_tgt_frame
],
target
.
new_full
((
1
,),
self
.
tgt_dict
.
eos
()),
),
dim
=
0
,
)
if
self
.
tgt_speakers
:
tgt_spk
=
get_features_or_waveform
(
self
.
tgt_speakers
[
index
])
tgt_spk
=
torch
.
from_numpy
(
tgt_spk
).
float
()
else
:
tgt_spk
=
torch
.
FloatTensor
([])
return
SpeechToSpeechDatasetItem
(
index
=
index
,
source
=
source
,
target
=
target
,
target_speaker
=
tgt_spk
,
tgt_lang_tag
=
tgt_lang_tag
,
)
def
_collate_target
(
self
,
samples
:
List
[
SpeechToSpeechDatasetItem
])
->
torch
.
Tensor
:
if
self
.
target_is_code
:
target
=
fairseq_data_utils
.
collate_tokens
(
[
x
.
target
for
x
in
samples
],
self
.
tgt_dict
.
pad
(),
self
.
tgt_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
False
,
)
# convert stacked units to a single id
pack_targets
=
[
self
.
pack_units
(
x
.
target
)
for
x
in
samples
]
prev_output_tokens
=
fairseq_data_utils
.
collate_tokens
(
pack_targets
,
self
.
tgt_dict
.
pad
(),
self
.
tgt_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
True
,
)
target_lengths
=
torch
.
tensor
(
[
x
.
size
(
0
)
for
x
in
pack_targets
],
dtype
=
torch
.
long
)
else
:
target
=
_collate_frames
([
x
.
target
for
x
in
samples
],
is_audio_input
=
False
)
bsz
,
_
,
d
=
target
.
size
()
prev_output_tokens
=
torch
.
cat
(
(
target
.
new_full
((
bsz
,
1
,
d
),
0.0
),
target
[:,
:
-
1
,
:]),
dim
=
1
)
target_lengths
=
torch
.
tensor
(
[
x
.
target
.
size
(
0
)
for
x
in
samples
],
dtype
=
torch
.
long
)
return
target
,
prev_output_tokens
,
target_lengths
def
collater
(
self
,
samples
:
List
[
SpeechToSpeechDatasetItem
],
return_order
:
bool
=
False
)
->
Dict
:
if
len
(
samples
)
==
0
:
return
{}
indices
=
torch
.
tensor
([
x
.
index
for
x
in
samples
],
dtype
=
torch
.
long
)
frames
=
_collate_frames
([
x
.
source
for
x
in
samples
],
self
.
cfg
.
use_audio_input
)
# sort samples by descending number of frames
n_frames
=
torch
.
tensor
([
x
.
source
.
size
(
0
)
for
x
in
samples
],
dtype
=
torch
.
long
)
n_frames
,
order
=
n_frames
.
sort
(
descending
=
True
)
indices
=
indices
.
index_select
(
0
,
order
)
frames
=
frames
.
index_select
(
0
,
order
)
target
,
prev_output_tokens
,
target_lengths
=
self
.
_collate_target
(
samples
)
target
=
target
.
index_select
(
0
,
order
)
target_lengths
=
target_lengths
.
index_select
(
0
,
order
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
order
)
ntokens
=
sum
(
x
.
target
.
size
(
0
)
for
x
in
samples
)
tgt_speakers
=
None
if
self
.
cfg
.
target_speaker_embed
:
tgt_speakers
=
_collate_frames
(
[
x
.
target_speaker
for
x
in
samples
],
is_audio_input
=
True
).
index_select
(
0
,
order
)
net_input
=
{
"src_tokens"
:
frames
,
"src_lengths"
:
n_frames
,
"prev_output_tokens"
:
prev_output_tokens
,
"tgt_speaker"
:
tgt_speakers
,
# TODO: unify "speaker" and "tgt_speaker"
}
if
self
.
tgt_texts
is
not
None
and
samples
[
0
].
tgt_lang_tag
is
not
None
:
for
i
in
range
(
len
(
samples
)):
net_input
[
"prev_output_tokens"
][
i
][
0
]
=
samples
[
order
[
i
]].
tgt_lang_tag
out
=
{
"id"
:
indices
,
"net_input"
:
net_input
,
"speaker"
:
tgt_speakers
,
# to support Tacotron2 loss for speech-to-spectrogram model
"target"
:
target
,
"target_lengths"
:
target_lengths
,
"ntokens"
:
ntokens
,
"nsentences"
:
len
(
samples
),
}
if
return_order
:
out
[
"order"
]
=
order
return
out
class
TextTargetMultitaskData
(
object
):
# mandatory columns
KEY_ID
,
KEY_TEXT
=
"id"
,
"tgt_text"
def
__init__
(
self
,
args
,
split
,
tgt_dict
):
samples
=
SpeechToTextDatasetCreator
.
_load_samples_from_tsv
(
args
.
data
,
split
)
self
.
data
=
{
s
[
self
.
KEY_ID
]:
s
[
self
.
KEY_TEXT
]
for
s
in
samples
}
self
.
dict
=
tgt_dict
self
.
append_eos
=
args
.
decoder_type
!=
"ctc"
def
get
(
self
,
sample_id
):
if
sample_id
in
self
.
data
:
return
self
.
dict
.
encode_line
(
self
.
data
[
sample_id
],
add_if_not_exist
=
False
,
append_eos
=
self
.
append_eos
,
)
else
:
logger
.
warning
(
f
"no target for
{
sample_id
}
"
)
return
torch
.
IntTensor
([])
def
collater
(
self
,
samples
:
List
[
torch
.
Tensor
])
->
torch
.
Tensor
:
out
=
fairseq_data_utils
.
collate_tokens
(
samples
,
self
.
dict
.
pad
(),
self
.
dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
False
,
).
long
()
prev_out
=
fairseq_data_utils
.
collate_tokens
(
samples
,
self
.
dict
.
pad
(),
self
.
dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
True
,
).
long
()
target_lengths
=
torch
.
tensor
([
t
.
size
(
0
)
for
t
in
samples
],
dtype
=
torch
.
long
)
ntokens
=
sum
(
t
.
size
(
0
)
for
t
in
samples
)
output
=
{
"prev_output_tokens"
:
prev_out
,
"target"
:
out
,
"target_lengths"
:
target_lengths
,
"ntokens"
:
ntokens
,
}
return
output
class
SpeechToSpeechMultitaskDataset
(
SpeechToSpeechDataset
):
def
__init__
(
self
,
*
argv
):
super
().
__init__
(
*
argv
)
self
.
multitask_data
=
{}
def
add_multitask_dataset
(
self
,
task_name
,
task_data
):
self
.
multitask_data
[
task_name
]
=
task_data
def
__getitem__
(
self
,
index
:
int
)
->
Tuple
[
SpeechToSpeechDatasetItem
,
Dict
[
str
,
torch
.
Tensor
]]:
s2s_data
=
super
().
__getitem__
(
index
)
multitask_target
=
{}
sample_id
=
self
.
ids
[
index
]
for
task_name
,
task_dataset
in
self
.
multitask_data
.
items
():
multitask_target
[
task_name
]
=
task_dataset
.
get
(
sample_id
)
return
s2s_data
,
multitask_target
def
collater
(
self
,
samples
:
List
[
Tuple
[
SpeechToSpeechDatasetItem
,
Dict
[
str
,
torch
.
Tensor
]]]
)
->
Dict
:
if
len
(
samples
)
==
0
:
return
{}
out
=
super
().
collater
([
s
for
s
,
_
in
samples
],
return_order
=
True
)
order
=
out
[
"order"
]
del
out
[
"order"
]
for
task_name
,
task_dataset
in
self
.
multitask_data
.
items
():
if
"multitask"
not
in
out
:
out
[
"multitask"
]
=
{}
d
=
[
s
[
task_name
]
for
_
,
s
in
samples
]
task_target
=
task_dataset
.
collater
(
d
)
out
[
"multitask"
][
task_name
]
=
{
"target"
:
task_target
[
"target"
].
index_select
(
0
,
order
),
"target_lengths"
:
task_target
[
"target_lengths"
].
index_select
(
0
,
order
),
"ntokens"
:
task_target
[
"ntokens"
],
}
out
[
"multitask"
][
task_name
][
"net_input"
]
=
{
"prev_output_tokens"
:
task_target
[
"prev_output_tokens"
].
index_select
(
0
,
order
),
}
return
out
class
SpeechToSpeechDatasetCreator
(
object
):
# mandatory columns
KEY_ID
,
KEY_SRC_AUDIO
,
KEY_SRC_N_FRAMES
=
"id"
,
"src_audio"
,
"src_n_frames"
KEY_TGT_AUDIO
,
KEY_TGT_N_FRAMES
=
"tgt_audio"
,
"tgt_n_frames"
# optional columns
KEY_SRC_LANG
,
KEY_TGT_LANG
=
"src_lang"
,
"tgt_lang"
# default values
DEFAULT_LANG
=
""
@
classmethod
def
_from_list
(
cls
,
split_name
:
str
,
is_train_split
,
samples
:
List
[
Dict
],
data_cfg
:
S2SDataConfig
,
target_is_code
:
bool
=
False
,
target_dictionary
:
Dictionary
=
None
,
n_frames_per_step
:
int
=
1
,
multitask
:
Optional
[
Dict
]
=
None
,
)
->
SpeechToSpeechDataset
:
audio_root
=
Path
(
data_cfg
.
audio_root
)
ids
=
[
s
[
cls
.
KEY_ID
]
for
s
in
samples
]
src_audio_paths
=
[
(
audio_root
/
s
[
cls
.
KEY_SRC_AUDIO
]).
as_posix
()
for
s
in
samples
]
tgt_audio_paths
=
[
s
[
cls
.
KEY_TGT_AUDIO
]
if
target_is_code
else
(
audio_root
/
s
[
cls
.
KEY_TGT_AUDIO
]).
as_posix
()
for
s
in
samples
]
src_n_frames
=
[
int
(
s
[
cls
.
KEY_SRC_N_FRAMES
])
for
s
in
samples
]
tgt_n_frames
=
[
int
(
s
[
cls
.
KEY_TGT_N_FRAMES
])
for
s
in
samples
]
src_langs
=
[
s
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
tgt_langs
=
[
s
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
has_multitask
=
len
(
multitask
)
>
0
dataset_cls
=
(
SpeechToSpeechMultitaskDataset
if
has_multitask
else
SpeechToSpeechDataset
)
ds
=
dataset_cls
(
split_name
,
is_train_split
,
data_cfg
,
src_audio_paths
,
src_n_frames
,
tgt_audio_paths
,
tgt_n_frames
,
src_langs
,
tgt_langs
,
ids
,
target_is_code
,
target_dictionary
,
n_frames_per_step
,
)
if
has_multitask
:
for
task_name
,
task_obj
in
multitask
.
items
():
task_data
=
TextTargetMultitaskData
(
task_obj
.
args
,
split_name
,
task_obj
.
target_dictionary
)
ds
.
add_multitask_dataset
(
task_name
,
task_data
)
return
ds
@
classmethod
def
from_tsv
(
cls
,
root
:
str
,
data_cfg
:
S2SDataConfig
,
splits
:
str
,
is_train_split
:
bool
,
epoch
:
int
,
seed
:
int
,
target_is_code
:
bool
=
False
,
target_dictionary
:
Dictionary
=
None
,
n_frames_per_step
:
int
=
1
,
multitask
:
Optional
[
Dict
]
=
None
,
)
->
SpeechToSpeechDataset
:
datasets
=
[]
for
split
in
splits
.
split
(
","
):
samples
=
SpeechToTextDatasetCreator
.
_load_samples_from_tsv
(
root
,
split
)
ds
=
cls
.
_from_list
(
split
,
is_train_split
,
samples
,
data_cfg
,
target_is_code
,
target_dictionary
,
n_frames_per_step
,
multitask
,
)
datasets
.
append
(
ds
)
return
ConcatDataset
(
datasets
)
if
len
(
datasets
)
>
1
else
datasets
[
0
]
PyTorch/NLP/new-Transformer/fairseq/data/audio/speech_to_text_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
csv
import
io
import
logging
import
re
from
collections
import
defaultdict
from
dataclasses
import
dataclass
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
Optional
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
fairseq.data
import
ConcatDataset
,
Dictionary
,
FairseqDataset
,
ResamplingDataset
from
fairseq.data
import
data_utils
as
fairseq_data_utils
from
fairseq.data.audio.audio_utils
import
(
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
,
get_fbank
,
get_waveform
,
is_npy_data
,
is_sf_audio_data
,
parse_path
,
read_from_stored_zip
,
)
from
fairseq.data.audio.data_cfg
import
S2TDataConfig
from
fairseq.data.audio.feature_transforms
import
CompositeAudioFeatureTransform
logger
=
logging
.
getLogger
(
__name__
)
def
get_features_from_npy_or_audio
(
path
):
ext
=
Path
(
path
).
suffix
if
ext
not
in
FEATURE_OR_SF_AUDIO_FILE_EXTENSIONS
:
raise
ValueError
(
f
'Unsupported file format for "
{
path
}
"'
)
return
np
.
load
(
path
)
if
ext
==
".npy"
else
get_fbank
(
path
)
def
get_features_or_waveform_from_stored_zip
(
path
,
byte_offset
,
byte_size
,
need_waveform
=
False
,
use_sample_rate
=
None
,
):
assert
path
.
endswith
(
".zip"
)
data
=
read_from_stored_zip
(
path
,
byte_offset
,
byte_size
)
f
=
io
.
BytesIO
(
data
)
if
is_npy_data
(
data
):
features_or_waveform
=
np
.
load
(
f
)
elif
is_sf_audio_data
(
data
):
features_or_waveform
=
(
get_waveform
(
f
,
always_2d
=
False
,
output_sample_rate
=
use_sample_rate
)[
0
]
if
need_waveform
else
get_fbank
(
f
)
)
else
:
raise
ValueError
(
f
'Unknown file format for "
{
path
}
"'
)
return
features_or_waveform
def
get_features_or_waveform
(
path
:
str
,
need_waveform
=
False
,
use_sample_rate
=
None
):
"""Get speech features from .npy file or waveform from .wav/.flac file.
The file may be inside an uncompressed ZIP file and is accessed via byte
offset and length.
Args:
path (str): File path in the format of "<.npy/.wav/.flac path>" or
"<zip path>:<byte offset>:<byte length>".
need_waveform (bool): return waveform instead of features.
use_sample_rate (int): change sample rate for the input wave file
Returns:
features_or_waveform (numpy.ndarray): speech features or waveform.
"""
_path
,
slice_ptr
=
parse_path
(
path
)
if
len
(
slice_ptr
)
==
0
:
if
need_waveform
:
return
get_waveform
(
_path
,
always_2d
=
False
,
output_sample_rate
=
use_sample_rate
)[
0
]
return
get_features_from_npy_or_audio
(
_path
)
elif
len
(
slice_ptr
)
==
2
:
features_or_waveform
=
get_features_or_waveform_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
],
need_waveform
=
need_waveform
,
use_sample_rate
=
use_sample_rate
,
)
else
:
raise
ValueError
(
f
"Invalid path:
{
path
}
"
)
return
features_or_waveform
def
_collate_frames
(
frames
:
List
[
torch
.
Tensor
],
is_audio_input
:
bool
=
False
)
->
torch
.
Tensor
:
"""
Convert a list of 2D frames into a padded 3D tensor
Args:
frames (list): list of 2D frames of size L[i]*f_dim. Where L[i] is
length of i-th frame and f_dim is static dimension of features
Returns:
3D tensor of size len(frames)*len_max*f_dim where len_max is max of L[i]
"""
max_len
=
max
(
frame
.
size
(
0
)
for
frame
in
frames
)
if
is_audio_input
:
out
=
frames
[
0
].
new_zeros
((
len
(
frames
),
max_len
))
else
:
out
=
frames
[
0
].
new_zeros
((
len
(
frames
),
max_len
,
frames
[
0
].
size
(
1
)))
for
i
,
v
in
enumerate
(
frames
):
out
[
i
,
:
v
.
size
(
0
)]
=
v
return
out
@
dataclass
class
SpeechToTextDatasetItem
(
object
):
index
:
int
source
:
torch
.
Tensor
target
:
Optional
[
torch
.
Tensor
]
=
None
speaker_id
:
Optional
[
int
]
=
None
class
SpeechToTextDataset
(
FairseqDataset
):
LANG_TAG_TEMPLATE
=
"<lang:{}>"
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
cfg
:
S2TDataConfig
,
audio_paths
:
List
[
str
],
n_frames
:
List
[
int
],
src_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
speakers
:
Optional
[
List
[
str
]]
=
None
,
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
tgt_dict
:
Optional
[
Dictionary
]
=
None
,
pre_tokenizer
=
None
,
bpe_tokenizer
=
None
,
n_frames_per_step
=
1
,
speaker_to_id
=
None
,
append_eos
=
True
,
):
self
.
split
,
self
.
is_train_split
=
split
,
is_train_split
self
.
cfg
=
cfg
self
.
audio_paths
,
self
.
n_frames
=
audio_paths
,
n_frames
self
.
n_samples
=
len
(
audio_paths
)
assert
len
(
n_frames
)
==
self
.
n_samples
>
0
assert
src_texts
is
None
or
len
(
src_texts
)
==
self
.
n_samples
assert
tgt_texts
is
None
or
len
(
tgt_texts
)
==
self
.
n_samples
assert
speakers
is
None
or
len
(
speakers
)
==
self
.
n_samples
assert
src_langs
is
None
or
len
(
src_langs
)
==
self
.
n_samples
assert
tgt_langs
is
None
or
len
(
tgt_langs
)
==
self
.
n_samples
assert
ids
is
None
or
len
(
ids
)
==
self
.
n_samples
assert
(
tgt_dict
is
None
and
tgt_texts
is
None
)
or
(
tgt_dict
is
not
None
and
tgt_texts
is
not
None
)
self
.
src_texts
,
self
.
tgt_texts
=
src_texts
,
tgt_texts
self
.
src_langs
,
self
.
tgt_langs
=
src_langs
,
tgt_langs
self
.
speakers
=
speakers
self
.
tgt_dict
=
tgt_dict
self
.
check_tgt_lang_tag
()
self
.
ids
=
ids
self
.
shuffle
=
cfg
.
shuffle
if
is_train_split
else
False
self
.
feature_transforms
=
CompositeAudioFeatureTransform
.
from_config_dict
(
self
.
cfg
.
get_feature_transforms
(
split
,
is_train_split
)
)
self
.
pre_tokenizer
=
pre_tokenizer
self
.
bpe_tokenizer
=
bpe_tokenizer
self
.
n_frames_per_step
=
n_frames_per_step
self
.
speaker_to_id
=
speaker_to_id
self
.
tgt_lens
=
self
.
get_tgt_lens_and_check_oov
()
self
.
append_eos
=
append_eos
logger
.
info
(
self
.
__repr__
())
def
get_tgt_lens_and_check_oov
(
self
):
if
self
.
tgt_texts
is
None
:
return
[
0
for
_
in
range
(
self
.
n_samples
)]
tgt_lens
=
[]
n_tokens
,
n_oov_tokens
=
0
,
0
for
i
in
range
(
self
.
n_samples
):
tokenized
=
self
.
get_tokenized_tgt_text
(
i
).
split
(
" "
)
oov_tokens
=
[
t
for
t
in
tokenized
if
self
.
tgt_dict
.
index
(
t
)
==
self
.
tgt_dict
.
unk_index
]
n_tokens
+=
len
(
tokenized
)
n_oov_tokens
+=
len
(
oov_tokens
)
tgt_lens
.
append
(
len
(
tokenized
))
logger
.
info
(
f
"'
{
self
.
split
}
' has
{
n_oov_tokens
/
n_tokens
*
100
:.
2
f
}
% OOV"
)
return
tgt_lens
def
__repr__
(
self
):
return
(
self
.
__class__
.
__name__
+
f
'(split="
{
self
.
split
}
", n_samples=
{
self
.
n_samples
:
_
}
, '
f
"prepend_tgt_lang_tag=
{
self
.
cfg
.
prepend_tgt_lang_tag
}
, "
f
"shuffle=
{
self
.
shuffle
}
, transforms=
{
self
.
feature_transforms
}
, "
f
"n_frames_per_step=
{
self
.
n_frames_per_step
}
"
)
@
classmethod
def
is_lang_tag
(
cls
,
token
):
pattern
=
cls
.
LANG_TAG_TEMPLATE
.
replace
(
"{}"
,
"(.*)"
)
return
re
.
match
(
pattern
,
token
)
def
check_tgt_lang_tag
(
self
):
if
self
.
cfg
.
prepend_tgt_lang_tag
:
assert
self
.
tgt_langs
is
not
None
and
self
.
tgt_dict
is
not
None
tgt_lang_tags
=
[
self
.
LANG_TAG_TEMPLATE
.
format
(
t
)
for
t
in
set
(
self
.
tgt_langs
)
]
assert
all
(
t
in
self
.
tgt_dict
for
t
in
tgt_lang_tags
)
@
classmethod
def
tokenize
(
cls
,
tokenizer
,
text
:
str
):
return
text
if
tokenizer
is
None
else
tokenizer
.
encode
(
text
)
def
get_tokenized_tgt_text
(
self
,
index
:
int
):
text
=
self
.
tokenize
(
self
.
pre_tokenizer
,
self
.
tgt_texts
[
index
])
text
=
self
.
tokenize
(
self
.
bpe_tokenizer
,
text
)
return
text
def
pack_frames
(
self
,
feature
:
torch
.
Tensor
):
if
self
.
n_frames_per_step
==
1
:
return
feature
n_packed_frames
=
feature
.
shape
[
0
]
//
self
.
n_frames_per_step
feature
=
feature
[:
self
.
n_frames_per_step
*
n_packed_frames
]
return
feature
.
reshape
(
n_packed_frames
,
-
1
)
@
classmethod
def
get_lang_tag_idx
(
cls
,
lang
:
str
,
dictionary
:
Dictionary
):
lang_tag_idx
=
dictionary
.
index
(
cls
.
LANG_TAG_TEMPLATE
.
format
(
lang
))
assert
lang_tag_idx
!=
dictionary
.
unk
()
return
lang_tag_idx
def
_get_source_audio
(
self
,
index
:
int
)
->
torch
.
Tensor
:
source
=
get_features_or_waveform
(
self
.
audio_paths
[
index
],
need_waveform
=
self
.
cfg
.
use_audio_input
,
use_sample_rate
=
self
.
cfg
.
use_sample_rate
,
)
if
self
.
cfg
.
use_audio_input
:
source
=
torch
.
from_numpy
(
source
).
float
()
if
self
.
cfg
.
standardize_audio
:
with
torch
.
no_grad
():
source
=
F
.
layer_norm
(
source
,
source
.
shape
)
else
:
if
self
.
feature_transforms
is
not
None
:
source
=
self
.
feature_transforms
(
source
)
source
=
torch
.
from_numpy
(
source
).
float
()
return
source
def
__getitem__
(
self
,
index
:
int
)
->
SpeechToTextDatasetItem
:
source
=
self
.
_get_source_audio
(
index
)
source
=
self
.
pack_frames
(
source
)
target
=
None
if
self
.
tgt_texts
is
not
None
:
tokenized
=
self
.
get_tokenized_tgt_text
(
index
)
target
=
self
.
tgt_dict
.
encode_line
(
tokenized
,
add_if_not_exist
=
False
,
append_eos
=
self
.
append_eos
).
long
()
if
self
.
cfg
.
prepend_tgt_lang_tag
:
lang_tag_idx
=
self
.
get_lang_tag_idx
(
self
.
tgt_langs
[
index
],
self
.
tgt_dict
)
target
=
torch
.
cat
((
torch
.
LongTensor
([
lang_tag_idx
]),
target
),
0
)
if
self
.
cfg
.
prepend_bos_and_append_tgt_lang_tag
:
bos
=
torch
.
LongTensor
([
self
.
tgt_dict
.
bos
()])
lang_tag_idx
=
self
.
get_lang_tag_idx
(
self
.
tgt_langs
[
index
],
self
.
tgt_dict
)
assert
lang_tag_idx
!=
self
.
tgt_dict
.
unk
()
lang_tag_idx
=
torch
.
LongTensor
([
lang_tag_idx
])
target
=
torch
.
cat
((
bos
,
target
,
lang_tag_idx
),
0
)
speaker_id
=
None
if
self
.
speaker_to_id
is
not
None
:
speaker_id
=
self
.
speaker_to_id
[
self
.
speakers
[
index
]]
return
SpeechToTextDatasetItem
(
index
=
index
,
source
=
source
,
target
=
target
,
speaker_id
=
speaker_id
)
def
__len__
(
self
):
return
self
.
n_samples
def
collater
(
self
,
samples
:
List
[
SpeechToTextDatasetItem
],
return_order
:
bool
=
False
)
->
Dict
:
if
len
(
samples
)
==
0
:
return
{}
indices
=
torch
.
tensor
([
x
.
index
for
x
in
samples
],
dtype
=
torch
.
long
)
frames
=
_collate_frames
([
x
.
source
for
x
in
samples
],
self
.
cfg
.
use_audio_input
)
# sort samples by descending number of frames
n_frames
=
torch
.
tensor
([
x
.
source
.
size
(
0
)
for
x
in
samples
],
dtype
=
torch
.
long
)
n_frames
,
order
=
n_frames
.
sort
(
descending
=
True
)
indices
=
indices
.
index_select
(
0
,
order
)
frames
=
frames
.
index_select
(
0
,
order
)
target
,
target_lengths
=
None
,
None
prev_output_tokens
=
None
ntokens
=
None
if
self
.
tgt_texts
is
not
None
:
target
=
fairseq_data_utils
.
collate_tokens
(
[
x
.
target
for
x
in
samples
],
self
.
tgt_dict
.
pad
(),
self
.
tgt_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
False
,
)
target
=
target
.
index_select
(
0
,
order
)
target_lengths
=
torch
.
tensor
(
[
x
.
target
.
size
(
0
)
for
x
in
samples
],
dtype
=
torch
.
long
).
index_select
(
0
,
order
)
prev_output_tokens
=
fairseq_data_utils
.
collate_tokens
(
[
x
.
target
for
x
in
samples
],
self
.
tgt_dict
.
pad
(),
eos_idx
=
None
,
left_pad
=
False
,
move_eos_to_beginning
=
True
,
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
order
)
ntokens
=
sum
(
x
.
target
.
size
(
0
)
for
x
in
samples
)
speaker
=
None
if
self
.
speaker_to_id
is
not
None
:
speaker
=
(
torch
.
tensor
([
s
.
speaker_id
for
s
in
samples
],
dtype
=
torch
.
long
)
.
index_select
(
0
,
order
)
.
view
(
-
1
,
1
)
)
net_input
=
{
"src_tokens"
:
frames
,
"src_lengths"
:
n_frames
,
"prev_output_tokens"
:
prev_output_tokens
,
}
out
=
{
"id"
:
indices
,
"net_input"
:
net_input
,
"speaker"
:
speaker
,
"target"
:
target
,
"target_lengths"
:
target_lengths
,
"ntokens"
:
ntokens
,
"nsentences"
:
len
(
samples
),
}
if
return_order
:
out
[
"order"
]
=
order
return
out
def
num_tokens
(
self
,
index
):
return
self
.
n_frames
[
index
]
def
size
(
self
,
index
):
return
self
.
n_frames
[
index
],
self
.
tgt_lens
[
index
]
@
property
def
sizes
(
self
):
return
np
.
array
(
self
.
n_frames
)
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
True
def
ordered_indices
(
self
):
if
self
.
shuffle
:
order
=
[
np
.
random
.
permutation
(
len
(
self
))]
else
:
order
=
[
np
.
arange
(
len
(
self
))]
# first by descending order of # of frames then by original/random order
order
.
append
([
-
n
for
n
in
self
.
n_frames
])
return
np
.
lexsort
(
order
)
def
prefetch
(
self
,
indices
):
raise
False
class
SpeechToTextDatasetCreator
(
object
):
# mandatory columns
KEY_ID
,
KEY_AUDIO
,
KEY_N_FRAMES
=
"id"
,
"audio"
,
"n_frames"
KEY_TGT_TEXT
=
"tgt_text"
# optional columns
KEY_SPEAKER
,
KEY_SRC_TEXT
=
"speaker"
,
"src_text"
KEY_SRC_LANG
,
KEY_TGT_LANG
=
"src_lang"
,
"tgt_lang"
# default values
DEFAULT_SPEAKER
=
DEFAULT_SRC_TEXT
=
DEFAULT_LANG
=
""
@
classmethod
def
_from_list
(
cls
,
split_name
:
str
,
is_train_split
,
samples
:
List
[
Dict
],
cfg
:
S2TDataConfig
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
n_frames_per_step
,
speaker_to_id
,
)
->
SpeechToTextDataset
:
audio_root
=
Path
(
cfg
.
audio_root
)
ids
=
[
s
[
cls
.
KEY_ID
]
for
s
in
samples
]
audio_paths
=
[(
audio_root
/
s
[
cls
.
KEY_AUDIO
]).
as_posix
()
for
s
in
samples
]
n_frames
=
[
int
(
s
[
cls
.
KEY_N_FRAMES
])
for
s
in
samples
]
tgt_texts
=
[
s
[
cls
.
KEY_TGT_TEXT
]
for
s
in
samples
]
src_texts
=
[
s
.
get
(
cls
.
KEY_SRC_TEXT
,
cls
.
DEFAULT_SRC_TEXT
)
for
s
in
samples
]
speakers
=
[
s
.
get
(
cls
.
KEY_SPEAKER
,
cls
.
DEFAULT_SPEAKER
)
for
s
in
samples
]
src_langs
=
[
s
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
tgt_langs
=
[
s
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
return
SpeechToTextDataset
(
split_name
,
is_train_split
,
cfg
,
audio_paths
,
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
n_frames_per_step
=
n_frames_per_step
,
speaker_to_id
=
speaker_to_id
,
)
@
classmethod
def
get_size_ratios
(
cls
,
datasets
:
List
[
SpeechToTextDataset
],
alpha
:
float
=
1.0
)
->
List
[
float
]:
"""Size ratios for temperature-based sampling
(https://arxiv.org/abs/1907.05019)"""
id_to_lp
,
lp_to_sz
=
{},
defaultdict
(
int
)
for
ds
in
datasets
:
lang_pairs
=
{
f
"
{
s
}
->
{
t
}
"
for
s
,
t
in
zip
(
ds
.
src_langs
,
ds
.
tgt_langs
)}
assert
len
(
lang_pairs
)
==
1
lang_pair
=
list
(
lang_pairs
)[
0
]
id_to_lp
[
ds
.
split
]
=
lang_pair
lp_to_sz
[
lang_pair
]
+=
sum
(
ds
.
n_frames
)
sz_sum
=
sum
(
v
for
v
in
lp_to_sz
.
values
())
lp_to_prob
=
{
k
:
v
/
sz_sum
for
k
,
v
in
lp_to_sz
.
items
()}
lp_to_tgt_prob
=
{
k
:
v
**
alpha
for
k
,
v
in
lp_to_prob
.
items
()}
prob_sum
=
sum
(
v
for
v
in
lp_to_tgt_prob
.
values
())
lp_to_tgt_prob
=
{
k
:
v
/
prob_sum
for
k
,
v
in
lp_to_tgt_prob
.
items
()}
lp_to_sz_ratio
=
{
k
:
(
lp_to_tgt_prob
[
k
]
*
sz_sum
)
/
v
for
k
,
v
in
lp_to_sz
.
items
()
}
size_ratio
=
[
lp_to_sz_ratio
[
id_to_lp
[
ds
.
split
]]
for
ds
in
datasets
]
p_formatted
=
{
k
:
f
"
{
lp_to_prob
[
k
]:.
3
f
}
->
{
lp_to_tgt_prob
[
k
]:.
3
f
}
"
for
k
in
lp_to_sz
}
logger
.
info
(
f
"sampling probability balancing:
{
p_formatted
}
"
)
sr_formatted
=
{
ds
.
split
:
f
"
{
r
:.
3
f
}
"
for
ds
,
r
in
zip
(
datasets
,
size_ratio
)}
logger
.
info
(
f
"balanced sampling size ratio:
{
sr_formatted
}
"
)
return
size_ratio
@
classmethod
def
_load_samples_from_tsv
(
cls
,
root
:
str
,
split
:
str
):
tsv_path
=
Path
(
root
)
/
f
"
{
split
}
.tsv"
if
not
tsv_path
.
is_file
():
raise
FileNotFoundError
(
f
"Dataset not found:
{
tsv_path
}
"
)
with
open
(
tsv_path
)
as
f
:
reader
=
csv
.
DictReader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
None
,
doublequote
=
False
,
lineterminator
=
"
\n
"
,
quoting
=
csv
.
QUOTE_NONE
,
)
samples
=
[
dict
(
e
)
for
e
in
reader
]
if
len
(
samples
)
==
0
:
raise
ValueError
(
f
"Empty manifest:
{
tsv_path
}
"
)
return
samples
@
classmethod
def
_from_tsv
(
cls
,
root
:
str
,
cfg
:
S2TDataConfig
,
split
:
str
,
tgt_dict
,
is_train_split
:
bool
,
pre_tokenizer
,
bpe_tokenizer
,
n_frames_per_step
,
speaker_to_id
,
)
->
SpeechToTextDataset
:
samples
=
cls
.
_load_samples_from_tsv
(
root
,
split
)
return
cls
.
_from_list
(
split
,
is_train_split
,
samples
,
cfg
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
n_frames_per_step
,
speaker_to_id
,
)
@
classmethod
def
from_tsv
(
cls
,
root
:
str
,
cfg
:
S2TDataConfig
,
splits
:
str
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
is_train_split
:
bool
,
epoch
:
int
,
seed
:
int
,
n_frames_per_step
:
int
=
1
,
speaker_to_id
=
None
,
)
->
SpeechToTextDataset
:
datasets
=
[
cls
.
_from_tsv
(
root
,
cfg
,
split
,
tgt_dict
,
is_train_split
,
pre_tokenizer
,
bpe_tokenizer
,
n_frames_per_step
,
speaker_to_id
,
)
for
split
in
splits
.
split
(
","
)
]
if
is_train_split
and
len
(
datasets
)
>
1
and
cfg
.
sampling_alpha
!=
1.0
:
# temperature-based sampling
size_ratios
=
cls
.
get_size_ratios
(
datasets
,
alpha
=
cfg
.
sampling_alpha
)
datasets
=
[
ResamplingDataset
(
d
,
size_ratio
=
r
,
seed
=
seed
,
epoch
=
epoch
,
replace
=
(
r
>=
1.0
)
)
for
r
,
d
in
zip
(
size_ratios
,
datasets
)
]
return
ConcatDataset
(
datasets
)
if
len
(
datasets
)
>
1
else
datasets
[
0
]
PyTorch/NLP/new-Transformer/fairseq/data/audio/speech_to_text_joint_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
from
pathlib
import
Path
from
typing
import
Dict
,
List
,
NamedTuple
,
Optional
import
torch
from
fairseq.data
import
ConcatDataset
,
Dictionary
,
ResamplingDataset
from
fairseq.data
import
data_utils
as
fairseq_data_utils
from
fairseq.data.audio.speech_to_text_dataset
import
(
S2TDataConfig
,
SpeechToTextDataset
,
SpeechToTextDatasetCreator
,
)
logger
=
logging
.
getLogger
(
__name__
)
class
S2TJointDataConfig
(
S2TDataConfig
):
"""Wrapper class for data config YAML"""
@
property
def
src_vocab_filename
(
self
):
"""fairseq vocabulary file under data root"""
return
self
.
config
.
get
(
"src_vocab_filename"
,
"src_dict.txt"
)
@
property
def
src_pre_tokenizer
(
self
)
->
Dict
:
"""Pre-tokenizer to apply before subword tokenization. Returning
a dictionary with `tokenizer` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return
self
.
config
.
get
(
"src_pre_tokenizer"
,
{
"tokenizer"
:
None
})
@
property
def
src_bpe_tokenizer
(
self
)
->
Dict
:
"""Subword tokenizer to apply on source text after pre-tokenization.
Returning a dictionary with `bpe` providing the tokenizer name and
the other items providing the tokenizer-specific arguments.
Tokenizers are defined in `fairseq.data.encoders.*`"""
return
self
.
config
.
get
(
"src_bpe_tokenizer"
,
{
"bpe"
:
None
})
@
property
def
prepend_tgt_lang_tag_no_change
(
self
)
->
bool
:
"""Prepend target lang ID token as the prev_output_tokens BOS (e.g. for
to-many multilingual setting). No change needed during inference.
This option is deprecated and replaced by prepend_tgt_lang_tag_as_bos.
"""
value
=
self
.
config
.
get
(
"prepend_tgt_lang_tag_no_change"
,
None
)
if
value
is
None
:
return
self
.
config
.
get
(
"prepend_tgt_lang_tag_as_bos"
,
False
)
return
value
@
property
def
sampling_text_alpha
(
self
):
"""Hyper-parameter alpha = 1/T for temperature-based resampling. (text
input only) (alpha = 1 for no resampling)"""
return
self
.
config
.
get
(
"sampling_text_alpha"
,
1.0
)
class
SpeechToTextJointDatasetItem
(
NamedTuple
):
index
:
int
source
:
torch
.
Tensor
target
:
Optional
[
torch
.
Tensor
]
=
None
src_txt_tokens
:
Optional
[
torch
.
Tensor
]
=
None
tgt_lang_tag
:
Optional
[
int
]
=
None
src_lang_tag
:
Optional
[
int
]
=
None
tgt_alignment
:
Optional
[
torch
.
Tensor
]
=
None
# use_src_lang_id:
# 0: don't use src_lang_id
# 1: attach src_lang_id to the src_txt_tokens as eos
class
SpeechToTextJointDataset
(
SpeechToTextDataset
):
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
cfg
:
S2TJointDataConfig
,
audio_paths
:
List
[
str
],
n_frames
:
List
[
int
],
src_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
speakers
:
Optional
[
List
[
str
]]
=
None
,
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
tgt_dict
:
Optional
[
Dictionary
]
=
None
,
src_dict
:
Optional
[
Dictionary
]
=
None
,
pre_tokenizer
=
None
,
bpe_tokenizer
=
None
,
src_pre_tokenizer
=
None
,
src_bpe_tokenizer
=
None
,
append_eos
:
Optional
[
bool
]
=
True
,
alignment
:
Optional
[
List
[
str
]]
=
None
,
use_src_lang_id
:
Optional
[
int
]
=
0
,
):
super
().
__init__
(
split
,
is_train_split
,
cfg
,
audio_paths
,
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
append_eos
=
append_eos
,
)
self
.
src_dict
=
src_dict
self
.
src_pre_tokenizer
=
src_pre_tokenizer
self
.
src_bpe_tokenizer
=
src_bpe_tokenizer
self
.
alignment
=
None
self
.
use_src_lang_id
=
use_src_lang_id
if
alignment
is
not
None
:
self
.
alignment
=
[
[
float
(
s
)
for
s
in
sample
.
split
()]
for
sample
in
alignment
]
def
get_tokenized_src_text
(
self
,
index
:
int
):
text
=
self
.
tokenize
(
self
.
src_pre_tokenizer
,
self
.
src_texts
[
index
])
text
=
self
.
tokenize
(
self
.
src_bpe_tokenizer
,
text
)
return
text
def
__getitem__
(
self
,
index
:
int
)
->
SpeechToTextJointDatasetItem
:
s2t_dataset_item
=
super
().
__getitem__
(
index
)
src_tokens
=
None
src_lang_tag
=
None
if
self
.
src_texts
is
not
None
and
self
.
src_dict
is
not
None
:
src_tokens
=
self
.
get_tokenized_src_text
(
index
)
src_tokens
=
self
.
src_dict
.
encode_line
(
src_tokens
,
add_if_not_exist
=
False
,
append_eos
=
True
).
long
()
if
self
.
use_src_lang_id
>
0
:
src_lang_tag
=
self
.
get_lang_tag_idx
(
self
.
src_langs
[
index
],
self
.
src_dict
)
tgt_lang_tag
=
None
if
self
.
cfg
.
prepend_tgt_lang_tag_no_change
:
# prepend_tgt_lang_tag_no_change: modify prev_output_tokens instead
tgt_lang_tag
=
self
.
get_lang_tag_idx
(
self
.
tgt_langs
[
index
],
self
.
tgt_dict
)
ali
=
None
if
self
.
alignment
is
not
None
:
ali
=
torch
.
Tensor
(
self
.
alignment
[
index
]).
float
()
return
SpeechToTextJointDatasetItem
(
index
=
index
,
source
=
s2t_dataset_item
.
source
,
target
=
s2t_dataset_item
.
target
,
src_txt_tokens
=
src_tokens
,
tgt_lang_tag
=
tgt_lang_tag
,
src_lang_tag
=
src_lang_tag
,
tgt_alignment
=
ali
,
)
def
__len__
(
self
):
return
self
.
n_samples
def
collater
(
self
,
samples
:
List
[
SpeechToTextJointDatasetItem
])
->
Dict
:
s2t_out
=
super
().
collater
(
samples
,
return_order
=
True
)
if
s2t_out
==
{}:
return
s2t_out
net_input
,
order
=
s2t_out
[
"net_input"
],
s2t_out
[
"order"
]
if
self
.
src_texts
is
not
None
and
self
.
src_dict
is
not
None
:
src_txt_tokens
=
fairseq_data_utils
.
collate_tokens
(
[
x
.
src_txt_tokens
for
x
in
samples
],
self
.
src_dict
.
pad
(),
self
.
src_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
False
,
)
src_txt_lengths
=
torch
.
tensor
(
[
x
.
src_txt_tokens
.
size
()[
0
]
for
x
in
samples
],
dtype
=
torch
.
long
)
if
self
.
use_src_lang_id
>
0
:
src_lang_idxs
=
torch
.
tensor
(
[
s
.
src_lang_tag
for
s
in
samples
],
dtype
=
src_txt_tokens
.
dtype
)
if
self
.
use_src_lang_id
==
1
:
# replace eos with lang_id
eos_idx
=
src_txt_lengths
-
1
src_txt_tokens
.
scatter_
(
1
,
eos_idx
.
view
(
-
1
,
1
),
src_lang_idxs
.
view
(
-
1
,
1
)
)
else
:
raise
NotImplementedError
(
"Implementation is required"
)
src_txt_tokens
=
src_txt_tokens
.
index_select
(
0
,
order
)
src_txt_lengths
=
src_txt_lengths
.
index_select
(
0
,
order
)
net_input
[
"src_txt_tokens"
]
=
src_txt_tokens
net_input
[
"src_txt_lengths"
]
=
src_txt_lengths
net_input
[
"alignment"
]
=
None
if
self
.
alignment
is
not
None
:
max_len
=
max
([
s
.
tgt_alignment
.
size
(
0
)
for
s
in
samples
])
alignment
=
torch
.
ones
(
len
(
samples
),
max_len
).
float
()
for
i
,
s
in
enumerate
(
samples
):
cur_len
=
s
.
tgt_alignment
.
size
(
0
)
alignment
[
i
][:
cur_len
].
copy_
(
s
.
tgt_alignment
)
net_input
[
"alignment"
]
=
alignment
.
index_select
(
0
,
order
)
if
self
.
tgt_texts
is
not
None
and
samples
[
0
].
tgt_lang_tag
is
not
None
:
for
i
in
range
(
len
(
samples
)):
net_input
[
"prev_output_tokens"
][
i
][
0
]
=
samples
[
order
[
i
]].
tgt_lang_tag
out
=
{
"id"
:
s2t_out
[
"id"
],
"net_input"
:
net_input
,
"target"
:
s2t_out
[
"target"
],
"target_lengths"
:
s2t_out
[
"target_lengths"
],
"ntokens"
:
s2t_out
[
"ntokens"
],
"nsentences"
:
len
(
samples
),
}
return
out
class
SpeechToTextJointDatasetCreator
(
SpeechToTextDatasetCreator
):
KEY_ALIGN
=
"align"
@
classmethod
def
_from_list
(
cls
,
split_name
:
str
,
is_train_split
,
samples
:
List
[
Dict
],
cfg
:
S2TJointDataConfig
,
tgt_dict
,
src_dict
,
pre_tokenizer
,
bpe_tokenizer
,
src_pre_tokenizer
,
src_bpe_tokenizer
,
append_eos
,
use_src_lang_id
,
)
->
SpeechToTextJointDataset
:
audio_root
=
Path
(
cfg
.
audio_root
)
ids
=
[
s
[
cls
.
KEY_ID
]
for
s
in
samples
]
audio_paths
=
[(
audio_root
/
s
[
cls
.
KEY_AUDIO
]).
as_posix
()
for
s
in
samples
]
n_frames
=
[
int
(
s
[
cls
.
KEY_N_FRAMES
])
for
s
in
samples
]
tgt_texts
=
[
s
[
cls
.
KEY_TGT_TEXT
]
for
s
in
samples
]
src_texts
=
[
s
.
get
(
cls
.
KEY_SRC_TEXT
,
cls
.
DEFAULT_SRC_TEXT
)
for
s
in
samples
]
speakers
=
[
s
.
get
(
cls
.
KEY_SPEAKER
,
cls
.
DEFAULT_SPEAKER
)
for
s
in
samples
]
src_langs
=
[
s
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
tgt_langs
=
[
s
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
tgt_alignment
=
None
if
cls
.
KEY_ALIGN
in
samples
[
0
].
keys
():
tgt_alignment
=
[
s
[
cls
.
KEY_ALIGN
]
for
s
in
samples
]
return
SpeechToTextJointDataset
(
split_name
,
is_train_split
,
cfg
,
audio_paths
,
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
src_dict
=
src_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
src_pre_tokenizer
=
src_pre_tokenizer
,
src_bpe_tokenizer
=
src_bpe_tokenizer
,
append_eos
=
append_eos
,
alignment
=
tgt_alignment
,
use_src_lang_id
=
use_src_lang_id
,
)
@
classmethod
def
_from_tsv
(
cls
,
root
:
str
,
cfg
:
S2TJointDataConfig
,
split
:
str
,
tgt_dict
,
src_dict
,
is_train_split
:
bool
,
pre_tokenizer
,
bpe_tokenizer
,
src_pre_tokenizer
,
src_bpe_tokenizer
,
append_eos
:
bool
,
use_src_lang_id
:
int
,
)
->
SpeechToTextJointDataset
:
samples
=
cls
.
_load_samples_from_tsv
(
root
,
split
)
return
cls
.
_from_list
(
split
,
is_train_split
,
samples
,
cfg
,
tgt_dict
,
src_dict
,
pre_tokenizer
,
bpe_tokenizer
,
src_pre_tokenizer
,
src_bpe_tokenizer
,
append_eos
,
use_src_lang_id
,
)
@
classmethod
def
from_tsv
(
cls
,
root
:
str
,
cfg
:
S2TJointDataConfig
,
splits
:
str
,
tgt_dict
,
src_dict
,
pre_tokenizer
,
bpe_tokenizer
,
src_pre_tokenizer
,
src_bpe_tokenizer
,
is_train_split
:
bool
,
epoch
:
int
,
seed
:
int
,
append_eos
:
Optional
[
bool
]
=
True
,
use_src_lang_id
:
Optional
[
int
]
=
0
,
)
->
SpeechToTextJointDataset
:
datasets
=
[
cls
.
_from_tsv
(
root
,
cfg
,
split
,
tgt_dict
,
src_dict
,
is_train_split
,
pre_tokenizer
,
bpe_tokenizer
,
src_pre_tokenizer
,
src_bpe_tokenizer
,
append_eos
=
append_eos
,
use_src_lang_id
=
use_src_lang_id
,
)
for
split
in
splits
.
split
(
","
)
]
if
is_train_split
and
len
(
datasets
)
>
1
and
cfg
.
sampling_alpha
!=
1.0
:
# temperature-based sampling
size_ratios
=
cls
.
get_size_ratios
(
datasets
,
alpha
=
cfg
.
sampling_alpha
)
datasets
=
[
ResamplingDataset
(
d
,
size_ratio
=
r
,
seed
=
seed
,
epoch
=
epoch
,
replace
=
(
r
>=
1.0
)
)
for
r
,
d
in
zip
(
size_ratios
,
datasets
)
]
return
ConcatDataset
(
datasets
)
if
len
(
datasets
)
>
1
else
datasets
[
0
]
PyTorch/NLP/new-Transformer/fairseq/data/audio/text_to_speech_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
#
# This source code is licensed under the license found in the LICENSE file in
# the root directory of this source tree. An additional grant of patent rights
# can be found in the PATENTS file in the same directory.abs
from
pathlib
import
Path
from
typing
import
List
,
Dict
,
Optional
,
Any
from
dataclasses
import
dataclass
import
numpy
as
np
import
torch
from
fairseq.data.audio.speech_to_text_dataset
import
(
SpeechToTextDataset
,
SpeechToTextDatasetCreator
,
S2TDataConfig
,
_collate_frames
,
get_features_or_waveform
,
)
from
fairseq.data
import
Dictionary
,
data_utils
as
fairseq_data_utils
@
dataclass
class
TextToSpeechDatasetItem
(
object
):
index
:
int
source
:
torch
.
Tensor
target
:
Optional
[
torch
.
Tensor
]
=
None
speaker_id
:
Optional
[
int
]
=
None
duration
:
Optional
[
torch
.
Tensor
]
=
None
pitch
:
Optional
[
torch
.
Tensor
]
=
None
energy
:
Optional
[
torch
.
Tensor
]
=
None
class
TextToSpeechDataset
(
SpeechToTextDataset
):
def
__init__
(
self
,
split
:
str
,
is_train_split
:
bool
,
cfg
:
S2TDataConfig
,
audio_paths
:
List
[
str
],
n_frames
:
List
[
int
],
src_texts
:
Optional
[
List
[
str
]]
=
None
,
tgt_texts
:
Optional
[
List
[
str
]]
=
None
,
speakers
:
Optional
[
List
[
str
]]
=
None
,
src_langs
:
Optional
[
List
[
str
]]
=
None
,
tgt_langs
:
Optional
[
List
[
str
]]
=
None
,
ids
:
Optional
[
List
[
str
]]
=
None
,
tgt_dict
:
Optional
[
Dictionary
]
=
None
,
pre_tokenizer
=
None
,
bpe_tokenizer
=
None
,
n_frames_per_step
=
1
,
speaker_to_id
=
None
,
durations
:
Optional
[
List
[
List
[
int
]]]
=
None
,
pitches
:
Optional
[
List
[
str
]]
=
None
,
energies
:
Optional
[
List
[
str
]]
=
None
,
):
super
(
TextToSpeechDataset
,
self
).
__init__
(
split
,
is_train_split
,
cfg
,
audio_paths
,
n_frames
,
src_texts
=
src_texts
,
tgt_texts
=
tgt_texts
,
speakers
=
speakers
,
src_langs
=
src_langs
,
tgt_langs
=
tgt_langs
,
ids
=
ids
,
tgt_dict
=
tgt_dict
,
pre_tokenizer
=
pre_tokenizer
,
bpe_tokenizer
=
bpe_tokenizer
,
n_frames_per_step
=
n_frames_per_step
,
speaker_to_id
=
speaker_to_id
,
)
self
.
durations
=
durations
self
.
pitches
=
pitches
self
.
energies
=
energies
def
__getitem__
(
self
,
index
:
int
)
->
TextToSpeechDatasetItem
:
s2t_item
=
super
().
__getitem__
(
index
)
duration
,
pitch
,
energy
=
None
,
None
,
None
if
self
.
durations
is
not
None
:
duration
=
torch
.
tensor
(
self
.
durations
[
index
]
+
[
0
],
dtype
=
torch
.
long
# pad 0 for EOS
)
if
self
.
pitches
is
not
None
:
pitch
=
get_features_or_waveform
(
self
.
pitches
[
index
])
pitch
=
torch
.
from_numpy
(
np
.
concatenate
((
pitch
,
[
0
]))
# pad 0 for EOS
).
float
()
if
self
.
energies
is
not
None
:
energy
=
get_features_or_waveform
(
self
.
energies
[
index
])
energy
=
torch
.
from_numpy
(
np
.
concatenate
((
energy
,
[
0
]))
# pad 0 for EOS
).
float
()
return
TextToSpeechDatasetItem
(
index
=
index
,
source
=
s2t_item
.
source
,
target
=
s2t_item
.
target
,
speaker_id
=
s2t_item
.
speaker_id
,
duration
=
duration
,
pitch
=
pitch
,
energy
=
energy
,
)
def
collater
(
self
,
samples
:
List
[
TextToSpeechDatasetItem
])
->
Dict
[
str
,
Any
]:
if
len
(
samples
)
==
0
:
return
{}
src_lengths
,
order
=
torch
.
tensor
(
[
s
.
target
.
shape
[
0
]
for
s
in
samples
],
dtype
=
torch
.
long
).
sort
(
descending
=
True
)
id_
=
torch
.
tensor
([
s
.
index
for
s
in
samples
],
dtype
=
torch
.
long
).
index_select
(
0
,
order
)
feat
=
_collate_frames
(
[
s
.
source
for
s
in
samples
],
self
.
cfg
.
use_audio_input
).
index_select
(
0
,
order
)
target_lengths
=
torch
.
tensor
(
[
s
.
source
.
shape
[
0
]
for
s
in
samples
],
dtype
=
torch
.
long
).
index_select
(
0
,
order
)
src_tokens
=
fairseq_data_utils
.
collate_tokens
(
[
s
.
target
for
s
in
samples
],
self
.
tgt_dict
.
pad
(),
self
.
tgt_dict
.
eos
(),
left_pad
=
False
,
move_eos_to_beginning
=
False
,
).
index_select
(
0
,
order
)
speaker
=
None
if
self
.
speaker_to_id
is
not
None
:
speaker
=
(
torch
.
tensor
([
s
.
speaker_id
for
s
in
samples
],
dtype
=
torch
.
long
)
.
index_select
(
0
,
order
)
.
view
(
-
1
,
1
)
)
bsz
,
_
,
d
=
feat
.
size
()
prev_output_tokens
=
torch
.
cat
(
(
feat
.
new_zeros
((
bsz
,
1
,
d
)),
feat
[:,
:
-
1
,
:]),
dim
=
1
)
durations
,
pitches
,
energies
=
None
,
None
,
None
if
self
.
durations
is
not
None
:
durations
=
fairseq_data_utils
.
collate_tokens
(
[
s
.
duration
for
s
in
samples
],
0
).
index_select
(
0
,
order
)
assert
src_tokens
.
shape
[
1
]
==
durations
.
shape
[
1
]
if
self
.
pitches
is
not
None
:
pitches
=
_collate_frames
([
s
.
pitch
for
s
in
samples
],
True
)
pitches
=
pitches
.
index_select
(
0
,
order
)
assert
src_tokens
.
shape
[
1
]
==
pitches
.
shape
[
1
]
if
self
.
energies
is
not
None
:
energies
=
_collate_frames
([
s
.
energy
for
s
in
samples
],
True
)
energies
=
energies
.
index_select
(
0
,
order
)
assert
src_tokens
.
shape
[
1
]
==
energies
.
shape
[
1
]
src_texts
=
[
self
.
tgt_dict
.
string
(
samples
[
i
].
target
)
for
i
in
order
]
return
{
"id"
:
id_
,
"net_input"
:
{
"src_tokens"
:
src_tokens
,
"src_lengths"
:
src_lengths
,
"prev_output_tokens"
:
prev_output_tokens
,
},
"speaker"
:
speaker
,
"target"
:
feat
,
"durations"
:
durations
,
"pitches"
:
pitches
,
"energies"
:
energies
,
"target_lengths"
:
target_lengths
,
"ntokens"
:
sum
(
target_lengths
).
item
(),
"nsentences"
:
len
(
samples
),
"src_texts"
:
src_texts
,
}
class
TextToSpeechDatasetCreator
(
SpeechToTextDatasetCreator
):
KEY_DURATION
=
"duration"
KEY_PITCH
=
"pitch"
KEY_ENERGY
=
"energy"
@
classmethod
def
_from_list
(
cls
,
split_name
:
str
,
is_train_split
,
samples
:
List
[
Dict
],
cfg
:
S2TDataConfig
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
n_frames_per_step
,
speaker_to_id
,
)
->
TextToSpeechDataset
:
audio_root
=
Path
(
cfg
.
audio_root
)
ids
=
[
s
[
cls
.
KEY_ID
]
for
s
in
samples
]
audio_paths
=
[(
audio_root
/
s
[
cls
.
KEY_AUDIO
]).
as_posix
()
for
s
in
samples
]
n_frames
=
[
int
(
s
[
cls
.
KEY_N_FRAMES
])
for
s
in
samples
]
tgt_texts
=
[
s
[
cls
.
KEY_TGT_TEXT
]
for
s
in
samples
]
src_texts
=
[
s
.
get
(
cls
.
KEY_SRC_TEXT
,
cls
.
DEFAULT_SRC_TEXT
)
for
s
in
samples
]
speakers
=
[
s
.
get
(
cls
.
KEY_SPEAKER
,
cls
.
DEFAULT_SPEAKER
)
for
s
in
samples
]
src_langs
=
[
s
.
get
(
cls
.
KEY_SRC_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
tgt_langs
=
[
s
.
get
(
cls
.
KEY_TGT_LANG
,
cls
.
DEFAULT_LANG
)
for
s
in
samples
]
durations
=
[
s
.
get
(
cls
.
KEY_DURATION
,
None
)
for
s
in
samples
]
durations
=
[
None
if
dd
is
None
else
[
int
(
d
)
for
d
in
dd
.
split
(
" "
)]
for
dd
in
durations
]
durations
=
None
if
any
(
dd
is
None
for
dd
in
durations
)
else
durations
pitches
=
[
s
.
get
(
cls
.
KEY_PITCH
,
None
)
for
s
in
samples
]
pitches
=
[
None
if
pp
is
None
else
(
audio_root
/
pp
).
as_posix
()
for
pp
in
pitches
]
pitches
=
None
if
any
(
pp
is
None
for
pp
in
pitches
)
else
pitches
energies
=
[
s
.
get
(
cls
.
KEY_ENERGY
,
None
)
for
s
in
samples
]
energies
=
[
None
if
ee
is
None
else
(
audio_root
/
ee
).
as_posix
()
for
ee
in
energies
]
energies
=
None
if
any
(
ee
is
None
for
ee
in
energies
)
else
energies
return
TextToSpeechDataset
(
split_name
,
is_train_split
,
cfg
,
audio_paths
,
n_frames
,
src_texts
,
tgt_texts
,
speakers
,
src_langs
,
tgt_langs
,
ids
,
tgt_dict
,
pre_tokenizer
,
bpe_tokenizer
,
n_frames_per_step
,
speaker_to_id
,
durations
,
pitches
,
energies
,
)
PyTorch/NLP/new-Transformer/fairseq/data/backtranslation_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
fairseq
import
utils
from
.
import
FairseqDataset
def
backtranslate_samples
(
samples
,
collate_fn
,
generate_fn
,
cuda
=
True
):
"""Backtranslate a list of samples.
Given an input (*samples*) of the form:
[{'id': 1, 'source': 'hallo welt'}]
this will return:
[{'id': 1, 'source': 'hello world', 'target': 'hallo welt'}]
Args:
samples (List[dict]): samples to backtranslate. Individual samples are
expected to have a 'source' key, which will become the 'target'
after backtranslation.
collate_fn (callable): function to collate samples into a mini-batch
generate_fn (callable): function to generate backtranslations
cuda (bool): use GPU for generation (default: ``True``)
Returns:
List[dict]: an updated list of samples with a backtranslated source
"""
collated_samples
=
collate_fn
(
samples
)
s
=
utils
.
move_to_cuda
(
collated_samples
)
if
cuda
else
collated_samples
generated_sources
=
generate_fn
(
s
)
id_to_src
=
{
sample
[
"id"
]:
sample
[
"source"
]
for
sample
in
samples
}
# Go through each tgt sentence in batch and its corresponding best
# generated hypothesis and create a backtranslation data pair
# {id: id, source: generated backtranslation, target: original tgt}
return
[
{
"id"
:
id
.
item
(),
"target"
:
id_to_src
[
id
.
item
()],
"source"
:
hypos
[
0
][
"tokens"
].
cpu
(),
}
for
id
,
hypos
in
zip
(
collated_samples
[
"id"
],
generated_sources
)
]
class
BacktranslationDataset
(
FairseqDataset
):
"""
Sets up a backtranslation dataset which takes a tgt batch, generates
a src using a tgt-src backtranslation function (*backtranslation_fn*),
and returns the corresponding `{generated src, input tgt}` batch.
Args:
tgt_dataset (~fairseq.data.FairseqDataset): the dataset to be
backtranslated. Only the source side of this dataset will be used.
After backtranslation, the source sentences in this dataset will be
returned as the targets.
src_dict (~fairseq.data.Dictionary): the dictionary of backtranslated
sentences.
tgt_dict (~fairseq.data.Dictionary, optional): the dictionary of
sentences to be backtranslated.
backtranslation_fn (callable, optional): function to call to generate
backtranslations. This is typically the `generate` method of a
:class:`~fairseq.sequence_generator.SequenceGenerator` object.
Pass in None when it is not available at initialization time, and
use set_backtranslation_fn function to set it when available.
output_collater (callable, optional): function to call on the
backtranslated samples to create the final batch
(default: ``tgt_dataset.collater``).
cuda: use GPU for generation
"""
def
__init__
(
self
,
tgt_dataset
,
src_dict
,
tgt_dict
=
None
,
backtranslation_fn
=
None
,
output_collater
=
None
,
cuda
=
True
,
**
kwargs
):
self
.
tgt_dataset
=
tgt_dataset
self
.
backtranslation_fn
=
backtranslation_fn
self
.
output_collater
=
(
output_collater
if
output_collater
is
not
None
else
tgt_dataset
.
collater
)
self
.
cuda
=
cuda
if
torch
.
cuda
.
is_available
()
else
False
self
.
src_dict
=
src_dict
self
.
tgt_dict
=
tgt_dict
def
__getitem__
(
self
,
index
):
"""
Returns a single sample from *tgt_dataset*. Note that backtranslation is
not applied in this step; use :func:`collater` instead to backtranslate
a batch of samples.
"""
return
self
.
tgt_dataset
[
index
]
def
__len__
(
self
):
return
len
(
self
.
tgt_dataset
)
def
set_backtranslation_fn
(
self
,
backtranslation_fn
):
self
.
backtranslation_fn
=
backtranslation_fn
def
collater
(
self
,
samples
):
"""Merge and backtranslate a list of samples to form a mini-batch.
Using the samples from *tgt_dataset*, load a collated target sample to
feed to the backtranslation model. Then take the backtranslation with
the best score as the source and the original input as the target.
Note: we expect *tgt_dataset* to provide a function `collater()` that
will collate samples into the format expected by *backtranslation_fn*.
After backtranslation, we will feed the new list of samples (i.e., the
`(backtranslated source, original source)` pairs) to *output_collater*
and return the result.
Args:
samples (List[dict]): samples to backtranslate and collate
Returns:
dict: a mini-batch with keys coming from *output_collater*
"""
if
samples
[
0
].
get
(
"is_dummy"
,
False
):
return
samples
samples
=
backtranslate_samples
(
samples
=
samples
,
collate_fn
=
self
.
tgt_dataset
.
collater
,
generate_fn
=
(
lambda
net_input
:
self
.
backtranslation_fn
(
net_input
)),
cuda
=
self
.
cuda
,
)
return
self
.
output_collater
(
samples
)
def
num_tokens
(
self
,
index
):
"""Just use the tgt dataset num_tokens"""
return
self
.
tgt_dataset
.
num_tokens
(
index
)
def
ordered_indices
(
self
):
"""Just use the tgt dataset ordered_indices"""
return
self
.
tgt_dataset
.
ordered_indices
()
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used
when filtering a dataset with ``--max-positions``.
Note: we use *tgt_dataset* to approximate the length of the source
sentence, since we do not know the actual length until after
backtranslation.
"""
tgt_size
=
self
.
tgt_dataset
.
size
(
index
)[
0
]
return
(
tgt_size
,
tgt_size
)
@
property
def
supports_prefetch
(
self
):
return
getattr
(
self
.
tgt_dataset
,
"supports_prefetch"
,
False
)
def
prefetch
(
self
,
indices
):
return
self
.
tgt_dataset
.
prefetch
(
indices
)
PyTorch/NLP/new-Transformer/fairseq/data/base_wrapper_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
torch.utils.data.dataloader
import
default_collate
from
.
import
FairseqDataset
class
BaseWrapperDataset
(
FairseqDataset
):
def
__init__
(
self
,
dataset
):
super
().
__init__
()
self
.
dataset
=
dataset
def
__getitem__
(
self
,
index
):
return
self
.
dataset
[
index
]
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
collater
(
self
,
samples
):
if
hasattr
(
self
.
dataset
,
"collater"
):
return
self
.
dataset
.
collater
(
samples
)
else
:
return
default_collate
(
samples
)
@
property
def
sizes
(
self
):
return
self
.
dataset
.
sizes
def
num_tokens
(
self
,
index
):
return
self
.
dataset
.
num_tokens
(
index
)
def
size
(
self
,
index
):
return
self
.
dataset
.
size
(
index
)
def
ordered_indices
(
self
):
return
self
.
dataset
.
ordered_indices
()
@
property
def
supports_prefetch
(
self
):
return
getattr
(
self
.
dataset
,
"supports_prefetch"
,
False
)
def
attr
(
self
,
attr
:
str
,
index
:
int
):
return
self
.
dataset
.
attr
(
attr
,
index
)
def
prefetch
(
self
,
indices
):
self
.
dataset
.
prefetch
(
indices
)
def
get_batch_shapes
(
self
):
return
self
.
dataset
.
get_batch_shapes
()
def
batch_by_size
(
self
,
indices
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
):
return
self
.
dataset
.
batch_by_size
(
indices
,
max_tokens
=
max_tokens
,
max_sentences
=
max_sentences
,
required_batch_size_multiple
=
required_batch_size_multiple
,
)
def
filter_indices_by_size
(
self
,
indices
,
max_sizes
):
return
self
.
dataset
.
filter_indices_by_size
(
indices
,
max_sizes
)
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
self
.
dataset
.
can_reuse_epoch_itr_across_epochs
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
if
hasattr
(
self
.
dataset
,
"set_epoch"
):
self
.
dataset
.
set_epoch
(
epoch
)
PyTorch/NLP/new-Transformer/fairseq/data/bucket_pad_length_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
import
torch.nn.functional
as
F
from
fairseq.data
import
BaseWrapperDataset
from
fairseq.data.data_utils
import
get_buckets
,
get_bucketed_sizes
class
BucketPadLengthDataset
(
BaseWrapperDataset
):
"""
Bucket and pad item lengths to the nearest bucket size. This can be used to
reduce the number of unique batch shapes, which is important on TPUs since
each new batch shape requires a recompilation.
Args:
dataset (FairseqDatset): dataset to bucket
sizes (List[int]): all item sizes
num_buckets (int): number of buckets to create
pad_idx (int): padding symbol
left_pad (bool): if True, pad on the left; otherwise right pad
"""
def
__init__
(
self
,
dataset
,
sizes
,
num_buckets
,
pad_idx
,
left_pad
,
tensor_key
=
None
,
):
super
().
__init__
(
dataset
)
self
.
pad_idx
=
pad_idx
self
.
left_pad
=
left_pad
assert
num_buckets
>
0
self
.
buckets
=
get_buckets
(
sizes
,
num_buckets
)
self
.
_bucketed_sizes
=
get_bucketed_sizes
(
sizes
,
self
.
buckets
)
self
.
_tensor_key
=
tensor_key
def
_set_tensor
(
self
,
item
,
val
):
if
self
.
_tensor_key
is
None
:
return
val
item
[
self
.
_tensor_key
]
=
val
return
item
def
_get_tensor
(
self
,
item
):
if
self
.
_tensor_key
is
None
:
return
item
return
item
[
self
.
_tensor_key
]
def
_pad
(
self
,
tensor
,
bucket_size
,
dim
=-
1
):
num_pad
=
bucket_size
-
tensor
.
size
(
dim
)
return
F
.
pad
(
tensor
,
(
num_pad
if
self
.
left_pad
else
0
,
0
if
self
.
left_pad
else
num_pad
),
value
=
self
.
pad_idx
,
)
def
__getitem__
(
self
,
index
):
item
=
self
.
dataset
[
index
]
bucket_size
=
self
.
_bucketed_sizes
[
index
]
tensor
=
self
.
_get_tensor
(
item
)
padded
=
self
.
_pad
(
tensor
,
bucket_size
)
return
self
.
_set_tensor
(
item
,
padded
)
@
property
def
sizes
(
self
):
return
self
.
_bucketed_sizes
def
num_tokens
(
self
,
index
):
return
self
.
_bucketed_sizes
[
index
]
def
size
(
self
,
index
):
return
self
.
_bucketed_sizes
[
index
]
PyTorch/NLP/new-Transformer/fairseq/data/codedataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
json
import
logging
import
os
import
random
from
pathlib
import
Path
import
numpy
as
np
import
torch
import
torch.utils.data
from
.
import
data_utils
from
fairseq.data.fairseq_dataset
import
FairseqDataset
F0_FRAME_SPACE
=
0.005
# sec
logger
=
logging
.
getLogger
(
__name__
)
class
ExpressiveCodeDataConfig
(
object
):
def
__init__
(
self
,
json_path
):
with
open
(
json_path
,
"r"
)
as
f
:
self
.
config
=
json
.
load
(
f
)
self
.
_manifests
=
self
.
config
[
"manifests"
]
@
property
def
manifests
(
self
):
return
self
.
_manifests
@
property
def
n_units
(
self
):
return
self
.
config
[
"n_units"
]
@
property
def
sampling_rate
(
self
):
return
self
.
config
[
"sampling_rate"
]
@
property
def
code_hop_size
(
self
):
return
self
.
config
[
"code_hop_size"
]
@
property
def
f0_stats
(
self
):
"""pre-computed f0 statistics path"""
return
self
.
config
.
get
(
"f0_stats"
,
None
)
@
property
def
f0_vq_type
(
self
):
"""naive or precomp"""
return
self
.
config
[
"f0_vq_type"
]
@
property
def
f0_vq_name
(
self
):
return
self
.
config
[
"f0_vq_name"
]
def
get_f0_vq_naive_quantizer
(
self
,
log
,
norm_mean
,
norm_std
):
key
=
"log"
if
log
else
"linear"
if
norm_mean
and
norm_std
:
key
+=
"_mean_std_norm"
elif
norm_mean
:
key
+=
"_mean_norm"
else
:
key
+=
"_none_norm"
return
self
.
config
[
"f0_vq_naive_quantizer"
][
key
]
@
property
def
f0_vq_n_units
(
self
):
return
self
.
config
[
"f0_vq_n_units"
]
@
property
def
multispkr
(
self
):
"""how to parse speaker label from audio path"""
return
self
.
config
.
get
(
"multispkr"
,
None
)
def
get_f0
(
audio
,
rate
=
16000
):
try
:
import
amfm_decompy.basic_tools
as
basic
import
amfm_decompy.pYAAPT
as
pYAAPT
from
librosa.util
import
normalize
except
ImportError
:
raise
"Please install amfm_decompy (`pip install AMFM-decompy`) and librosa (`pip install librosa`)."
assert
audio
.
ndim
==
1
frame_length
=
20.0
# ms
to_pad
=
int
(
frame_length
/
1000
*
rate
)
//
2
audio
=
normalize
(
audio
)
*
0.95
audio
=
np
.
pad
(
audio
,
(
to_pad
,
to_pad
),
"constant"
,
constant_values
=
0
)
audio
=
basic
.
SignalObj
(
audio
,
rate
)
pitch
=
pYAAPT
.
yaapt
(
audio
,
frame_length
=
frame_length
,
frame_space
=
F0_FRAME_SPACE
*
1000
,
nccf_thresh1
=
0.25
,
tda_frame_length
=
25.0
,
)
f0
=
pitch
.
samp_values
return
f0
def
interpolate_f0
(
f0
):
try
:
from
scipy.interpolate
import
interp1d
except
ImportError
:
raise
"Please install scipy (`pip install scipy`)"
orig_t
=
np
.
arange
(
f0
.
shape
[
0
])
f0_interp
=
f0
[:]
ii
=
f0_interp
!=
0
if
ii
.
sum
()
>
1
:
f0_interp
=
interp1d
(
orig_t
[
ii
],
f0_interp
[
ii
],
bounds_error
=
False
,
kind
=
"linear"
,
fill_value
=
0
)(
orig_t
)
f0_interp
=
torch
.
Tensor
(
f0_interp
).
type_as
(
f0
).
to
(
f0
.
device
)
return
f0_interp
def
naive_quantize
(
x
,
edges
):
bin_idx
=
(
x
.
view
(
-
1
,
1
)
>
edges
.
view
(
1
,
-
1
)).
long
().
sum
(
dim
=
1
)
return
bin_idx
def
load_wav
(
full_path
):
try
:
import
soundfile
as
sf
except
ImportError
:
raise
"Please install soundfile (`pip install SoundFile`)"
data
,
sampling_rate
=
sf
.
read
(
full_path
)
return
data
,
sampling_rate
def
parse_code
(
code_str
,
dictionary
,
append_eos
):
code
,
duration
=
torch
.
unique_consecutive
(
torch
.
ShortTensor
(
list
(
map
(
int
,
code_str
.
split
()))),
return_counts
=
True
)
code
=
" "
.
join
(
map
(
str
,
code
.
tolist
()))
code
=
dictionary
.
encode_line
(
code
,
append_eos
).
short
()
if
append_eos
:
duration
=
torch
.
cat
((
duration
,
duration
.
new_zeros
((
1
,))),
dim
=
0
)
# eos
duration
=
duration
.
short
()
return
code
,
duration
def
parse_manifest
(
manifest
,
dictionary
):
audio_files
=
[]
codes
=
[]
durations
=
[]
speakers
=
[]
with
open
(
manifest
)
as
info
:
for
line
in
info
.
readlines
():
sample
=
eval
(
line
.
strip
())
if
"cpc_km100"
in
sample
:
k
=
"cpc_km100"
elif
"hubert_km100"
in
sample
:
k
=
"hubert_km100"
elif
"phone"
in
sample
:
k
=
"phone"
else
:
assert
False
,
"unknown format"
code
=
sample
[
k
]
code
,
duration
=
parse_code
(
code
,
dictionary
,
append_eos
=
True
)
codes
.
append
(
code
)
durations
.
append
(
duration
)
audio_files
.
append
(
sample
[
"audio"
])
speakers
.
append
(
sample
.
get
(
"speaker"
,
None
))
return
audio_files
,
codes
,
durations
,
speakers
def
parse_speaker
(
path
,
method
):
if
type
(
path
)
==
str
:
path
=
Path
(
path
)
if
method
==
"parent_name"
:
return
path
.
parent
.
name
elif
method
==
"parent_parent_name"
:
return
path
.
parent
.
parent
.
name
elif
method
==
"_"
:
return
path
.
name
.
split
(
"_"
)[
0
]
elif
method
==
"single"
:
return
"A"
elif
callable
(
method
):
return
method
(
path
)
else
:
raise
NotImplementedError
()
def
get_f0_by_filename
(
filename
,
tgt_sampling_rate
):
audio
,
sampling_rate
=
load_wav
(
filename
)
if
sampling_rate
!=
tgt_sampling_rate
:
raise
ValueError
(
"{} SR doesn't match target {} SR"
.
format
(
sampling_rate
,
tgt_sampling_rate
)
)
# compute un-interpolated f0, and use Ann's interp in __getitem__ if set
f0
=
get_f0
(
audio
,
rate
=
tgt_sampling_rate
)
f0
=
torch
.
from_numpy
(
f0
.
astype
(
np
.
float32
))
return
f0
def
align_f0_to_durations
(
f0
,
durations
,
f0_code_ratio
,
tol
=
1
):
code_len
=
durations
.
sum
()
targ_len
=
int
(
f0_code_ratio
*
code_len
)
diff
=
f0
.
size
(
0
)
-
targ_len
assert
abs
(
diff
)
<=
tol
,
(
f
"Cannot subsample F0: |
{
f0
.
size
(
0
)
}
-
{
f0_code_ratio
}
*
{
code_len
}
|"
f
" >
{
tol
}
(dur=
\n
{
durations
}
)"
)
if
diff
>
0
:
f0
=
f0
[:
targ_len
]
elif
diff
<
0
:
f0
=
torch
.
cat
((
f0
,
f0
.
new_full
((
-
diff
,),
f0
[
-
1
])),
0
)
f0_offset
=
0.0
seg_f0s
=
[]
for
dur
in
durations
:
f0_dur
=
dur
.
item
()
*
f0_code_ratio
seg_f0
=
f0
[
int
(
f0_offset
)
:
int
(
f0_offset
+
f0_dur
)]
seg_f0
=
seg_f0
[
seg_f0
!=
0
]
if
len
(
seg_f0
)
==
0
:
seg_f0
=
torch
.
tensor
(
0
).
type
(
seg_f0
.
type
())
else
:
seg_f0
=
seg_f0
.
mean
()
seg_f0s
.
append
(
seg_f0
)
f0_offset
+=
f0_dur
assert
int
(
f0_offset
)
==
f0
.
size
(
0
),
f
"
{
f0_offset
}
{
f0
.
size
()
}
{
durations
.
sum
()
}
"
return
torch
.
tensor
(
seg_f0s
)
class
Paddings
(
object
):
def
__init__
(
self
,
code_val
,
dur_val
=
0
,
f0_val
=-
2.0
):
self
.
code
=
code_val
self
.
dur
=
dur_val
self
.
f0
=
f0_val
class
Shifts
(
object
):
def
__init__
(
self
,
shifts_str
,
pads
):
self
.
_shifts
=
list
(
map
(
int
,
shifts_str
.
split
(
","
)))
assert
len
(
self
.
_shifts
)
==
2
,
self
.
_shifts
assert
all
(
s
>=
0
for
s
in
self
.
_shifts
)
self
.
extra_length
=
max
(
s
for
s
in
self
.
_shifts
)
self
.
pads
=
pads
@
property
def
dur
(
self
):
return
self
.
_shifts
[
0
]
@
property
def
f0
(
self
):
return
self
.
_shifts
[
1
]
@
staticmethod
def
shift_one
(
seq
,
left_pad_num
,
right_pad_num
,
pad
):
assert
seq
.
ndim
==
1
bos
=
seq
.
new_full
((
left_pad_num
,),
pad
)
eos
=
seq
.
new_full
((
right_pad_num
,),
pad
)
seq
=
torch
.
cat
([
bos
,
seq
,
eos
])
mask
=
torch
.
ones_like
(
seq
).
bool
()
mask
[
left_pad_num
:
len
(
seq
)
-
right_pad_num
]
=
0
return
seq
,
mask
def
__call__
(
self
,
code
,
dur
,
f0
):
if
self
.
extra_length
==
0
:
code_mask
=
torch
.
zeros_like
(
code
).
bool
()
dur_mask
=
torch
.
zeros_like
(
dur
).
bool
()
f0_mask
=
torch
.
zeros_like
(
f0
).
bool
()
return
code
,
code_mask
,
dur
,
dur_mask
,
f0
,
f0_mask
code
,
code_mask
=
self
.
shift_one
(
code
,
0
,
self
.
extra_length
,
self
.
pads
.
code
)
dur
,
dur_mask
=
self
.
shift_one
(
dur
,
self
.
dur
,
self
.
extra_length
-
self
.
dur
,
self
.
pads
.
dur
)
f0
,
f0_mask
=
self
.
shift_one
(
f0
,
self
.
f0
,
self
.
extra_length
-
self
.
f0
,
self
.
pads
.
f0
)
return
code
,
code_mask
,
dur
,
dur_mask
,
f0
,
f0_mask
class
CodeDataset
(
FairseqDataset
):
def
__init__
(
self
,
manifest
,
dictionary
,
dur_dictionary
,
f0_dictionary
,
config
,
discrete_dur
,
discrete_f0
,
log_f0
,
normalize_f0_mean
,
normalize_f0_std
,
interpolate_f0
,
return_filename
=
False
,
strip_filename
=
True
,
shifts
=
"0,0"
,
return_continuous_f0
=
False
,
):
random
.
seed
(
1234
)
self
.
dictionary
=
dictionary
self
.
dur_dictionary
=
dur_dictionary
self
.
f0_dictionary
=
f0_dictionary
self
.
config
=
config
# duration config
self
.
discrete_dur
=
discrete_dur
# pitch config
self
.
discrete_f0
=
discrete_f0
self
.
log_f0
=
log_f0
self
.
normalize_f0_mean
=
normalize_f0_mean
self
.
normalize_f0_std
=
normalize_f0_std
self
.
interpolate_f0
=
interpolate_f0
self
.
return_filename
=
return_filename
self
.
strip_filename
=
strip_filename
self
.
f0_code_ratio
=
config
.
code_hop_size
/
(
config
.
sampling_rate
*
F0_FRAME_SPACE
)
# use lazy loading to avoid sharing file handlers across workers
self
.
manifest
=
manifest
self
.
_codes
=
None
self
.
_durs
=
None
self
.
_f0s
=
None
with
open
(
f
"
{
manifest
}
.leng.txt"
,
"r"
)
as
f
:
lengs
=
[
int
(
line
.
rstrip
())
for
line
in
f
]
edges
=
np
.
cumsum
([
0
]
+
lengs
)
self
.
starts
,
self
.
ends
=
edges
[:
-
1
],
edges
[
1
:]
with
open
(
f
"
{
manifest
}
.path.txt"
,
"r"
)
as
f
:
self
.
file_names
=
[
line
.
rstrip
()
for
line
in
f
]
logger
.
info
(
f
"num entries:
{
len
(
self
.
starts
)
}
"
)
if
os
.
path
.
exists
(
f
"
{
manifest
}
.f0_stat.pt"
):
self
.
f0_stats
=
torch
.
load
(
f
"
{
manifest
}
.f0_stat.pt"
)
elif
config
.
f0_stats
:
self
.
f0_stats
=
torch
.
load
(
config
.
f0_stats
)
self
.
multispkr
=
config
.
multispkr
if
config
.
multispkr
:
with
open
(
f
"
{
manifest
}
.speaker.txt"
,
"r"
)
as
f
:
self
.
spkrs
=
[
line
.
rstrip
()
for
line
in
f
]
self
.
id_to_spkr
=
sorted
(
self
.
spkrs
)
self
.
spkr_to_id
=
{
k
:
v
for
v
,
k
in
enumerate
(
self
.
id_to_spkr
)}
self
.
pads
=
Paddings
(
dictionary
.
pad
(),
0
,
# use 0 for duration padding
f0_dictionary
.
pad
()
if
discrete_f0
else
-
5.0
,
)
self
.
shifts
=
Shifts
(
shifts
,
pads
=
self
.
pads
)
self
.
return_continuous_f0
=
return_continuous_f0
def
get_data_handlers
(
self
):
logging
.
info
(
f
"loading data for
{
self
.
manifest
}
"
)
self
.
_codes
=
np
.
load
(
f
"
{
self
.
manifest
}
.code.npy"
,
mmap_mode
=
"r"
)
self
.
_durs
=
np
.
load
(
f
"
{
self
.
manifest
}
.dur.npy"
,
mmap_mode
=
"r"
)
if
self
.
discrete_f0
:
if
self
.
config
.
f0_vq_type
==
"precomp"
:
self
.
_f0s
=
np
.
load
(
f
"
{
self
.
manifest
}
.
{
self
.
config
.
f0_vq_name
}
.npy"
,
mmap_mode
=
"r"
)
elif
self
.
config
.
f0_vq_type
==
"naive"
:
self
.
_f0s
=
np
.
load
(
f
"
{
self
.
manifest
}
.f0.npy"
,
mmap_mode
=
"r"
)
quantizers_path
=
self
.
config
.
get_f0_vq_naive_quantizer
(
self
.
log_f0
,
self
.
normalize_f0_mean
,
self
.
normalize_f0_std
)
quantizers
=
torch
.
load
(
quantizers_path
)
n_units
=
self
.
config
.
f0_vq_n_units
self
.
_f0_quantizer
=
torch
.
from_numpy
(
quantizers
[
n_units
])
else
:
raise
ValueError
(
f
"f0_vq_type
{
self
.
config
.
f0_vq_type
}
not supported"
)
else
:
self
.
_f0s
=
np
.
load
(
f
"
{
self
.
manifest
}
.f0.npy"
,
mmap_mode
=
"r"
)
def
preprocess_f0
(
self
,
f0
,
stats
):
"""
1. interpolate
2. log transform (keep unvoiced frame 0)
"""
# TODO: change this to be dependent on config for naive quantizer
f0
=
f0
.
clone
()
if
self
.
interpolate_f0
:
f0
=
interpolate_f0
(
f0
)
mask
=
f0
!=
0
# only process voiced frames
if
self
.
log_f0
:
f0
[
mask
]
=
f0
[
mask
].
log
()
if
self
.
normalize_f0_mean
:
mean
=
stats
[
"logf0_mean"
]
if
self
.
log_f0
else
stats
[
"f0_mean"
]
f0
[
mask
]
=
f0
[
mask
]
-
mean
if
self
.
normalize_f0_std
:
std
=
stats
[
"logf0_std"
]
if
self
.
log_f0
else
stats
[
"f0_std"
]
f0
[
mask
]
=
f0
[
mask
]
/
std
return
f0
def
_get_raw_item
(
self
,
index
):
start
,
end
=
self
.
starts
[
index
],
self
.
ends
[
index
]
if
self
.
_codes
is
None
:
self
.
get_data_handlers
()
code
=
torch
.
from_numpy
(
np
.
array
(
self
.
_codes
[
start
:
end
])).
long
()
dur
=
torch
.
from_numpy
(
np
.
array
(
self
.
_durs
[
start
:
end
]))
f0
=
torch
.
from_numpy
(
np
.
array
(
self
.
_f0s
[
start
:
end
]))
return
code
,
dur
,
f0
def
__getitem__
(
self
,
index
):
code
,
dur
,
f0
=
self
.
_get_raw_item
(
index
)
code
=
torch
.
cat
([
code
.
new
([
self
.
dictionary
.
bos
()]),
code
])
# use 0 for eos and bos
dur
=
torch
.
cat
([
dur
.
new
([
0
]),
dur
])
if
self
.
discrete_dur
:
dur
=
self
.
dur_dictionary
.
encode_line
(
" "
.
join
(
map
(
str
,
dur
.
tolist
())),
append_eos
=
False
).
long
()
else
:
dur
=
dur
.
float
()
# TODO: find a more elegant approach
raw_f0
=
None
if
self
.
discrete_f0
:
if
self
.
config
.
f0_vq_type
==
"precomp"
:
f0
=
self
.
f0_dictionary
.
encode_line
(
" "
.
join
(
map
(
str
,
f0
.
tolist
())),
append_eos
=
False
).
long
()
else
:
f0
=
f0
.
float
()
f0
=
self
.
preprocess_f0
(
f0
,
self
.
f0_stats
[
self
.
spkrs
[
index
]])
if
self
.
return_continuous_f0
:
raw_f0
=
f0
raw_f0
=
torch
.
cat
([
raw_f0
.
new
([
self
.
f0_dictionary
.
bos
()]),
raw_f0
])
f0
=
naive_quantize
(
f0
,
self
.
_f0_quantizer
)
f0
=
torch
.
cat
([
f0
.
new
([
self
.
f0_dictionary
.
bos
()]),
f0
])
else
:
f0
=
f0
.
float
()
if
self
.
multispkr
:
f0
=
self
.
preprocess_f0
(
f0
,
self
.
f0_stats
[
self
.
spkrs
[
index
]])
else
:
f0
=
self
.
preprocess_f0
(
f0
,
self
.
f0_stats
)
f0
=
torch
.
cat
([
f0
.
new
([
0
]),
f0
])
if
raw_f0
is
not
None
:
*
_
,
raw_f0
,
raw_f0_mask
=
self
.
shifts
(
code
,
dur
,
raw_f0
)
else
:
raw_f0_mask
=
None
code
,
code_mask
,
dur
,
dur_mask
,
f0
,
f0_mask
=
self
.
shifts
(
code
,
dur
,
f0
)
if
raw_f0_mask
is
not
None
:
assert
(
raw_f0_mask
==
f0_mask
).
all
()
# is a padded frame if either input or output is padded
feats
=
{
"source"
:
code
[:
-
1
],
"target"
:
code
[
1
:],
"mask"
:
code_mask
[
1
:].
logical_or
(
code_mask
[:
-
1
]),
"dur_source"
:
dur
[:
-
1
],
"dur_target"
:
dur
[
1
:],
"dur_mask"
:
dur_mask
[
1
:].
logical_or
(
dur_mask
[:
-
1
]),
"f0_source"
:
f0
[:
-
1
],
"f0_target"
:
f0
[
1
:],
"f0_mask"
:
f0_mask
[
1
:].
logical_or
(
f0_mask
[:
-
1
]),
}
if
raw_f0
is
not
None
:
feats
[
"raw_f0"
]
=
raw_f0
[
1
:]
if
self
.
return_filename
:
fname
=
self
.
file_names
[
index
]
feats
[
"filename"
]
=
(
fname
if
not
self
.
strip_filename
else
Path
(
fname
).
with_suffix
(
""
).
name
)
return
feats
def
__len__
(
self
):
return
len
(
self
.
starts
)
def
size
(
self
,
index
):
return
self
.
ends
[
index
]
-
self
.
starts
[
index
]
+
self
.
shifts
.
extra_length
def
num_tokens
(
self
,
index
):
return
self
.
size
(
index
)
def
collater
(
self
,
samples
):
pad_idx
,
eos_idx
=
self
.
dictionary
.
pad
(),
self
.
dictionary
.
eos
()
if
len
(
samples
)
==
0
:
return
{}
src_tokens
=
data_utils
.
collate_tokens
(
[
s
[
"source"
]
for
s
in
samples
],
pad_idx
,
eos_idx
,
left_pad
=
False
)
tgt_tokens
=
data_utils
.
collate_tokens
(
[
s
[
"target"
]
for
s
in
samples
],
pad_idx
=
pad_idx
,
eos_idx
=
pad_idx
,
# appending padding, eos is there already
left_pad
=
False
,
)
src_durs
,
tgt_durs
=
[
data_utils
.
collate_tokens
(
[
s
[
k
]
for
s
in
samples
],
pad_idx
=
self
.
pads
.
dur
,
eos_idx
=
self
.
pads
.
dur
,
left_pad
=
False
,
)
for
k
in
[
"dur_source"
,
"dur_target"
]
]
src_f0s
,
tgt_f0s
=
[
data_utils
.
collate_tokens
(
[
s
[
k
]
for
s
in
samples
],
pad_idx
=
self
.
pads
.
f0
,
eos_idx
=
self
.
pads
.
f0
,
left_pad
=
False
,
)
for
k
in
[
"f0_source"
,
"f0_target"
]
]
mask
,
dur_mask
,
f0_mask
=
[
data_utils
.
collate_tokens
(
[
s
[
k
]
for
s
in
samples
],
pad_idx
=
1
,
eos_idx
=
1
,
left_pad
=
False
,
)
for
k
in
[
"mask"
,
"dur_mask"
,
"f0_mask"
]
]
src_lengths
=
torch
.
LongTensor
([
s
[
"source"
].
numel
()
for
s
in
samples
])
n_tokens
=
sum
(
len
(
s
[
"source"
])
for
s
in
samples
)
result
=
{
"nsentences"
:
len
(
samples
),
"ntokens"
:
n_tokens
,
"net_input"
:
{
"src_tokens"
:
src_tokens
,
"src_lengths"
:
src_lengths
,
"dur_src"
:
src_durs
,
"f0_src"
:
src_f0s
,
},
"target"
:
tgt_tokens
,
"dur_target"
:
tgt_durs
,
"f0_target"
:
tgt_f0s
,
"mask"
:
mask
,
"dur_mask"
:
dur_mask
,
"f0_mask"
:
f0_mask
,
}
if
"filename"
in
samples
[
0
]:
result
[
"filename"
]
=
[
s
[
"filename"
]
for
s
in
samples
]
# TODO: remove this hack into the inference dataset
if
"prefix"
in
samples
[
0
]:
result
[
"prefix"
]
=
[
s
[
"prefix"
]
for
s
in
samples
]
if
"raw_f0"
in
samples
[
0
]:
raw_f0s
=
data_utils
.
collate_tokens
(
[
s
[
"raw_f0"
]
for
s
in
samples
],
pad_idx
=
self
.
pads
.
f0
,
eos_idx
=
self
.
pads
.
f0
,
left_pad
=
False
,
)
result
[
"raw_f0"
]
=
raw_f0s
return
result
PyTorch/NLP/new-Transformer/fairseq/data/colorize_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
.
import
BaseWrapperDataset
class
ColorizeDataset
(
BaseWrapperDataset
):
"""Adds 'colors' property to net input that is obtained from the provided color getter for use by models"""
def
__init__
(
self
,
dataset
,
color_getter
):
super
().
__init__
(
dataset
)
self
.
color_getter
=
color_getter
def
collater
(
self
,
samples
):
base_collate
=
super
().
collater
(
samples
)
if
len
(
base_collate
)
>
0
:
base_collate
[
"net_input"
][
"colors"
]
=
torch
.
tensor
(
list
(
self
.
color_getter
(
self
.
dataset
,
s
[
"id"
])
for
s
in
samples
),
dtype
=
torch
.
long
,
)
return
base_collate
PyTorch/NLP/new-Transformer/fairseq/data/concat_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
bisect
import
numpy
as
np
from
torch.utils.data.dataloader
import
default_collate
from
.
import
FairseqDataset
class
ConcatDataset
(
FairseqDataset
):
@
staticmethod
def
cumsum
(
sequence
,
sample_ratios
):
r
,
s
=
[],
0
for
e
,
ratio
in
zip
(
sequence
,
sample_ratios
):
curr_len
=
int
(
ratio
*
len
(
e
))
r
.
append
(
curr_len
+
s
)
s
+=
curr_len
return
r
def
__init__
(
self
,
datasets
,
sample_ratios
=
1
):
super
(
ConcatDataset
,
self
).
__init__
()
assert
len
(
datasets
)
>
0
,
"datasets should not be an empty iterable"
self
.
datasets
=
list
(
datasets
)
if
isinstance
(
sample_ratios
,
int
):
sample_ratios
=
[
sample_ratios
]
*
len
(
self
.
datasets
)
self
.
sample_ratios
=
sample_ratios
self
.
cumulative_sizes
=
self
.
cumsum
(
self
.
datasets
,
sample_ratios
)
self
.
real_sizes
=
[
len
(
d
)
for
d
in
self
.
datasets
]
def
__len__
(
self
):
return
self
.
cumulative_sizes
[
-
1
]
def
__getitem__
(
self
,
idx
):
dataset_idx
,
sample_idx
=
self
.
_get_dataset_and_sample_index
(
idx
)
return
self
.
datasets
[
dataset_idx
][
sample_idx
]
def
_get_dataset_and_sample_index
(
self
,
idx
:
int
):
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
idx
)
if
dataset_idx
==
0
:
sample_idx
=
idx
else
:
sample_idx
=
idx
-
self
.
cumulative_sizes
[
dataset_idx
-
1
]
sample_idx
=
sample_idx
%
self
.
real_sizes
[
dataset_idx
]
return
dataset_idx
,
sample_idx
def
collater
(
self
,
samples
,
**
extra_args
):
# For now only supports datasets with same underlying collater implementations
if
hasattr
(
self
.
datasets
[
0
],
"collater"
):
return
self
.
datasets
[
0
].
collater
(
samples
,
**
extra_args
)
else
:
return
default_collate
(
samples
,
**
extra_args
)
def
size
(
self
,
idx
:
int
):
"""
Return an example's size as a float or tuple.
"""
dataset_idx
,
sample_idx
=
self
.
_get_dataset_and_sample_index
(
idx
)
return
self
.
datasets
[
dataset_idx
].
size
(
sample_idx
)
def
num_tokens
(
self
,
index
:
int
):
return
np
.
max
(
self
.
size
(
index
))
def
attr
(
self
,
attr
:
str
,
index
:
int
):
dataset_idx
=
bisect
.
bisect_right
(
self
.
cumulative_sizes
,
index
)
return
getattr
(
self
.
datasets
[
dataset_idx
],
attr
,
None
)
@
property
def
sizes
(
self
):
_dataset_sizes
=
[]
for
ds
,
sr
in
zip
(
self
.
datasets
,
self
.
sample_ratios
):
if
isinstance
(
ds
.
sizes
,
np
.
ndarray
):
_dataset_sizes
.
append
(
np
.
tile
(
ds
.
sizes
,
sr
))
else
:
# Only support underlying dataset with single size array.
assert
isinstance
(
ds
.
sizes
,
list
)
_dataset_sizes
.
append
(
np
.
tile
(
ds
.
sizes
[
0
],
sr
))
return
np
.
concatenate
(
_dataset_sizes
)
@
property
def
supports_prefetch
(
self
):
return
all
(
d
.
supports_prefetch
for
d
in
self
.
datasets
)
def
ordered_indices
(
self
):
"""
Returns indices sorted by length. So less padding is needed.
"""
if
isinstance
(
self
.
sizes
,
np
.
ndarray
)
and
len
(
self
.
sizes
.
shape
)
>
1
:
# special handling for concatenating lang_pair_datasets
indices
=
np
.
arange
(
len
(
self
))
sizes
=
self
.
sizes
tgt_sizes
=
(
sizes
[:,
1
]
if
len
(
sizes
.
shape
)
>
0
and
sizes
.
shape
[
1
]
>
1
else
None
)
src_sizes
=
(
sizes
[:,
0
]
if
len
(
sizes
.
shape
)
>
0
and
sizes
.
shape
[
1
]
>
1
else
sizes
)
# sort by target length, then source length
if
tgt_sizes
is
not
None
:
indices
=
indices
[
np
.
argsort
(
tgt_sizes
[
indices
],
kind
=
"mergesort"
)]
return
indices
[
np
.
argsort
(
src_sizes
[
indices
],
kind
=
"mergesort"
)]
else
:
return
np
.
argsort
(
self
.
sizes
)
def
prefetch
(
self
,
indices
):
frm
=
0
for
to
,
ds
in
zip
(
self
.
cumulative_sizes
,
self
.
datasets
):
real_size
=
len
(
ds
)
if
getattr
(
ds
,
"supports_prefetch"
,
False
):
ds
.
prefetch
([(
i
-
frm
)
%
real_size
for
i
in
indices
if
frm
<=
i
<
to
])
frm
=
to
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
all
(
d
.
can_reuse_epoch_itr_across_epochs
for
d
in
self
.
datasets
)
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
for
ds
in
self
.
datasets
:
if
hasattr
(
ds
,
"set_epoch"
):
ds
.
set_epoch
(
epoch
)
PyTorch/NLP/new-Transformer/fairseq/data/concat_sentences_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
torch
from
.
import
FairseqDataset
class
ConcatSentencesDataset
(
FairseqDataset
):
def
__init__
(
self
,
*
datasets
):
super
().
__init__
()
self
.
datasets
=
datasets
assert
all
(
len
(
ds
)
==
len
(
datasets
[
0
])
for
ds
in
datasets
),
"datasets must have the same length"
def
__getitem__
(
self
,
index
):
return
torch
.
cat
([
ds
[
index
]
for
ds
in
self
.
datasets
])
def
__len__
(
self
):
return
len
(
self
.
datasets
[
0
])
def
collater
(
self
,
samples
):
return
self
.
datasets
[
0
].
collater
(
samples
)
@
property
def
sizes
(
self
):
return
sum
(
ds
.
sizes
for
ds
in
self
.
datasets
)
def
num_tokens
(
self
,
index
):
return
sum
(
ds
.
num_tokens
(
index
)
for
ds
in
self
.
datasets
)
def
size
(
self
,
index
):
return
sum
(
ds
.
size
(
index
)
for
ds
in
self
.
datasets
)
def
ordered_indices
(
self
):
return
self
.
datasets
[
0
].
ordered_indices
()
@
property
def
supports_prefetch
(
self
):
return
any
(
getattr
(
ds
,
"supports_prefetch"
,
False
)
for
ds
in
self
.
datasets
)
def
prefetch
(
self
,
indices
):
for
ds
in
self
.
datasets
:
if
getattr
(
ds
,
"supports_prefetch"
,
False
):
ds
.
prefetch
(
indices
)
def
set_epoch
(
self
,
epoch
):
super
().
set_epoch
(
epoch
)
for
ds
in
self
.
datasets
:
if
hasattr
(
ds
,
"set_epoch"
):
ds
.
set_epoch
(
epoch
)
PyTorch/NLP/new-Transformer/fairseq/data/data_utils.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
try
:
from
collections.abc
import
Iterable
except
ImportError
:
from
collections
import
Iterable
import
contextlib
import
itertools
import
logging
import
re
import
warnings
from
typing
import
Optional
,
Tuple
import
numpy
as
np
import
torch
from
fairseq.file_io
import
PathManager
from
fairseq
import
utils
import
os
logger
=
logging
.
getLogger
(
__name__
)
def
infer_language_pair
(
path
):
"""Infer language pair from filename: <split>.<lang1>-<lang2>.(...).idx"""
src
,
dst
=
None
,
None
for
filename
in
PathManager
.
ls
(
path
):
parts
=
filename
.
split
(
"."
)
if
len
(
parts
)
>=
3
and
len
(
parts
[
1
].
split
(
"-"
))
==
2
:
return
parts
[
1
].
split
(
"-"
)
return
src
,
dst
def
collate_tokens
(
values
,
pad_idx
,
eos_idx
=
None
,
left_pad
=
False
,
move_eos_to_beginning
=
False
,
pad_to_length
=
None
,
pad_to_multiple
=
1
,
pad_to_bsz
=
None
,
):
"""Convert a list of 1d tensors into a padded 2d tensor."""
size
=
max
(
v
.
size
(
0
)
for
v
in
values
)
size
=
size
if
pad_to_length
is
None
else
max
(
size
,
pad_to_length
)
if
pad_to_multiple
!=
1
and
size
%
pad_to_multiple
!=
0
:
size
=
int
(((
size
-
0.1
)
//
pad_to_multiple
+
1
)
*
pad_to_multiple
)
batch_size
=
len
(
values
)
if
pad_to_bsz
is
None
else
max
(
len
(
values
),
pad_to_bsz
)
res
=
values
[
0
].
new
(
batch_size
,
size
).
fill_
(
pad_idx
)
def
copy_tensor
(
src
,
dst
):
assert
dst
.
numel
()
==
src
.
numel
()
if
move_eos_to_beginning
:
if
eos_idx
is
None
:
# if no eos_idx is specified, then use the last token in src
dst
[
0
]
=
src
[
-
1
]
else
:
dst
[
0
]
=
eos_idx
dst
[
1
:]
=
src
[:
-
1
]
else
:
dst
.
copy_
(
src
)
for
i
,
v
in
enumerate
(
values
):
copy_tensor
(
v
,
res
[
i
][
size
-
len
(
v
)
:]
if
left_pad
else
res
[
i
][:
len
(
v
)])
return
res
def
load_indexed_dataset
(
path
,
dictionary
=
None
,
dataset_impl
=
None
,
combine
=
False
,
default
=
"cached"
):
"""A helper function for loading indexed datasets.
Args:
path (str): path to indexed dataset (e.g., 'data-bin/train')
dictionary (~fairseq.data.Dictionary): data dictionary
dataset_impl (str, optional): which dataset implementation to use. If
not provided, it will be inferred automatically. For legacy indexed
data we use the 'cached' implementation by default.
combine (bool, optional): automatically load and combine multiple
datasets. For example, if *path* is 'data-bin/train', then we will
combine 'data-bin/train', 'data-bin/train1', ... and return a
single ConcatDataset instance.
"""
import
fairseq.data.indexed_dataset
as
indexed_dataset
from
fairseq.data.concat_dataset
import
ConcatDataset
datasets
=
[]
for
k
in
itertools
.
count
():
path_k
=
path
+
(
str
(
k
)
if
k
>
0
else
""
)
try
:
path_k
=
indexed_dataset
.
get_indexed_dataset_to_local
(
path_k
)
except
Exception
as
e
:
if
"StorageException: [404] Path not found"
in
str
(
e
):
logger
.
warning
(
f
"path_k:
{
e
}
not found"
)
else
:
raise
e
dataset_impl_k
=
dataset_impl
if
dataset_impl_k
is
None
:
dataset_impl_k
=
indexed_dataset
.
infer_dataset_impl
(
path_k
)
dataset
=
indexed_dataset
.
make_dataset
(
path_k
,
impl
=
dataset_impl_k
or
default
,
fix_lua_indexing
=
True
,
dictionary
=
dictionary
,
)
if
dataset
is
None
:
break
logger
.
info
(
"loaded {:,} examples from: {}"
.
format
(
len
(
dataset
),
path_k
))
datasets
.
append
(
dataset
)
if
not
combine
:
break
if
len
(
datasets
)
==
0
:
return
None
elif
len
(
datasets
)
==
1
:
return
datasets
[
0
]
else
:
return
ConcatDataset
(
datasets
)
@
contextlib
.
contextmanager
def
numpy_seed
(
seed
,
*
addl_seeds
):
"""Context manager which seeds the NumPy PRNG with the specified seed and
restores the state afterward"""
if
seed
is
None
:
yield
return
if
len
(
addl_seeds
)
>
0
:
seed
=
int
(
hash
((
seed
,
*
addl_seeds
))
%
1e6
)
state
=
np
.
random
.
get_state
()
np
.
random
.
seed
(
seed
)
try
:
yield
finally
:
np
.
random
.
set_state
(
state
)
def
collect_filtered
(
function
,
iterable
,
filtered
):
"""
Similar to :func:`filter` but collects filtered elements in ``filtered``.
Args:
function (callable): function that returns ``False`` for elements that
should be filtered
iterable (iterable): iterable to filter
filtered (list): list to store filtered elements
"""
for
el
in
iterable
:
if
function
(
el
):
yield
el
else
:
filtered
.
append
(
el
)
def
_filter_by_size_dynamic
(
indices
,
size_fn
,
max_positions
,
raise_exception
=
False
):
def
compare_leq
(
a
,
b
):
return
a
<=
b
if
not
isinstance
(
a
,
tuple
)
else
max
(
a
)
<=
b
def
check_size
(
idx
):
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
return
size_fn
(
idx
)
<=
max_positions
elif
isinstance
(
max_positions
,
dict
):
idx_size
=
size_fn
(
idx
)
assert
isinstance
(
idx_size
,
dict
)
intersect_keys
=
set
(
max_positions
.
keys
())
&
set
(
idx_size
.
keys
())
return
all
(
all
(
a
is
None
or
b
is
None
or
a
<=
b
for
a
,
b
in
zip
(
idx_size
[
key
],
max_positions
[
key
])
)
for
key
in
intersect_keys
)
else
:
# For MultiCorpusSampledDataset, will generalize it later
if
not
isinstance
(
size_fn
(
idx
),
Iterable
):
return
all
(
size_fn
(
idx
)
<=
b
for
b
in
max_positions
)
return
all
(
a
is
None
or
b
is
None
or
a
<=
b
for
a
,
b
in
zip
(
size_fn
(
idx
),
max_positions
)
)
ignored
=
[]
itr
=
collect_filtered
(
check_size
,
indices
,
ignored
)
indices
=
np
.
fromiter
(
itr
,
dtype
=
np
.
int64
,
count
=-
1
)
return
indices
,
ignored
def
filter_by_size
(
indices
,
dataset
,
max_positions
,
raise_exception
=
False
):
"""
[deprecated] Filter indices based on their size.
Use `FairseqDataset::filter_indices_by_size` instead.
Args:
indices (List[int]): ordered list of dataset indices
dataset (FairseqDataset): fairseq dataset instance
max_positions (tuple): filter elements larger than this size.
Comparisons are done component-wise.
raise_exception (bool, optional): if ``True``, raise an exception if
any elements are filtered (default: False).
"""
warnings
.
warn
(
"data_utils.filter_by_size is deprecated. "
"Use `FairseqDataset::filter_indices_by_size` instead."
,
stacklevel
=
2
,
)
if
isinstance
(
max_positions
,
float
)
or
isinstance
(
max_positions
,
int
):
if
hasattr
(
dataset
,
"sizes"
)
and
isinstance
(
dataset
.
sizes
,
np
.
ndarray
):
ignored
=
indices
[
dataset
.
sizes
[
indices
]
>
max_positions
].
tolist
()
indices
=
indices
[
dataset
.
sizes
[
indices
]
<=
max_positions
]
elif
(
hasattr
(
dataset
,
"sizes"
)
and
isinstance
(
dataset
.
sizes
,
list
)
and
len
(
dataset
.
sizes
)
==
1
):
ignored
=
indices
[
dataset
.
sizes
[
0
][
indices
]
>
max_positions
].
tolist
()
indices
=
indices
[
dataset
.
sizes
[
0
][
indices
]
<=
max_positions
]
else
:
indices
,
ignored
=
_filter_by_size_dynamic
(
indices
,
dataset
.
size
,
max_positions
)
else
:
indices
,
ignored
=
_filter_by_size_dynamic
(
indices
,
dataset
.
size
,
max_positions
)
if
len
(
ignored
)
>
0
and
raise_exception
:
raise
Exception
(
(
"Size of sample #{} is invalid (={}) since max_positions={}, "
"skip this example with --skip-invalid-size-inputs-valid-test"
).
format
(
ignored
[
0
],
dataset
.
size
(
ignored
[
0
]),
max_positions
)
)
if
len
(
ignored
)
>
0
:
logger
.
warning
(
(
"{} samples have invalid sizes and will be skipped, "
"max_positions={}, first few sample ids={}"
).
format
(
len
(
ignored
),
max_positions
,
ignored
[:
10
])
)
return
indices
def
filter_paired_dataset_indices_by_size
(
src_sizes
,
tgt_sizes
,
indices
,
max_sizes
):
"""Filter a list of sample indices. Remove those that are longer
than specified in max_sizes.
Args:
indices (np.array): original array of sample indices
max_sizes (int or list[int] or tuple[int]): max sample size,
can be defined separately for src and tgt (then list or tuple)
Returns:
np.array: filtered sample array
list: list of removed indices
"""
if
max_sizes
is
None
:
return
indices
,
[]
if
type
(
max_sizes
)
in
(
int
,
float
):
max_src_size
,
max_tgt_size
=
max_sizes
,
max_sizes
else
:
max_src_size
,
max_tgt_size
=
max_sizes
if
tgt_sizes
is
None
:
ignored
=
indices
[
src_sizes
[
indices
]
>
max_src_size
]
else
:
ignored
=
indices
[
(
src_sizes
[
indices
]
>
max_src_size
)
|
(
tgt_sizes
[
indices
]
>
max_tgt_size
)
]
if
len
(
ignored
)
>
0
:
if
tgt_sizes
is
None
:
indices
=
indices
[
src_sizes
[
indices
]
<=
max_src_size
]
else
:
indices
=
indices
[
(
src_sizes
[
indices
]
<=
max_src_size
)
&
(
tgt_sizes
[
indices
]
<=
max_tgt_size
)
]
return
indices
,
ignored
.
tolist
()
def
batch_by_size
(
indices
,
num_tokens_fn
,
num_tokens_vec
=
None
,
max_tokens
=
None
,
max_sentences
=
None
,
required_batch_size_multiple
=
1
,
fixed_shapes
=
None
,
):
"""
Yield mini-batches of indices bucketed by size. Batches may contain
sequences of different lengths.
Args:
indices (List[int]): ordered list of dataset indices
num_tokens_fn (callable): function that returns the number of tokens at
a given index
num_tokens_vec (List[int], optional): precomputed vector of the number
of tokens for each index in indices (to enable faster batch generation)
max_tokens (int, optional): max number of tokens in each batch
(default: None).
max_sentences (int, optional): max number of sentences in each
batch (default: None).
required_batch_size_multiple (int, optional): require batch size to
be less than N or a multiple of N (default: 1).
fixed_shapes (List[Tuple[int, int]], optional): if given, batches will
only be created with the given shapes. *max_sentences* and
*required_batch_size_multiple* will be ignored (default: None).
"""
try
:
from
fairseq.data.data_utils_fast
import
(
batch_by_size_fn
,
batch_by_size_vec
,
batch_fixed_shapes_fast
,
)
except
ImportError
:
raise
ImportError
(
"Please build Cython components with: "
"`python setup.py build_ext --inplace`"
)
except
ValueError
:
raise
ValueError
(
"Please build (or rebuild) Cython components with `python setup.py build_ext --inplace`."
)
# added int() to avoid TypeError: an integer is required
max_tokens
=
int
(
max_tokens
)
if
max_tokens
is
not
None
else
-
1
max_sentences
=
max_sentences
if
max_sentences
is
not
None
else
-
1
bsz_mult
=
required_batch_size_multiple
if
not
isinstance
(
indices
,
np
.
ndarray
):
indices
=
np
.
fromiter
(
indices
,
dtype
=
np
.
int64
,
count
=-
1
)
if
num_tokens_vec
is
not
None
and
not
isinstance
(
num_tokens_vec
,
np
.
ndarray
):
num_tokens_vec
=
np
.
fromiter
(
num_tokens_vec
,
dtype
=
np
.
int64
,
count
=-
1
)
if
fixed_shapes
is
None
:
if
num_tokens_vec
is
None
:
return
batch_by_size_fn
(
indices
,
num_tokens_fn
,
max_tokens
,
max_sentences
,
bsz_mult
,
)
else
:
return
batch_by_size_vec
(
indices
,
num_tokens_vec
,
max_tokens
,
max_sentences
,
bsz_mult
,
)
else
:
fixed_shapes
=
np
.
array
(
fixed_shapes
,
dtype
=
np
.
int64
)
sort_order
=
np
.
lexsort
(
[
fixed_shapes
[:,
1
].
argsort
(),
# length
fixed_shapes
[:,
0
].
argsort
(),
# bsz
]
)
fixed_shapes_sorted
=
fixed_shapes
[
sort_order
]
return
batch_fixed_shapes_fast
(
indices
,
num_tokens_fn
,
fixed_shapes_sorted
)
def
post_process
(
sentence
:
str
,
symbol
:
str
):
if
symbol
==
"sentencepiece"
:
sentence
=
sentence
.
replace
(
" "
,
""
).
replace
(
"
\u2581
"
,
" "
).
strip
()
elif
symbol
==
"wordpiece"
:
sentence
=
sentence
.
replace
(
" "
,
""
).
replace
(
"_"
,
" "
).
strip
()
elif
symbol
==
"letter"
:
sentence
=
sentence
.
replace
(
" "
,
""
).
replace
(
"|"
,
" "
).
strip
()
elif
symbol
==
"silence"
:
import
re
sentence
=
sentence
.
replace
(
"<SIL>"
,
""
)
sentence
=
re
.
sub
(
" +"
,
" "
,
sentence
).
strip
()
elif
symbol
==
"_EOW"
:
sentence
=
sentence
.
replace
(
" "
,
""
).
replace
(
"_EOW"
,
" "
).
strip
()
elif
symbol
in
{
"subword_nmt"
,
"@@ "
,
"@@"
}:
if
symbol
==
"subword_nmt"
:
symbol
=
"@@ "
sentence
=
(
sentence
+
" "
).
replace
(
symbol
,
""
).
rstrip
()
elif
symbol
==
"none"
:
pass
elif
symbol
is
not
None
:
raise
NotImplementedError
(
f
"Unknown post_process option:
{
symbol
}
"
)
return
sentence
def
compute_mask_indices
(
shape
:
Tuple
[
int
,
int
],
padding_mask
:
Optional
[
torch
.
Tensor
],
mask_prob
:
float
,
mask_length
:
int
,
mask_type
:
str
=
"static"
,
mask_other
:
float
=
0.0
,
min_masks
:
int
=
0
,
no_overlap
:
bool
=
False
,
min_space
:
int
=
0
,
require_same_masks
:
bool
=
True
,
mask_dropout
:
float
=
0.0
,
)
->
np
.
ndarray
:
"""
Computes random mask spans for a given shape
Args:
shape: the the shape for which to compute masks.
should be of size 2 where first element is batch size and 2nd is timesteps
padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements
mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by
number of timesteps divided by length of mask span to mask approximately this percentage of all elements.
however due to overlaps, the actual number will be smaller (unless no_overlap is True)
mask_type: how to compute mask lengths
static = fixed size
uniform = sample from uniform distribution [mask_other, mask_length*2]
normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element
poisson = sample from possion distribution with lambda = mask length
min_masks: minimum number of masked spans
no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping
min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans
require_same_masks: if true, will randomly drop out masks until same amount of masks remains in each sample
mask_dropout: randomly dropout this percentage of masks in each example
"""
bsz
,
all_sz
=
shape
mask
=
np
.
full
((
bsz
,
all_sz
),
False
)
all_num_mask
=
int
(
# add a random number for probabilistic rounding
mask_prob
*
all_sz
/
float
(
mask_length
)
+
np
.
random
.
rand
()
)
all_num_mask
=
max
(
min_masks
,
all_num_mask
)
mask_idcs
=
[]
for
i
in
range
(
bsz
):
if
padding_mask
is
not
None
:
sz
=
all_sz
-
padding_mask
[
i
].
long
().
sum
().
item
()
num_mask
=
int
(
# add a random number for probabilistic rounding
mask_prob
*
sz
/
float
(
mask_length
)
+
np
.
random
.
rand
()
)
num_mask
=
max
(
min_masks
,
num_mask
)
else
:
sz
=
all_sz
num_mask
=
all_num_mask
if
mask_type
==
"static"
:
lengths
=
np
.
full
(
num_mask
,
mask_length
)
elif
mask_type
==
"uniform"
:
lengths
=
np
.
random
.
randint
(
mask_other
,
mask_length
*
2
+
1
,
size
=
num_mask
)
elif
mask_type
==
"normal"
:
lengths
=
np
.
random
.
normal
(
mask_length
,
mask_other
,
size
=
num_mask
)
lengths
=
[
max
(
1
,
int
(
round
(
x
)))
for
x
in
lengths
]
elif
mask_type
==
"poisson"
:
lengths
=
np
.
random
.
poisson
(
mask_length
,
size
=
num_mask
)
lengths
=
[
int
(
round
(
x
))
for
x
in
lengths
]
else
:
raise
Exception
(
"unknown mask selection "
+
mask_type
)
if
sum
(
lengths
)
==
0
:
lengths
[
0
]
=
min
(
mask_length
,
sz
-
1
)
if
no_overlap
:
mask_idc
=
[]
def
arrange
(
s
,
e
,
length
,
keep_length
):
span_start
=
np
.
random
.
randint
(
s
,
e
-
length
)
mask_idc
.
extend
(
span_start
+
i
for
i
in
range
(
length
))
new_parts
=
[]
if
span_start
-
s
-
min_space
>=
keep_length
:
new_parts
.
append
((
s
,
span_start
-
min_space
+
1
))
if
e
-
span_start
-
length
-
min_space
>
keep_length
:
new_parts
.
append
((
span_start
+
length
+
min_space
,
e
))
return
new_parts
parts
=
[(
0
,
sz
)]
min_length
=
min
(
lengths
)
for
length
in
sorted
(
lengths
,
reverse
=
True
):
lens
=
np
.
fromiter
(
(
e
-
s
if
e
-
s
>=
length
+
min_space
else
0
for
s
,
e
in
parts
),
np
.
int
,
)
l_sum
=
np
.
sum
(
lens
)
if
l_sum
==
0
:
break
probs
=
lens
/
np
.
sum
(
lens
)
c
=
np
.
random
.
choice
(
len
(
parts
),
p
=
probs
)
s
,
e
=
parts
.
pop
(
c
)
parts
.
extend
(
arrange
(
s
,
e
,
length
,
min_length
))
mask_idc
=
np
.
asarray
(
mask_idc
)
else
:
min_len
=
min
(
lengths
)
if
sz
-
min_len
<=
num_mask
:
min_len
=
sz
-
num_mask
-
1
mask_idc
=
np
.
random
.
choice
(
sz
-
min_len
,
num_mask
,
replace
=
False
)
mask_idc
=
np
.
asarray
(
[
mask_idc
[
j
]
+
offset
for
j
in
range
(
len
(
mask_idc
))
for
offset
in
range
(
lengths
[
j
])
]
)
mask_idcs
.
append
(
np
.
unique
(
mask_idc
[
mask_idc
<
sz
]))
min_len
=
min
([
len
(
m
)
for
m
in
mask_idcs
])
for
i
,
mask_idc
in
enumerate
(
mask_idcs
):
if
len
(
mask_idc
)
>
min_len
and
require_same_masks
:
mask_idc
=
np
.
random
.
choice
(
mask_idc
,
min_len
,
replace
=
False
)
if
mask_dropout
>
0
:
num_holes
=
np
.
rint
(
len
(
mask_idc
)
*
mask_dropout
).
astype
(
int
)
mask_idc
=
np
.
random
.
choice
(
mask_idc
,
len
(
mask_idc
)
-
num_holes
,
replace
=
False
)
mask
[
i
,
mask_idc
]
=
True
return
mask
def
get_mem_usage
():
try
:
import
psutil
mb
=
1024
*
1024
return
f
"used=
{
psutil
.
virtual_memory
().
used
/
mb
}
Mb; avail=
{
psutil
.
virtual_memory
().
available
/
mb
}
Mb"
except
ImportError
:
return
"N/A"
# lens: torch.LongTensor
# returns: torch.BoolTensor
def
lengths_to_padding_mask
(
lens
):
bsz
,
max_lens
=
lens
.
size
(
0
),
torch
.
max
(
lens
).
item
()
mask
=
torch
.
arange
(
max_lens
).
to
(
lens
.
device
).
view
(
1
,
max_lens
)
mask
=
mask
.
expand
(
bsz
,
-
1
)
>=
lens
.
view
(
bsz
,
1
).
expand
(
-
1
,
max_lens
)
return
mask
# lens: torch.LongTensor
# returns: torch.BoolTensor
def
lengths_to_mask
(
lens
):
return
~
lengths_to_padding_mask
(
lens
)
def
get_buckets
(
sizes
,
num_buckets
):
buckets
=
np
.
unique
(
np
.
percentile
(
sizes
,
np
.
linspace
(
0
,
100
,
num_buckets
+
1
),
interpolation
=
"lower"
,
)[
1
:]
)
return
buckets
def
get_bucketed_sizes
(
orig_sizes
,
buckets
):
sizes
=
np
.
copy
(
orig_sizes
)
assert
np
.
min
(
sizes
)
>=
0
start_val
=
-
1
for
end_val
in
buckets
:
mask
=
(
sizes
>
start_val
)
&
(
sizes
<=
end_val
)
sizes
[
mask
]
=
end_val
start_val
=
end_val
return
sizes
def
_find_extra_valid_paths
(
dataset_path
:
str
)
->
set
:
paths
=
utils
.
split_paths
(
dataset_path
)
all_valid_paths
=
set
()
for
sub_dir
in
paths
:
contents
=
PathManager
.
ls
(
sub_dir
)
valid_paths
=
[
c
for
c
in
contents
if
re
.
match
(
"valid*[0-9].*"
,
c
)
is
not
None
]
all_valid_paths
|=
{
os
.
path
.
basename
(
p
)
for
p
in
valid_paths
}
# Remove .bin, .idx etc
roots
=
{
os
.
path
.
splitext
(
p
)[
0
]
for
p
in
all_valid_paths
}
return
roots
def
raise_if_valid_subsets_unintentionally_ignored
(
train_cfg
)
->
None
:
"""Raises if there are paths matching 'valid*[0-9].*' which are not combined or ignored."""
if
(
train_cfg
.
dataset
.
ignore_unused_valid_subsets
or
train_cfg
.
dataset
.
combine_valid_subsets
or
train_cfg
.
dataset
.
disable_validation
or
not
hasattr
(
train_cfg
.
task
,
"data"
)
):
return
other_paths
=
_find_extra_valid_paths
(
train_cfg
.
task
.
data
)
specified_subsets
=
train_cfg
.
dataset
.
valid_subset
.
split
(
","
)
ignored_paths
=
[
p
for
p
in
other_paths
if
p
not
in
specified_subsets
]
if
ignored_paths
:
advice
=
"Set --combine-val to combine them or --ignore-unused-valid-subsets to ignore them."
msg
=
f
"Valid paths
{
ignored_paths
}
will be ignored.
{
advice
}
"
raise
ValueError
(
msg
)
PyTorch/NLP/new-Transformer/fairseq/data/data_utils_fast.pyx
0 → 100644
View file @
c0f05c10
# cython: language_level=3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
numpy
as
np
cimport
cython
cimport
numpy
as
np
from
libc.stdint
cimport
int32_t
,
int64_t
from
libcpp
cimport
bool
as
bool_t
ctypedef
int64_t
DTYPE_t
@
cython
.
cdivision
(
True
)
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
cpdef
list
batch_by_size_vec
(
np
.
ndarray
[
int64_t
,
ndim
=
1
]
indices
,
np
.
ndarray
[
int64_t
,
ndim
=
1
]
num_tokens_vec
,
int64_t
max_tokens
,
int64_t
max_sentences
,
int32_t
bsz_mult
,
):
if
indices
.
shape
[
0
]
==
0
:
return
[]
assert
max_tokens
<=
0
or
np
.
max
(
num_tokens_vec
)
<=
max_tokens
,
(
f
"Sentences lengths should not exceed max_tokens=
{
max_tokens
}
"
)
cdef
int32_t
indices_len
=
indices
.
shape
[
0
]
cdef
np
.
ndarray
[
int32_t
,
ndim
=
1
]
batches_ends
=
\
np
.
zeros
(
indices_len
,
dtype
=
np
.
int32
)
cdef
int32_t
[:]
batches_ends_view
=
batches_ends
cdef
int64_t
[:]
num_tokens_view
=
num_tokens_vec
cdef
int32_t
pos
=
0
cdef
int32_t
new_batch_end
=
0
cdef
int64_t
new_batch_max_tokens
=
0
cdef
int32_t
new_batch_sentences
=
0
cdef
int64_t
new_batch_num_tokens
=
0
cdef
bool_t
overflow
=
False
cdef
bool_t
size_matches_with_bsz_mult
=
False
cdef
int32_t
batches_count
=
0
cdef
int32_t
batch_start
=
0
cdef
int64_t
tail_max_tokens
=
0
cdef
int64_t
batch_max_tokens
=
0
for
pos
in
range
(
indices_len
):
# At every pos we keep stats about the last complete batch [batch_start:batch_end),
# and tail [batch_end:pos].
# 1) Every time when (batch + tail) forms a valid batch
# (according to max_tokens, max_sentences and bsz_mult) we append tail to batch.
# 2) When (batch+tail) violates max_tokens or max_sentences constraints
# we finalize running batch, and tail becomes a new batch.
# 3) There is a corner case when tail also violates constraints.
# In that situation [batch_end:pos-1] (tail without the current pos)
# gets added to the finalized batches, while [pos:pos] becomes a new tail.
#
# Important: For the sake of performance try to avoid using function calls within this loop.
tail_max_tokens
=
tail_max_tokens
\
if
tail_max_tokens
>
num_tokens_view
[
pos
]
\
else
num_tokens_view
[
pos
]
new_batch_end
=
pos
+
1
new_batch_max_tokens
=
batch_max_tokens
\
if
batch_max_tokens
>
tail_max_tokens
\
else
tail_max_tokens
new_batch_sentences
=
new_batch_end
-
batch_start
new_batch_num_tokens
=
new_batch_sentences
*
new_batch_max_tokens
overflow
=
(
new_batch_sentences
>
max_sentences
>
0
or
new_batch_num_tokens
>
max_tokens
>
0
)
size_matches_with_bsz_mult
=
(
new_batch_sentences
<
bsz_mult
or
new_batch_sentences
%
bsz_mult
==
0
)
if
overflow
:
tail_num_tokens
=
tail_max_tokens
*
\
(
new_batch_end
-
batches_ends_view
[
batches_count
])
tail_overflow
=
tail_num_tokens
>
max_tokens
>
0
# In case of a tail overflow finalize two batches
if
tail_overflow
:
batches_count
+=
1
batches_ends_view
[
batches_count
]
=
pos
tail_max_tokens
=
num_tokens_view
[
pos
]
batch_start
=
batches_ends_view
[
batches_count
]
batches_count
+=
1
new_batch_max_tokens
=
tail_max_tokens
if
overflow
or
size_matches_with_bsz_mult
:
batches_ends_view
[
batches_count
]
=
new_batch_end
batch_max_tokens
=
new_batch_max_tokens
tail_max_tokens
=
0
if
batches_ends_view
[
batches_count
]
!=
indices_len
:
batches_count
+=
1
# Memory and time-efficient split
return
np
.
split
(
indices
,
batches_ends
[:
batches_count
])
@
cython
.
boundscheck
(
False
)
@
cython
.
wraparound
(
False
)
cpdef
list
batch_by_size_fn
(
np
.
ndarray
[
DTYPE_t
,
ndim
=
1
]
indices
,
num_tokens_fn
,
int64_t
max_tokens
,
int64_t
max_sentences
,
int32_t
bsz_mult
,
):
cdef
int32_t
indices_len
=
indices
.
shape
[
0
]
cdef
np
.
ndarray
[
int64_t
,
ndim
=
1
]
num_tokens_vec
=
np
.
zeros
(
indices_len
,
dtype
=
np
.
int64
)
cdef
DTYPE_t
[:]
indices_view
=
indices
cdef
DTYPE_t
[:]
num_tokens_vec_view
=
num_tokens_vec
cdef
int64_t
pos
for
pos
in
range
(
indices_len
):
num_tokens_vec
[
pos
]
=
num_tokens_fn
(
indices_view
[
pos
])
return
batch_by_size_vec
(
indices
,
num_tokens_vec
,
max_tokens
,
max_sentences
,
bsz_mult
,)
cdef
_find_valid_shape
(
DTYPE_t
[:,
:]
shapes_view
,
int64_t
num_sentences
,
int64_t
num_tokens
,
):
"""Return index of first valid shape of -1 if none is found."""
for
i
in
range
(
shapes_view
.
shape
[
0
]):
if
num_sentences
<=
shapes_view
[
i
][
0
]
and
num_tokens
<=
shapes_view
[
i
][
1
]:
return
i
return
-
1
@
cython
.
cdivision
(
True
)
cpdef
list
batch_fixed_shapes_fast
(
np
.
ndarray
[
DTYPE_t
,
ndim
=
1
]
indices
,
num_tokens_fn
,
np
.
ndarray
[
DTYPE_t
,
ndim
=
2
]
fixed_shapes_sorted
,
):
cdef
int64_t
sample_len
=
0
cdef
list
sample_lens
=
[]
cdef
list
batch
=
[]
cdef
list
batches
=
[]
cdef
int64_t
mod_len
cdef
int64_t
i
cdef
int64_t
idx
cdef
int64_t
num_tokens
cdef
DTYPE_t
[:]
indices_view
=
indices
cdef
DTYPE_t
[:,
:]
shapes_view
=
fixed_shapes_sorted
for
i
in
range
(
len
(
indices_view
)):
idx
=
indices_view
[
i
]
num_tokens
=
num_tokens_fn
(
idx
)
sample_lens
.
append
(
num_tokens
)
sample_len
=
max
(
sample_len
,
num_tokens
)
shape_idx
=
_find_valid_shape
(
shapes_view
,
len
(
batch
)
+
1
,
sample_len
)
if
shape_idx
==
-
1
:
batches
.
append
(
batch
)
batch
=
[]
sample_lens
=
[]
sample_len
=
0
shapes_view
=
fixed_shapes_sorted
elif
shape_idx
>
0
:
# small optimization for the next call to _find_valid_shape
shapes_view
=
shapes_view
[
shape_idx
:]
batch
.
append
(
idx
)
if
len
(
batch
)
>
0
:
batches
.
append
(
batch
)
return
batches
PyTorch/NLP/new-Transformer/fairseq/data/denoising_dataset.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
math
import
numpy
as
np
import
torch
from
.
import
FairseqDataset
,
data_utils
def
collate
(
samples
,
pad_idx
,
eos_idx
,
vocab
,
left_pad_source
=
False
,
left_pad_target
=
False
,
input_feeding
=
True
,
pad_to_length
=
None
,
):
assert
input_feeding
if
len
(
samples
)
==
0
:
return
{}
def
merge
(
key
,
left_pad
,
move_eos_to_beginning
=
False
,
pad_to_length
=
None
):
return
data_utils
.
collate_tokens
(
[
s
[
key
]
for
s
in
samples
],
pad_idx
,
eos_idx
=
None
,
# use eos_idx of each sample instead of vocab.eos()
left_pad
=
left_pad
,
move_eos_to_beginning
=
move_eos_to_beginning
,
pad_to_length
=
pad_to_length
,
)
id
=
torch
.
LongTensor
([
s
[
"id"
]
for
s
in
samples
])
src_tokens
=
merge
(
"source"
,
left_pad
=
left_pad_source
,
pad_to_length
=
pad_to_length
[
"source"
]
if
pad_to_length
is
not
None
else
None
,
)
# sort by descending source length
src_lengths
=
torch
.
LongTensor
([
s
[
"source"
].
numel
()
for
s
in
samples
])
src_lengths
,
sort_order
=
src_lengths
.
sort
(
descending
=
True
)
id
=
id
.
index_select
(
0
,
sort_order
)
src_tokens
=
src_tokens
.
index_select
(
0
,
sort_order
)
prev_output_tokens
=
None
target
=
None
if
samples
[
0
].
get
(
"target"
,
None
)
is
not
None
:
target
=
merge
(
"target"
,
left_pad
=
left_pad_target
,
pad_to_length
=
pad_to_length
[
"target"
]
if
pad_to_length
is
not
None
else
None
,
)
target
=
target
.
index_select
(
0
,
sort_order
)
ntokens
=
sum
(
len
(
s
[
"target"
])
for
s
in
samples
)
if
input_feeding
:
# we create a shifted version of targets for feeding the
# previous output token(s) into the next decoder step
prev_output_tokens
=
merge
(
"target"
,
left_pad
=
left_pad_target
,
move_eos_to_beginning
=
True
,
pad_to_length
=
pad_to_length
[
"target"
]
if
pad_to_length
is
not
None
else
None
,
)
prev_output_tokens
=
prev_output_tokens
.
index_select
(
0
,
sort_order
)
else
:
ntokens
=
sum
(
len
(
s
[
"source"
])
for
s
in
samples
)
batch
=
{
"id"
:
id
,
"ntokens"
:
ntokens
,
"net_input"
:
{
"src_tokens"
:
src_tokens
,
"src_lengths"
:
src_lengths
,
},
"target"
:
target
,
"nsentences"
:
samples
[
0
][
"source"
].
size
(
0
),
"sort_order"
:
sort_order
,
}
if
prev_output_tokens
is
not
None
:
batch
[
"net_input"
][
"prev_output_tokens"
]
=
prev_output_tokens
return
batch
class
DenoisingDataset
(
FairseqDataset
):
"""
A wrapper around TokenBlockDataset for BART dataset.
Args:
dataset (TokenBlockDataset): dataset to wrap
sizes (List[int]): sentence lengths
vocab (~fairseq.data.Dictionary): vocabulary
mask_idx (int): dictionary index used for masked token
mask_whole_words: only mask whole words. This should be a byte mask
over vocab indices, indicating whether it is the beginning of a
word. We will extend any mask to encompass the whole word.
shuffle (bool, optional): shuffle the elements before batching.
Default: ``True``
seed: Seed for random number generator for reproducibility.
args: argparse arguments.
"""
def
__init__
(
self
,
dataset
,
sizes
,
vocab
,
mask_idx
,
mask_whole_words
,
shuffle
,
seed
,
args
,
eos
=
None
,
item_transform_func
=
None
,
):
self
.
dataset
=
dataset
self
.
sizes
=
sizes
self
.
vocab
=
vocab
self
.
shuffle
=
shuffle
self
.
seed
=
seed
self
.
mask_idx
=
mask_idx
self
.
mask_whole_word
=
mask_whole_words
self
.
mask_ratio
=
args
.
mask
self
.
random_ratio
=
args
.
mask_random
self
.
insert_ratio
=
args
.
insert
self
.
rotate_ratio
=
args
.
rotate
self
.
permute_sentence_ratio
=
args
.
permute_sentences
self
.
eos
=
eos
if
eos
is
not
None
else
vocab
.
eos
()
self
.
item_transform_func
=
item_transform_func
if
args
.
bpe
!=
"gpt2"
:
self
.
full_stop_index
=
self
.
vocab
.
eos
()
else
:
assert
args
.
bpe
==
"gpt2"
self
.
full_stop_index
=
self
.
vocab
.
index
(
"13"
)
self
.
replace_length
=
args
.
replace_length
if
self
.
replace_length
not
in
[
-
1
,
0
,
1
]:
raise
ValueError
(
f
"invalid arg: replace_length=
{
self
.
replace_length
}
"
)
if
args
.
mask_length
not
in
[
"subword"
,
"word"
,
"span-poisson"
]:
raise
ValueError
(
f
"invalid arg: mask-length=
{
args
.
mask_length
}
"
)
if
args
.
mask_length
==
"subword"
and
args
.
replace_length
not
in
[
0
,
1
]:
raise
ValueError
(
f
"if using subwords, use replace-length=1 or 0"
)
self
.
mask_span_distribution
=
None
if
args
.
mask_length
==
"span-poisson"
:
_lambda
=
args
.
poisson_lambda
lambda_to_the_k
=
1
e_to_the_minus_lambda
=
math
.
exp
(
-
_lambda
)
k_factorial
=
1
ps
=
[]
for
k
in
range
(
0
,
128
):
ps
.
append
(
e_to_the_minus_lambda
*
lambda_to_the_k
/
k_factorial
)
lambda_to_the_k
*=
_lambda
k_factorial
*=
k
+
1
if
ps
[
-
1
]
<
0.0000001
:
break
ps
=
torch
.
FloatTensor
(
ps
)
self
.
mask_span_distribution
=
torch
.
distributions
.
Categorical
(
ps
)
self
.
epoch
=
0
@
property
def
can_reuse_epoch_itr_across_epochs
(
self
):
return
True
# only the noise changes, not item sizes
def
set_epoch
(
self
,
epoch
,
**
unused
):
self
.
epoch
=
epoch
def
__getitem__
(
self
,
index
):
with
data_utils
.
numpy_seed
(
self
.
seed
,
self
.
epoch
,
index
):
tokens
=
self
.
dataset
[
index
]
assert
tokens
[
-
1
]
==
self
.
eos
source
,
target
=
tokens
,
tokens
.
clone
()
if
self
.
permute_sentence_ratio
>
0.0
:
source
=
self
.
permute_sentences
(
source
,
self
.
permute_sentence_ratio
)
if
self
.
mask_ratio
>
0
:
source
=
self
.
add_whole_word_mask
(
source
,
self
.
mask_ratio
)
if
self
.
insert_ratio
>
0
:
source
=
self
.
add_insertion_noise
(
source
,
self
.
insert_ratio
)
if
self
.
rotate_ratio
>
0.0
and
np
.
random
.
random
()
<
self
.
rotate_ratio
:
source
=
self
.
add_rolling_noise
(
source
)
# there can additional changes to make:
if
self
.
item_transform_func
is
not
None
:
source
,
target
=
self
.
item_transform_func
(
source
,
target
)
assert
(
source
>=
0
).
all
()
assert
(
source
[
1
:
-
1
]
>=
1
).
all
()
assert
(
source
<=
len
(
self
.
vocab
)).
all
()
assert
source
[
0
]
==
self
.
vocab
.
bos
()
assert
source
[
-
1
]
==
self
.
eos
return
{
"id"
:
index
,
"source"
:
source
,
"target"
:
target
,
}
def
__len__
(
self
):
return
len
(
self
.
dataset
)
def
permute_sentences
(
self
,
source
,
p
=
1.0
):
full_stops
=
source
==
self
.
full_stop_index
# Pretend it ends with a full stop so last span is a sentence
full_stops
[
-
2
]
=
1
# Tokens that are full stops, where the previous token is not
sentence_ends
=
(
full_stops
[
1
:]
*
~
full_stops
[:
-
1
]).
nonzero
(
as_tuple
=
False
)
+
2
result
=
source
.
clone
()
num_sentences
=
sentence_ends
.
size
(
0
)
num_to_permute
=
math
.
ceil
((
num_sentences
*
2
*
p
)
/
2.0
)
substitutions
=
torch
.
randperm
(
num_sentences
)[:
num_to_permute
]
ordering
=
torch
.
arange
(
0
,
num_sentences
)
ordering
[
substitutions
]
=
substitutions
[
torch
.
randperm
(
num_to_permute
)]
# Ignore <bos> at start
index
=
1
for
i
in
ordering
:
sentence
=
source
[(
sentence_ends
[
i
-
1
]
if
i
>
0
else
1
)
:
sentence_ends
[
i
]]
result
[
index
:
index
+
sentence
.
size
(
0
)]
=
sentence
index
+=
sentence
.
size
(
0
)
return
result
def
word_starts
(
self
,
source
):
if
self
.
mask_whole_word
is
not
None
:
is_word_start
=
self
.
mask_whole_word
.
gather
(
0
,
source
)
else
:
is_word_start
=
torch
.
ones
(
source
.
size
())
is_word_start
[
0
]
=
0
is_word_start
[
-
1
]
=
0
return
is_word_start
def
add_whole_word_mask
(
self
,
source
,
p
):
is_word_start
=
self
.
word_starts
(
source
)
num_to_mask
=
int
(
math
.
ceil
(
is_word_start
.
float
().
sum
()
*
p
))
num_inserts
=
0
if
num_to_mask
==
0
:
return
source
if
self
.
mask_span_distribution
is
not
None
:
lengths
=
self
.
mask_span_distribution
.
sample
(
sample_shape
=
(
num_to_mask
,))
# Make sure we have enough to mask
cum_length
=
torch
.
cumsum
(
lengths
,
0
)
while
cum_length
[
-
1
]
<
num_to_mask
:
lengths
=
torch
.
cat
(
[
lengths
,
self
.
mask_span_distribution
.
sample
(
sample_shape
=
(
num_to_mask
,)),
],
dim
=
0
,
)
cum_length
=
torch
.
cumsum
(
lengths
,
0
)
# Trim to masking budget
i
=
0
while
cum_length
[
i
]
<
num_to_mask
:
i
+=
1
lengths
[
i
]
=
num_to_mask
-
(
0
if
i
==
0
else
cum_length
[
i
-
1
])
num_to_mask
=
i
+
1
lengths
=
lengths
[:
num_to_mask
]
# Handle 0-length mask (inserts) separately
lengths
=
lengths
[
lengths
>
0
]
num_inserts
=
num_to_mask
-
lengths
.
size
(
0
)
num_to_mask
-=
num_inserts
if
num_to_mask
==
0
:
return
self
.
add_insertion_noise
(
source
,
num_inserts
/
source
.
size
(
0
))
assert
(
lengths
>
0
).
all
()
else
:
lengths
=
torch
.
ones
((
num_to_mask
,)).
long
()
assert
is_word_start
[
-
1
]
==
0
word_starts
=
is_word_start
.
nonzero
(
as_tuple
=
False
)
indices
=
word_starts
[
torch
.
randperm
(
word_starts
.
size
(
0
))[:
num_to_mask
]
].
squeeze
(
1
)
mask_random
=
torch
.
FloatTensor
(
num_to_mask
).
uniform_
()
<
self
.
random_ratio
source_length
=
source
.
size
(
0
)
assert
source_length
-
1
not
in
indices
to_keep
=
torch
.
ones
(
source_length
,
dtype
=
torch
.
bool
)
is_word_start
[
-
1
]
=
255
# acts as a long length, so spans don't go over the end of doc
if
self
.
replace_length
==
0
:
to_keep
[
indices
]
=
0
else
:
# keep index, but replace it with [MASK]
source
[
indices
]
=
self
.
mask_idx
source
[
indices
[
mask_random
]]
=
torch
.
randint
(
1
,
len
(
self
.
vocab
),
size
=
(
mask_random
.
sum
(),)
)
if
self
.
mask_span_distribution
is
not
None
:
assert
len
(
lengths
.
size
())
==
1
assert
lengths
.
size
()
==
indices
.
size
()
lengths
-=
1
while
indices
.
size
(
0
)
>
0
:
assert
lengths
.
size
()
==
indices
.
size
()
lengths
-=
is_word_start
[
indices
+
1
].
long
()
uncompleted
=
lengths
>=
0
indices
=
indices
[
uncompleted
]
+
1
mask_random
=
mask_random
[
uncompleted
]
lengths
=
lengths
[
uncompleted
]
if
self
.
replace_length
!=
-
1
:
# delete token
to_keep
[
indices
]
=
0
else
:
# keep index, but replace it with [MASK]
source
[
indices
]
=
self
.
mask_idx
source
[
indices
[
mask_random
]]
=
torch
.
randint
(
1
,
len
(
self
.
vocab
),
size
=
(
mask_random
.
sum
(),)
)
else
:
# A bit faster when all lengths are 1
while
indices
.
size
(
0
)
>
0
:
uncompleted
=
is_word_start
[
indices
+
1
]
==
0
indices
=
indices
[
uncompleted
]
+
1
mask_random
=
mask_random
[
uncompleted
]
if
self
.
replace_length
!=
-
1
:
# delete token
to_keep
[
indices
]
=
0
else
:
# keep index, but replace it with [MASK]
source
[
indices
]
=
self
.
mask_idx
source
[
indices
[
mask_random
]]
=
torch
.
randint
(
1
,
len
(
self
.
vocab
),
size
=
(
mask_random
.
sum
(),)
)
assert
source_length
-
1
not
in
indices
source
=
source
[
to_keep
]
if
num_inserts
>
0
:
source
=
self
.
add_insertion_noise
(
source
,
num_inserts
/
source
.
size
(
0
))
return
source
def
add_permuted_noise
(
self
,
tokens
,
p
):
num_words
=
len
(
tokens
)
num_to_permute
=
math
.
ceil
(((
num_words
*
2
)
*
p
)
/
2.0
)
substitutions
=
torch
.
randperm
(
num_words
-
2
)[:
num_to_permute
]
+
1
tokens
[
substitutions
]
=
tokens
[
substitutions
[
torch
.
randperm
(
num_to_permute
)]]
return
tokens
def
add_rolling_noise
(
self
,
tokens
):
offset
=
np
.
random
.
randint
(
1
,
max
(
1
,
tokens
.
size
(
-
1
)
-
1
)
+
1
)
tokens
=
torch
.
cat
(
(
tokens
[
0
:
1
],
tokens
[
offset
:
-
1
],
tokens
[
1
:
offset
],
tokens
[
-
1
:]),
dim
=
0
,
)
return
tokens
def
add_insertion_noise
(
self
,
tokens
,
p
):
if
p
==
0.0
:
return
tokens
num_tokens
=
len
(
tokens
)
n
=
int
(
math
.
ceil
(
num_tokens
*
p
))
noise_indices
=
torch
.
randperm
(
num_tokens
+
n
-
2
)[:
n
]
+
1
noise_mask
=
torch
.
zeros
(
size
=
(
num_tokens
+
n
,),
dtype
=
torch
.
bool
)
noise_mask
[
noise_indices
]
=
1
result
=
torch
.
LongTensor
(
n
+
len
(
tokens
)).
fill_
(
-
1
)
num_random
=
int
(
math
.
ceil
(
n
*
self
.
random_ratio
))
result
[
noise_indices
[
num_random
:]]
=
self
.
mask_idx
result
[
noise_indices
[:
num_random
]]
=
torch
.
randint
(
low
=
1
,
high
=
len
(
self
.
vocab
),
size
=
(
num_random
,)
)
result
[
~
noise_mask
]
=
tokens
assert
(
result
>=
0
).
all
()
return
result
def
collater
(
self
,
samples
,
pad_to_length
=
None
):
"""Merge a list of samples to form a mini-batch.
Args:
samples (List[dict]): samples to collate
Returns:
dict: a mini-batch of data
"""
return
collate
(
samples
,
self
.
vocab
.
pad
(),
self
.
eos
,
self
.
vocab
,
pad_to_length
=
pad_to_length
)
def
num_tokens
(
self
,
index
):
"""Return the number of tokens in a sample. This value is used to
enforce ``--max-tokens`` during batching."""
return
self
.
sizes
[
index
]
def
size
(
self
,
index
):
"""Return an example's size as a float or tuple. This value is used when
filtering a dataset with ``--max-positions``."""
return
self
.
sizes
[
index
]
def
ordered_indices
(
self
):
"""Return an ordered list of indices. Batches will be constructed based
on this order."""
if
self
.
shuffle
:
indices
=
np
.
random
.
permutation
(
len
(
self
))
else
:
indices
=
np
.
arange
(
len
(
self
))
return
indices
[
np
.
argsort
(
self
.
sizes
[
indices
],
kind
=
"mergesort"
)]
def
prefetch
(
self
,
indices
):
self
.
src
.
prefetch
(
indices
)
self
.
tgt
.
prefetch
(
indices
)
@
property
def
supports_prefetch
(
self
):
return
(
hasattr
(
self
.
src
,
"supports_prefetch"
)
and
self
.
src
.
supports_prefetch
and
hasattr
(
self
.
tgt
,
"supports_prefetch"
)
and
self
.
tgt
.
supports_prefetch
)
PyTorch/NLP/Transformer/fairseq/data/dictionary.py
→
PyTorch/NLP/
new-
Transformer/fairseq/data/dictionary.py
View file @
c0f05c10
# Copyright (c) 2017-present, Facebook, Inc.
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
#
# This source code is licensed under the license found in the LICENSE file in
# This source code is licensed under the MIT license found in the
# the root directory of this source tree. An additional grant of patent rights
# LICENSE file in the root directory of this source tree.
# can be found in the PATENTS file in the same directory.
from
collections
import
Counter
import
os
import
os
from
collections
import
Counter
from
multiprocessing
import
Pool
import
torch
import
torch
from
fairseq
import
utils
from
fairseq.data
import
data_utils
from
fairseq.file_chunker_utils
import
Chunker
,
find_offsets
from
fairseq.file_io
import
PathManager
from
fairseq.tokenizer
import
tokenize_line
class
Dictionary
(
object
)
:
class
Dictionary
:
"""A mapping from symbols to consecutive integers"""
"""A mapping from symbols to consecutive integers"""
def
__init__
(
self
,
pad
=
'<pad>'
,
eos
=
'</s>'
,
unk
=
'<unk>'
):
self
.
unk_word
,
self
.
pad_word
,
self
.
eos_word
=
unk
,
pad
,
eos
def
__init__
(
self
,
*
,
# begin keyword-only arguments
bos
=
"<s>"
,
pad
=
"<pad>"
,
eos
=
"</s>"
,
unk
=
"<unk>"
,
extra_special_symbols
=
None
,
):
self
.
bos_word
,
self
.
unk_word
,
self
.
pad_word
,
self
.
eos_word
=
bos
,
unk
,
pad
,
eos
self
.
symbols
=
[]
self
.
symbols
=
[]
self
.
count
=
[]
self
.
count
=
[]
self
.
indices
=
{}
self
.
indices
=
{}
# dictionary indexing starts at 1 for consistency with Lua
self
.
bos_index
=
self
.
add_symbol
(
bos
)
self
.
add_symbol
(
'<Lua heritage>'
)
self
.
pad_index
=
self
.
add_symbol
(
pad
)
self
.
pad_index
=
self
.
add_symbol
(
pad
)
self
.
eos_index
=
self
.
add_symbol
(
eos
)
self
.
eos_index
=
self
.
add_symbol
(
eos
)
self
.
unk_index
=
self
.
add_symbol
(
unk
)
self
.
unk_index
=
self
.
add_symbol
(
unk
)
if
extra_special_symbols
:
for
s
in
extra_special_symbols
:
self
.
add_symbol
(
s
)
self
.
nspecial
=
len
(
self
.
symbols
)
self
.
nspecial
=
len
(
self
.
symbols
)
def
__eq__
(
self
,
other
):
def
__eq__
(
self
,
other
):
...
@@ -33,45 +48,83 @@ class Dictionary(object):
...
@@ -33,45 +48,83 @@ class Dictionary(object):
return
self
.
symbols
[
idx
]
return
self
.
symbols
[
idx
]
return
self
.
unk_word
return
self
.
unk_word
def
get_count
(
self
,
idx
):
return
self
.
count
[
idx
]
def
__len__
(
self
):
def
__len__
(
self
):
"""Returns the number of symbols in the dictionary"""
"""Returns the number of symbols in the dictionary"""
return
len
(
self
.
symbols
)
return
len
(
self
.
symbols
)
def
__contains__
(
self
,
sym
):
return
sym
in
self
.
indices
def
index
(
self
,
sym
):
def
index
(
self
,
sym
):
"""Returns the index of the specified symbol"""
"""Returns the index of the specified symbol"""
assert
isinstance
(
sym
,
str
)
if
sym
in
self
.
indices
:
if
sym
in
self
.
indices
:
return
self
.
indices
[
sym
]
return
self
.
indices
[
sym
]
return
self
.
unk_index
return
self
.
unk_index
def
string
(
self
,
tensor
,
bpe_symbol
=
None
,
escape_unk
=
False
):
def
string
(
self
,
tensor
,
bpe_symbol
=
None
,
escape_unk
=
False
,
extra_symbols_to_ignore
=
None
,
unk_string
=
None
,
include_eos
=
False
,
separator
=
" "
,
):
"""Helper for converting a tensor of token indices to a string.
"""Helper for converting a tensor of token indices to a string.
Can optionally remove BPE symbols or escape <unk> words.
Can optionally remove BPE symbols or escape <unk> words.
"""
"""
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
dim
()
==
2
:
if
torch
.
is_tensor
(
tensor
)
and
tensor
.
dim
()
==
2
:
return
'
\n
'
.
join
(
self
.
string
(
t
)
for
t
in
tensor
)
return
"
\n
"
.
join
(
self
.
string
(
t
,
bpe_symbol
,
escape_unk
,
extra_symbols_to_ignore
,
include_eos
=
include_eos
,
)
for
t
in
tensor
)
extra_symbols_to_ignore
=
set
(
extra_symbols_to_ignore
or
[])
if
not
include_eos
:
extra_symbols_to_ignore
.
add
(
self
.
eos
())
def
token_string
(
i
):
def
token_string
(
i
):
if
i
==
self
.
unk
():
if
i
==
self
.
unk
():
return
self
.
unk_string
(
escape_unk
)
if
unk_string
is
not
None
:
return
unk_string
else
:
return
self
.
unk_string
(
escape_unk
)
else
:
else
:
return
self
[
i
]
return
self
[
i
]
sent
=
' '
.
join
(
token_string
(
i
)
for
i
in
tensor
if
i
!=
self
.
eos
())
if
hasattr
(
self
,
"bos_index"
):
if
bpe_symbol
is
not
None
:
extra_symbols_to_ignore
.
add
(
self
.
bos
())
sent
=
(
sent
+
' '
).
replace
(
bpe_symbol
,
''
).
rstrip
()
return
sent
sent
=
separator
.
join
(
token_string
(
i
)
for
i
in
tensor
if
utils
.
item
(
i
)
not
in
extra_symbols_to_ignore
)
return
data_utils
.
post_process
(
sent
,
bpe_symbol
)
def
unk_string
(
self
,
escape
=
False
):
def
unk_string
(
self
,
escape
=
False
):
"""Return unknown string, optionally escaped as: <<unk>>"""
"""Return unknown string, optionally escaped as: <<unk>>"""
if
escape
:
if
escape
:
return
'
<{}>
'
.
format
(
self
.
unk_word
)
return
"
<{}>
"
.
format
(
self
.
unk_word
)
else
:
else
:
return
self
.
unk_word
return
self
.
unk_word
def
add_symbol
(
self
,
word
,
n
=
1
):
def
add_symbol
(
self
,
word
,
n
=
1
,
overwrite
=
False
):
"""Adds a word to the dictionary"""
"""Adds a word to the dictionary"""
if
word
in
self
.
indices
:
if
word
in
self
.
indices
and
not
overwrite
:
idx
=
self
.
indices
[
word
]
idx
=
self
.
indices
[
word
]
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
n
self
.
count
[
idx
]
=
self
.
count
[
idx
]
+
n
return
idx
return
idx
...
@@ -109,11 +162,15 @@ class Dictionary(object):
...
@@ -109,11 +162,15 @@ class Dictionary(object):
if
nwords
<=
0
:
if
nwords
<=
0
:
nwords
=
len
(
self
)
nwords
=
len
(
self
)
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
new_indices
=
dict
(
zip
(
self
.
symbols
[:
self
.
nspecial
],
range
(
self
.
nspecial
)))
new_symbols
=
self
.
symbols
[:
self
.
nspecial
]
new_symbols
=
self
.
symbols
[:
self
.
nspecial
]
new_count
=
self
.
count
[:
self
.
nspecial
]
new_count
=
self
.
count
[:
self
.
nspecial
]
c
=
Counter
(
dict
(
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:])))
c
=
Counter
(
dict
(
sorted
(
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]))
)
)
for
symbol
,
count
in
c
.
most_common
(
nwords
-
self
.
nspecial
):
for
symbol
,
count
in
c
.
most_common
(
nwords
-
self
.
nspecial
):
if
count
>=
threshold
:
if
count
>=
threshold
:
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_indices
[
symbol
]
=
len
(
new_symbols
)
...
@@ -122,24 +179,27 @@ class Dictionary(object):
...
@@ -122,24 +179,27 @@ class Dictionary(object):
else
:
else
:
break
break
threshold_nwords
=
len
(
new_symbols
)
if
padding_factor
>
1
:
i
=
0
while
threshold_nwords
%
padding_factor
!=
0
:
symbol
=
'madeupword{:04d}'
.
format
(
i
)
new_indices
[
symbol
]
=
len
(
new_symbols
)
new_symbols
.
append
(
symbol
)
new_count
.
append
(
0
)
i
+=
1
threshold_nwords
+=
1
assert
len
(
new_symbols
)
%
padding_factor
==
0
assert
len
(
new_symbols
)
==
len
(
new_indices
)
assert
len
(
new_symbols
)
==
len
(
new_indices
)
self
.
count
=
list
(
new_count
)
self
.
count
=
list
(
new_count
)
self
.
symbols
=
list
(
new_symbols
)
self
.
symbols
=
list
(
new_symbols
)
self
.
indices
=
new_indices
self
.
indices
=
new_indices
self
.
pad_to_multiple_
(
padding_factor
)
def
pad_to_multiple_
(
self
,
padding_factor
):
"""Pad Dictionary size to be a multiple of *padding_factor*."""
if
padding_factor
>
1
:
i
=
0
while
len
(
self
)
%
padding_factor
!=
0
:
symbol
=
"madeupword{:04d}"
.
format
(
i
)
self
.
add_symbol
(
symbol
,
n
=
0
)
i
+=
1
def
bos
(
self
):
"""Helper to get index of beginning-of-sentence symbol"""
return
self
.
bos_index
def
pad
(
self
):
def
pad
(
self
):
"""Helper to get index of pad symbol"""
"""Helper to get index of pad symbol"""
return
self
.
pad_index
return
self
.
pad_index
...
@@ -153,20 +213,7 @@ class Dictionary(object):
...
@@ -153,20 +213,7 @@ class Dictionary(object):
return
self
.
unk_index
return
self
.
unk_index
@
classmethod
@
classmethod
def
loads
(
cls
,
s
):
def
load
(
cls
,
f
):
lines
=
s
.
strip
().
split
(
'
\n
'
)
d
=
cls
()
for
line
in
lines
:
idx
=
line
.
rfind
(
' '
)
word
=
line
[:
idx
]
count
=
int
(
line
[
idx
+
1
:])
d
.
indices
[
word
]
=
len
(
d
.
symbols
)
d
.
symbols
.
append
(
word
)
d
.
count
.
append
(
count
)
return
d
@
classmethod
def
load
(
cls
,
f
,
ignore_utf_errors
=
False
):
"""Loads the dictionary from a text file with the format:
"""Loads the dictionary from a text file with the format:
```
```
...
@@ -175,47 +222,180 @@ class Dictionary(object):
...
@@ -175,47 +222,180 @@ class Dictionary(object):
...
...
```
```
"""
"""
d
=
cls
()
d
.
add_from_file
(
f
)
return
d
def
add_from_file
(
self
,
f
):
"""
Loads a pre-existing dictionary from a text file and adds its symbols
to this instance.
"""
if
isinstance
(
f
,
str
):
if
isinstance
(
f
,
str
):
try
:
try
:
if
not
ignore_utf_errors
:
with
open
(
PathManager
.
get_local_path
(
f
),
"r"
,
encoding
=
"utf-8"
)
as
fd
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
)
as
fd
:
self
.
add_from_file
(
fd
)
return
cls
.
load
(
fd
)
else
:
with
open
(
f
,
'r'
,
encoding
=
'utf-8'
,
errors
=
'ignore'
)
as
fd
:
return
cls
.
load
(
fd
)
except
FileNotFoundError
as
fnfe
:
except
FileNotFoundError
as
fnfe
:
raise
fnfe
raise
fnfe
except
Exception
:
except
UnicodeError
:
raise
Exception
(
"Incorrect encoding detected in {}, please "
raise
Exception
(
"rebuild the dataset"
.
format
(
f
))
"Incorrect encoding detected in {}, please "
cont
=
f
.
read
(
)
"rebuild the dataset"
.
format
(
f
)
d
=
cls
.
loads
(
cont
)
)
return
d
return
def
save
(
self
,
f
):
lines
=
f
.
readlines
()
"""Stores dictionary into a text file"""
indices_start_line
=
self
.
_load_meta
(
lines
)
for
line
in
lines
[
indices_start_line
:]:
try
:
line
,
field
=
line
.
rstrip
().
rsplit
(
" "
,
1
)
if
field
==
"#fairseq:overwrite"
:
overwrite
=
True
line
,
field
=
line
.
rsplit
(
" "
,
1
)
else
:
overwrite
=
False
count
=
int
(
field
)
word
=
line
if
word
in
self
and
not
overwrite
:
raise
RuntimeError
(
"Duplicate word found when loading Dictionary: '{}'. "
"Duplicate words can overwrite earlier ones by adding the "
"#fairseq:overwrite flag at the end of the corresponding row "
"in the dictionary file. If using the Camembert model, please "
"download an updated copy of the model file."
.
format
(
word
)
)
self
.
add_symbol
(
word
,
n
=
count
,
overwrite
=
overwrite
)
except
ValueError
:
raise
ValueError
(
f
"Incorrect dictionary format, expected '<token> <cnt> [flags]':
\"
{
line
}
\"
"
)
def
_save
(
self
,
f
,
kv_iterator
):
if
isinstance
(
f
,
str
):
if
isinstance
(
f
,
str
):
os
.
make
dirs
(
os
.
path
.
dirname
(
f
)
,
exist_ok
=
True
)
PathManager
.
mk
dirs
(
os
.
path
.
dirname
(
f
))
with
open
(
f
,
'w'
,
encoding
=
'
utf-8
'
)
as
fd
:
with
PathManager
.
open
(
f
,
"w"
,
encoding
=
"
utf-8
"
)
as
fd
:
return
self
.
save
(
fd
)
return
self
.
save
(
fd
)
d
=
self
.
saves
()
for
k
,
v
in
kv_iterator
:
f
.
write
(
d
)
print
(
"{} {}"
.
format
(
k
,
v
),
file
=
f
)
def
_get_meta
(
self
):
return
[],
[]
def
saves
(
self
):
def
_load_meta
(
self
,
lines
):
rv
=
''
return
0
for
symbol
,
count
in
zip
(
self
.
symbols
[
self
.
nspecial
:],
self
.
count
[
self
.
nspecial
:]):
rv
+=
'{} {}
\n
'
.
format
(
symbol
,
count
)
def
save
(
self
,
f
):
return
rv
"""Stores dictionary into a text file"""
ex_keys
,
ex_vals
=
self
.
_get_meta
()
self
.
_save
(
f
,
zip
(
ex_keys
+
self
.
symbols
[
self
.
nspecial
:],
ex_vals
+
self
.
count
[
self
.
nspecial
:],
),
)
def
dummy_sentence
(
self
,
length
):
def
dummy_sentence
(
self
,
length
):
t
=
torch
.
Tensor
(
length
).
uniform_
(
self
.
nspecial
+
1
,
len
(
self
)).
long
()
t
=
torch
.
Tensor
(
length
).
uniform_
(
self
.
nspecial
+
1
,
len
(
self
)).
long
()
t
[
-
1
]
=
self
.
eos
()
t
[
-
1
]
=
self
.
eos
()
return
t
return
t
def
get_metadata
(
self
):
def
encode_line
(
return
{
'len'
:
self
.
__len__
(),
self
,
'pad'
:
self
.
pad_index
,
line
,
'eos'
:
self
.
eos_index
,
line_tokenizer
=
tokenize_line
,
'unk'
:
self
.
unk_index
,
add_if_not_exist
=
True
,
'nspecial'
:
self
.
nspecial
consumer
=
None
,
}
append_eos
=
True
,
reverse_order
=
False
,
)
->
torch
.
IntTensor
:
words
=
line_tokenizer
(
line
)
if
reverse_order
:
words
=
list
(
reversed
(
words
))
nwords
=
len
(
words
)
ids
=
torch
.
IntTensor
(
nwords
+
1
if
append_eos
else
nwords
)
for
i
,
word
in
enumerate
(
words
):
if
add_if_not_exist
:
idx
=
self
.
add_symbol
(
word
)
else
:
idx
=
self
.
index
(
word
)
if
consumer
is
not
None
:
consumer
(
word
,
idx
)
ids
[
i
]
=
idx
if
append_eos
:
ids
[
nwords
]
=
self
.
eos_index
return
ids
@
staticmethod
def
_add_file_to_dictionary_single_worker
(
filename
,
tokenize
,
eos_word
,
start_offset
,
end_offset
,
):
counter
=
Counter
()
with
Chunker
(
filename
,
start_offset
,
end_offset
)
as
line_iterator
:
for
line
in
line_iterator
:
for
word
in
tokenize
(
line
):
counter
.
update
([
word
])
counter
.
update
([
eos_word
])
return
counter
@
staticmethod
def
add_file_to_dictionary
(
filename
,
dict
,
tokenize
,
num_workers
):
def
merge_result
(
counter
):
for
w
,
c
in
sorted
(
counter
.
items
()):
dict
.
add_symbol
(
w
,
c
)
local_file
=
PathManager
.
get_local_path
(
filename
)
offsets
=
find_offsets
(
local_file
,
num_workers
)
if
num_workers
>
1
:
chunks
=
zip
(
offsets
,
offsets
[
1
:])
pool
=
Pool
(
processes
=
num_workers
)
results
=
[]
for
(
start_offset
,
end_offset
)
in
chunks
:
results
.
append
(
pool
.
apply_async
(
Dictionary
.
_add_file_to_dictionary_single_worker
,
(
local_file
,
tokenize
,
dict
.
eos_word
,
start_offset
,
end_offset
,
),
)
)
pool
.
close
()
pool
.
join
()
for
r
in
results
:
merge_result
(
r
.
get
())
else
:
merge_result
(
Dictionary
.
_add_file_to_dictionary_single_worker
(
local_file
,
tokenize
,
dict
.
eos_word
,
offsets
[
0
],
offsets
[
1
]
)
)
class
TruncatedDictionary
(
object
):
def
__init__
(
self
,
wrapped_dict
,
length
):
self
.
__class__
=
type
(
wrapped_dict
.
__class__
.
__name__
,
(
self
.
__class__
,
wrapped_dict
.
__class__
),
{},
)
self
.
__dict__
=
wrapped_dict
.
__dict__
self
.
wrapped_dict
=
wrapped_dict
self
.
length
=
min
(
len
(
self
.
wrapped_dict
),
length
)
def
__len__
(
self
):
return
self
.
length
def
__getitem__
(
self
,
i
):
if
i
<
self
.
length
:
return
self
.
wrapped_dict
[
i
]
return
self
.
wrapped_dict
.
unk
()
PyTorch/NLP/new-Transformer/fairseq/data/encoders/__init__.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
importlib
import
os
from
fairseq
import
registry
build_tokenizer
,
register_tokenizer
,
TOKENIZER_REGISTRY
,
_
=
registry
.
setup_registry
(
"--tokenizer"
,
default
=
None
,
)
build_bpe
,
register_bpe
,
BPE_REGISTRY
,
_
=
registry
.
setup_registry
(
"--bpe"
,
default
=
None
,
)
# automatically import any Python files in the encoders/ directory
for
file
in
sorted
(
os
.
listdir
(
os
.
path
.
dirname
(
__file__
))):
if
file
.
endswith
(
".py"
)
and
not
file
.
startswith
(
"_"
):
module
=
file
[:
file
.
find
(
".py"
)]
importlib
.
import_module
(
"fairseq.data.encoders."
+
module
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/byte_bpe.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from
dataclasses
import
dataclass
,
field
from
fairseq
import
file_utils
from
fairseq.data.encoders
import
register_bpe
from
fairseq.data.encoders.byte_utils
import
(
SPACE
,
SPACE_ESCAPE
,
byte_encode
,
smart_byte_decode
,
)
from
fairseq.dataclass
import
FairseqDataclass
@
dataclass
class
ByteBpeConfig
(
FairseqDataclass
):
sentencepiece_model_path
:
str
=
field
(
default
=
"???"
,
metadata
=
{
"help"
:
"path to sentencepiece model"
}
)
@
register_bpe
(
"byte_bpe"
,
dataclass
=
ByteBpeConfig
)
class
ByteBPE
(
object
):
def
__init__
(
self
,
cfg
):
vocab
=
file_utils
.
cached_path
(
cfg
.
sentencepiece_model_path
)
try
:
import
sentencepiece
as
spm
self
.
sp
=
spm
.
SentencePieceProcessor
()
self
.
sp
.
Load
(
vocab
)
except
ImportError
:
raise
ImportError
(
"Please install sentencepiece with: pip install sentencepiece"
)
def
encode
(
self
,
x
:
str
)
->
str
:
byte_encoded
=
byte_encode
(
x
)
return
SPACE
.
join
(
self
.
sp
.
EncodeAsPieces
(
byte_encoded
))
@
staticmethod
def
decode
(
x
:
str
)
->
str
:
unescaped
=
x
.
replace
(
SPACE
,
""
).
replace
(
SPACE_ESCAPE
,
SPACE
)
return
smart_byte_decode
(
unescaped
)
PyTorch/NLP/new-Transformer/fairseq/data/encoders/byte_utils.py
0 → 100644
View file @
c0f05c10
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
re
WHITESPACE_NORMALIZER
=
re
.
compile
(
r
"\s+"
)
SPACE
=
chr
(
32
)
SPACE_ESCAPE
=
chr
(
9601
)
# excluding non-breaking space (160) here
PRINTABLE_LATIN
=
set
(
list
(
range
(
32
,
126
+
1
))
+
list
(
range
(
161
,
172
+
1
))
+
list
(
range
(
174
,
255
+
1
))
)
BYTE_TO_BCHAR
=
{
b
:
chr
(
b
)
if
b
in
PRINTABLE_LATIN
else
chr
(
256
+
b
)
for
b
in
range
(
256
)
}
BCHAR_TO_BYTE
=
{
bc
:
b
for
b
,
bc
in
BYTE_TO_BCHAR
.
items
()}
def
byte_encode
(
x
:
str
)
->
str
:
normalized
=
WHITESPACE_NORMALIZER
.
sub
(
SPACE
,
x
)
return
""
.
join
([
BYTE_TO_BCHAR
[
b
]
for
b
in
normalized
.
encode
(
"utf-8"
)])
def
byte_decode
(
x
:
str
)
->
str
:
try
:
return
bytes
([
BCHAR_TO_BYTE
[
bc
]
for
bc
in
x
]).
decode
(
"utf-8"
)
except
ValueError
:
return
""
def
smart_byte_decode
(
x
:
str
)
->
str
:
output
=
byte_decode
(
x
)
if
output
==
""
:
# DP the best recovery (max valid chars) if it's broken
n_bytes
=
len
(
x
)
f
=
[
0
for
_
in
range
(
n_bytes
+
1
)]
pt
=
[
0
for
_
in
range
(
n_bytes
+
1
)]
for
i
in
range
(
1
,
n_bytes
+
1
):
f
[
i
],
pt
[
i
]
=
f
[
i
-
1
],
i
-
1
for
j
in
range
(
1
,
min
(
4
,
i
)
+
1
):
if
f
[
i
-
j
]
+
1
>
f
[
i
]
and
len
(
byte_decode
(
x
[
i
-
j
:
i
]))
>
0
:
f
[
i
],
pt
[
i
]
=
f
[
i
-
j
]
+
1
,
i
-
j
cur_pt
=
n_bytes
while
cur_pt
>
0
:
if
f
[
cur_pt
]
==
f
[
pt
[
cur_pt
]]
+
1
:
output
=
byte_decode
(
x
[
pt
[
cur_pt
]
:
cur_pt
])
+
output
cur_pt
=
pt
[
cur_pt
]
return
output
Prev
1
…
6
7
8
9
10
11
12
13
14
…
17
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment