Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VITA-Audio_pytorch
Commits
39ac40a9
Commit
39ac40a9
authored
Jun 06, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2747
failed with stages
in 0 seconds
Changes
427
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2253 additions
and
0 deletions
+2253
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/libri_labels.py
...irdparty/UniSpeech/src/examples/unispeech/libri_labels.py
+56
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/continue_pretrain.sh
...peech/src/examples/unispeech/scripts/continue_pretrain.sh
+14
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/finetune.sh
...arty/UniSpeech/src/examples/unispeech/scripts/finetune.sh
+14
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/inference.sh
...rty/UniSpeech/src/examples/unispeech/scripts/inference.sh
+12
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/multilingal_large_pretrain.sh
.../examples/unispeech/scripts/multilingal_large_pretrain.sh
+16
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/one2one_large_pretrain_en1350.sh
...amples/unispeech/scripts/one2one_large_pretrain_en1350.sh
+14
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/unispeech_manifest.py
...ty/UniSpeech/src/examples/unispeech/unispeech_manifest.py
+49
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/wav2vec_manifest.py
...arty/UniSpeech/src/examples/unispeech/wav2vec_manifest.py
+76
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/__init__.py
...eed-tts-eval/thirdparty/UniSpeech/src/fairseq/__init__.py
+41
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/binarizer.py
...ed-tts-eval/thirdparty/UniSpeech/src/fairseq/binarizer.py
+114
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/checkpoint_utils.py
...eval/thirdparty/UniSpeech/src/fairseq/checkpoint_utils.py
+803
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
...iSpeech/src/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
+47
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
...h/src/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
+76
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libbase/balanced_assignment.cpp
...niSpeech/src/fairseq/clib/libbase/balanced_assignment.cpp
+95
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libbleu/libbleu.cpp
...thirdparty/UniSpeech/src/fairseq/clib/libbleu/libbleu.cpp
+141
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libbleu/module.cpp
.../thirdparty/UniSpeech/src/fairseq/clib/libbleu/module.cpp
+37
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat/edit_dist.cpp
...hirdparty/UniSpeech/src/fairseq/clib/libnat/edit_dist.cpp
+231
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat_cuda/binding.cpp
...dparty/UniSpeech/src/fairseq/clib/libnat_cuda/binding.cpp
+60
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat_cuda/edit_dist.cu
...party/UniSpeech/src/fairseq/clib/libnat_cuda/edit_dist.cu
+332
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat_cuda/edit_dist.h
...dparty/UniSpeech/src/fairseq/clib/libnat_cuda/edit_dist.h
+25
-0
No files found.
Too many changes to show.
To preserve performance only
427 of 427+
files are displayed.
Plain diff
Email patch
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/libri_labels.py
0 → 100644
View file @
39ac40a9
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Helper script to pre-compute embeddings for a wav2letter++ dataset
"""
import
argparse
import
os
def
main
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"tsv"
)
parser
.
add_argument
(
"--output-dir"
,
required
=
True
)
parser
.
add_argument
(
"--output-name"
,
required
=
True
)
args
=
parser
.
parse_args
()
os
.
makedirs
(
args
.
output_dir
,
exist_ok
=
True
)
transcriptions
=
{}
with
open
(
args
.
tsv
,
"r"
)
as
tsv
,
open
(
os
.
path
.
join
(
args
.
output_dir
,
args
.
output_name
+
".ltr"
),
"w"
)
as
ltr_out
,
open
(
os
.
path
.
join
(
args
.
output_dir
,
args
.
output_name
+
".wrd"
),
"w"
)
as
wrd_out
:
root
=
next
(
tsv
).
strip
()
for
line
in
tsv
:
line
=
line
.
strip
()
dir
=
os
.
path
.
dirname
(
line
)
if
dir
not
in
transcriptions
:
parts
=
dir
.
split
(
os
.
path
.
sep
)
trans_path
=
f
"
{
parts
[
-
2
]
}
-
{
parts
[
-
1
]
}
.trans.txt"
path
=
os
.
path
.
join
(
root
,
dir
,
trans_path
)
assert
os
.
path
.
exists
(
path
)
texts
=
{}
with
open
(
path
,
"r"
)
as
trans_f
:
for
tline
in
trans_f
:
items
=
tline
.
strip
().
split
()
texts
[
items
[
0
]]
=
" "
.
join
(
items
[
1
:])
transcriptions
[
dir
]
=
texts
part
=
os
.
path
.
basename
(
line
).
split
(
"."
)[
0
]
assert
part
in
transcriptions
[
dir
]
print
(
transcriptions
[
dir
][
part
],
file
=
wrd_out
)
print
(
" "
.
join
(
list
(
transcriptions
[
dir
][
part
].
replace
(
" "
,
"|"
)))
+
" |"
,
file
=
ltr_out
,
)
if
__name__
==
"__main__"
:
main
()
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/continue_pretrain.sh
0 → 100644
View file @
39ac40a9
#:: Copyright (c) Microsoft Corporation.
#:: Licensed under the MIT License.
model_path
=
MODEL_PATH
train_subset
=
pretrain_HOUR_16k
valid_subset
=
valSeqs_1.0_uniform_new_version_16k
WORLD_SIZE
=
8
update_freq
=
2
mkdir
-p
${
model_path
}
python train.py
--distributed-world-size
${
WORLD_SIZE
}
--distributed-port
0 examples/unispeech/data/LANG
--save-dir
${
model_path
}
--fp16
--num-workers
10
--task
audio_pretraining
--criterion
wav2vec
--arch
wav2vec2
--train-subset
${
train_subset
}
--valid-subset
${
valid_subset
}
--log-keys
'["prob_perplexity","code_perplexity","temp"]'
--quantize-targets
--normalize
--extractor-mode
"layer_norm"
--encoder-layers
24
--encoder-embed-dim
1024
--encoder-ffn-embed-dim
4096
--encoder-attention-heads
16
--final-dim
768
--layer-norm-first
--conv-feature-layers
'[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2'
--latent-vars
320
--latent-groups
2
--latent-temp
'(2,0.1,0.999995)'
--infonce
--optimizer
adam
--adam-betas
'(0.9,0.98)'
--adam-eps
1e-06
--lr-scheduler
polynomial_decay
--total-num-update
100000
--lr
0.0002
--warmup-updates
10000
--mask-length
10
--mask-prob
0.65
--mask-selection
static
--mask-other
0
--encoder-layerdrop
0.05
--dropout-input
0.1
--dropout-features
0.1
--feature-grad-mult
0.1
--loss-weights
'[0.1, 0]'
--conv-pos
128
--conv-pos-groups
16
--num-negatives
100
--cross-sample-negatives
0
--max-sample-size
250000
--min-sample-size
32000
--dropout
0.1
--attention-dropout
0.1
--weight-decay
0.01
--max-tokens
1000000
--max-update
100000
--skip-invalid-size-inputs-valid-test
--ddp-backend
no_c10d
--update-freq
${
update_freq
}
--pretrained-path
PRETRAINED_MODEL
--no-epoch-checkpoints
--transpose
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/finetune.sh
0 → 100644
View file @
39ac40a9
#:: Copyright (c) Microsoft Corporation.
#:: Licensed under the MIT License.
model_path
=
MODEL_PATH
pretrained_model
=
PRETRAINED_MODEL
train_subset
=
trainSeqs_1.0_uniform_new_version_16k
valid_subset
=
valSeqs_1.0_uniform_new_version_16k
mkdir
-p
${
model_path
}
WORLD_SIZE
=
4
updata_freq
=
1
python train.py
--distributed-world-size
$WORLD_SIZE
--distributed-port
0 examples/unispeech/data/LANG/
--save-dir
${
model_path
}
--post-process
word
--train-subset
${
train_subset
}
--valid-subset
${
valid_subset
}
--no-epoch-checkpoints
--best-checkpoint-metric
uer
--num-workers
4
--max-update
20000
--sentence-avg
--task
audio_pretraining
--arch
wav2vec_ctc
--w2v-path
${
pretrained_model
}
--labels
id
--apply-mask
--mask-selection
static
--mask-other
0
--mask-length
10
--mask-prob
0.75
--layerdrop
0.1
--mask-channel-selection
static
--mask-channel-other
0
--mask-channel-length
64
--mask-channel-prob
0.25
--zero-infinity
--feature-grad-mult
0.0
--freeze-finetune-updates
2000
--validate-after-updates
2000
--optimizer
adam
--adam-betas
'(0.9, 0.98)'
--adam-eps
1e-08
--lr
2e-05
--lr-scheduler
tri_stage
--warmup-steps
2000
--hold-steps
8000
--decay-steps
10000
--final-lr-scale
0.05
--activation-dropout
0.1
--dropout
0.1
--attention-dropout
0.1
--final-dropout
0.1
--dropout-input
0.1
--criterion
ctc
--max-tokens
1000000
--seed
1337
--log-format
json
--log-interval
100
--ddp-backend
no_c10d
--fp16
--update-freq
${
updata_freq
}
--dict-path
examples/unispeech/data/LANG/phonesMatches_reduced.json
--save-interval
10
--validate-interval
10
--normalize
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/inference.sh
0 → 100644
View file @
39ac40a9
#:: Copyright (c) Microsoft Corporation.
#:: Licensed under the MIT License.
model_path
=
MODEL_PATH
gen_subset
=
testSeqs_uniform_new_version_16k
result_path
=
${
model_path
}
/decode_ctc/
${
gen_subset
}
mkdir
-p
${
result_path
}
export
PYTHONENCODING
=
UTF-8
python examples/speech_recognition/infer.py examples/unispeech/data/LANG
--task
audio_pretraining
--nbest
1
--path
${
model_path
}
/checkpoint_best.pt
--gen-subset
${
gen_subset
}
--results-path
${
result_path
}
--w2l-decoder
viterbi
--word-score
-1
--sil-weight
0
--criterion
ctc
--max-tokens
4000000
--dict-path
examples/unispeech/data/LANG/phonesMatches_reduced.json
--post-process
none
--quiet
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/multilingal_large_pretrain.sh
0 → 100644
View file @
39ac40a9
#:: Copyright (c) Microsoft Corporation.
#:: Licensed under the MIT License.
model_path
=
MODEL_PATH
valid_subset
=
en/valid_16k
WORLD_SIZE
=
NUM_OF_GPUS
update_freq
=
$((
64
/
$WORLD_SIZE
))
#ngpu * update_freq = 64
DISTRIBUTED_ARGS
=
"--nproc_per_node
$GPUS_PER_NODE
--nnodes
$NNODES
--node_rank
$NODE_RANK
--master_addr
$MASTER_ADDR
--master_port
$MASTER_PORT
"
mkdir
-p
${
model_path
}
python
-m
torch.distributed.launch
$DISTRIBUTED_ARGS
train.py
--distributed-world-size
${
WORLD_SIZE
}
--distributed-port
0 examples/unispeech/data
--save-dir
${
model_path
}
--fp16
--num-workers
10
--task
audio_pretraining
--criterion
wav2vec_mtl
--arch
unispeech
--extractor-mode
"layer_norm"
--encoder-layers
24
--encoder-embed-dim
1024
--encoder-ffn-embed-dim
4096
--encoder-attention-heads
16
--final-dim
768
--layer-norm-first
--conv-bias
--logit-temp
0.1
--train-subset
en/pretrain_1350_16k,es/pretrain_168_16k_sep,fr/pretrain_353_16k_sep,it/pretrain_90_16k_sep
--valid-subset
${
valid_subset
}
--log-keys
'["prob_perplexity","code_perplexity","temp"]'
--quantize-targets
--conv-feature-layers
'[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2'
--latent-vars
320
--latent-groups
2
--latent-temp
'(2,0.1,0.999995)'
--infonce
--optimizer
adam
--adam-betas
'(0.9,0.98)'
--adam-eps
1e-06
--lr-scheduler
polynomial_decay
--total-num-update
200000
--lr
0.001
--warmup-updates
25000
--mask-length
10
--mask-prob
0.5
--mask-selection
static
--mask-other
0
--encoder-layerdrop
0.0
--dropout-input
0.1
--dropout-features
0.1
--feature-grad-mult
1.0
--loss-weights
'[0.1, 0]'
--conv-pos
128
--conv-pos-groups
16
--num-negatives
100
--cross-sample-negatives
0
--max-sample-size
320000
--min-sample-size
32000
--dropout
0.1
--attention-dropout
0.1
--weight-decay
0.01
--max-tokens
1200000
--max-update
250000
--skip-invalid-size-inputs-valid-test
--ddp-backend
no_c10d
--update-freq
${
update_freq
}
--post-process
word
--labels
ltr
--dict-path
examples/unispeech/data/mtl/vocab_sep.json
--negatives-from-everywhere
--mtlalpha
0.5
--replace-prob
0.5
--transpose
--no-epoch-checkpoints
--log-format
json
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/scripts/one2one_large_pretrain_en1350.sh
0 → 100644
View file @
39ac40a9
#:: Copyright (c) Microsoft Corporation.
#:: Licensed under the MIT License.
model_path
=
MODEL_PATH
train_subset
=
pretrain_1350_16k
valid_subset
=
valid_16k
WORLD_SIZE
=
NUM_OF_GPUS
update_freq
=
$((
64
/
$WORLD_SIZE
))
#ngpu * update_freq = 64
mkdir
-p
${
model_path
}
python train.py
--distributed-world-size
${
WORLD_SIZE
}
--distributed-port
0 examples/unispeech/data/en
--save-dir
${
model_path
}
--fp16
--num-workers
10
--task
audio_pretraining
--criterion
unispeech_criterion
--arch
unispeech
--extractor-mode
"layer_norm"
--encoder-layers
24
--encoder-embed-dim
1024
--encoder-ffn-embed-dim
4096
--encoder-attention-heads
16
--final-dim
768
--layer-norm-first
--conv-bias
--logit-temp
0.1
--train-subset
${
train_subset
}
--valid-subset
${
valid_subset
}
--log-keys
'["prob_perplexity","code_perplexity","temp"]'
--quantize-targets
--conv-feature-layers
'[(512, 10, 5)] + [(512, 3, 2)] * 4 + [(512,2,2)] * 2'
--latent-vars
320
--latent-groups
2
--latent-temp
'(2,0.1,0.999995)'
--infonce
--optimizer
adam
--adam-betas
'(0.9,0.98)'
--adam-eps
1e-06
--lr-scheduler
polynomial_decay
--total-num-update
200000
--lr
0.001
--warmup-updates
25000
--mask-length
10
--mask-prob
0.5
--mask-selection
static
--mask-other
0
--encoder-layerdrop
0.0
--dropout-input
0.1
--dropout-features
0.1
--feature-grad-mult
1.0
--loss-weights
'[0.1, 0]'
--conv-pos
128
--conv-pos-groups
16
--num-negatives
100
--cross-sample-negatives
0
--max-sample-size
320000
--min-sample-size
32000
--dropout
0.1
--attention-dropout
0.1
--weight-decay
0.01
--max-tokens
1200000
--max-update
250000
--skip-invalid-size-inputs-valid-test
--ddp-backend
no_c10d
--update-freq
${
update_freq
}
--post-process
none
--labels
id
--dict-path
examples/unispeech/data/en/vocab.json
--negatives-from-everywhere
--mtlalpha
0.5
--replace-prob
0.5
--transpose
--no-epoch-checkpoints
--log-format
json
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/unispeech_manifest.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
import
argparse
import
os
def
get_parser
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'input'
,
type
=
str
,
help
=
"input .tsv file"
)
parser
.
add_argument
(
'--dest'
,
type
=
str
,
help
=
"output directory"
)
return
parser
def
main
(
args
):
wav_names
=
[]
text
=
[]
with
open
(
args
.
input
)
as
f
:
f
.
readline
()
for
line
in
f
:
items
=
line
.
strip
().
split
(
"
\t
"
)
wav_names
.
append
(
items
[
1
])
text
.
append
(
items
[
2
])
base_name
=
os
.
path
.
basename
(
args
.
input
)
file_name
=
os
.
path
.
splitext
(
base_name
)[
0
]
with
open
(
os
.
path
.
join
(
args
.
dest
,
file_name
+
'.list'
),
'w'
)
as
f
:
for
name
in
wav_names
:
f
.
write
(
name
+
"
\n
"
)
with
open
(
os
.
path
.
join
(
args
.
dest
,
file_name
+
'.text'
),
'w'
)
as
f
:
for
i
in
range
(
len
(
wav_names
)):
f
.
write
(
"{}
\t
{}
\n
"
.
format
(
wav_names
[
i
],
text
[
i
]))
if
__name__
==
"__main__"
:
parser
=
get_parser
()
args
=
parser
.
parse_args
()
main
(
args
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/unispeech/wav2vec_manifest.py
0 → 100644
View file @
39ac40a9
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Data pre-processing: build vocabularies and binarize training data.
"""
import
argparse
import
glob
import
os
import
random
import
soundfile
def
get_parser
():
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"root"
,
metavar
=
"DIR"
,
help
=
"root directory containing flac files to index"
)
parser
.
add_argument
(
"--valid-percent"
,
default
=
0.01
,
type
=
float
,
metavar
=
"D"
,
help
=
"percentage of data to use as validation set (between 0 and 1)"
,
)
parser
.
add_argument
(
"--dest"
,
default
=
"."
,
type
=
str
,
metavar
=
"DIR"
,
help
=
"output directory"
)
parser
.
add_argument
(
"--ext"
,
default
=
"flac"
,
type
=
str
,
metavar
=
"EXT"
,
help
=
"extension to look for"
)
parser
.
add_argument
(
"--seed"
,
default
=
42
,
type
=
int
,
metavar
=
"N"
,
help
=
"random seed"
)
parser
.
add_argument
(
"--path-must-contain"
,
default
=
None
,
type
=
str
,
metavar
=
"FRAG"
,
help
=
"if set, path must contain this substring for a file to be included in the manifest"
,
)
return
parser
def
main
(
args
):
assert
args
.
valid_percent
>=
0
and
args
.
valid_percent
<=
1.0
dir_path
=
os
.
path
.
realpath
(
args
.
root
)
search_path
=
os
.
path
.
join
(
dir_path
,
"**/*."
+
args
.
ext
)
rand
=
random
.
Random
(
args
.
seed
)
with
open
(
os
.
path
.
join
(
args
.
dest
,
"train.tsv"
),
"w"
)
as
train_f
,
open
(
os
.
path
.
join
(
args
.
dest
,
"valid.tsv"
),
"w"
)
as
valid_f
:
print
(
dir_path
,
file
=
train_f
)
print
(
dir_path
,
file
=
valid_f
)
for
fname
in
glob
.
iglob
(
search_path
,
recursive
=
True
):
file_path
=
os
.
path
.
realpath
(
fname
)
if
args
.
path_must_contain
and
args
.
path_must_contain
not
in
file_path
:
continue
frames
=
soundfile
.
info
(
fname
).
frames
dest
=
train_f
if
rand
.
random
()
>
args
.
valid_percent
else
valid_f
print
(
"{}
\t
{}"
.
format
(
os
.
path
.
relpath
(
file_path
,
dir_path
),
frames
),
file
=
dest
)
if
__name__
==
"__main__"
:
parser
=
get_parser
()
args
=
parser
.
parse_args
()
main
(
args
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/__init__.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""isort:skip_file"""
import
os
import
sys
try
:
from
.version
import
__version__
# noqa
except
ImportError
:
version_txt
=
os
.
path
.
join
(
os
.
path
.
dirname
(
__file__
),
"version.txt"
)
with
open
(
version_txt
)
as
f
:
__version__
=
f
.
read
().
strip
()
__all__
=
[
"pdb"
]
# backwards compatibility to support `from fairseq.X import Y`
from
fairseq.distributed
import
utils
as
distributed_utils
from
fairseq.logging
import
meters
,
metrics
,
progress_bar
# noqa
sys
.
modules
[
"fairseq.distributed_utils"
]
=
distributed_utils
sys
.
modules
[
"fairseq.meters"
]
=
meters
sys
.
modules
[
"fairseq.metrics"
]
=
metrics
sys
.
modules
[
"fairseq.progress_bar"
]
=
progress_bar
# initialize hydra
from
fairseq.dataclass.initialize
import
hydra_init
hydra_init
()
import
fairseq.criterions
# noqa
import
fairseq.distributed
# noqa
import
fairseq.models
# noqa
import
fairseq.modules
# noqa
import
fairseq.optim
# noqa
import
fairseq.optim.lr_scheduler
# noqa
import
fairseq.pdb
# noqa
import
fairseq.tasks
# noqa
import
fairseq.token_generation_constraints
# noqa
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/binarizer.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
os
from
collections
import
Counter
import
torch
from
fairseq.file_io
import
PathManager
from
fairseq.tokenizer
import
tokenize_line
from
typing
import
List
,
Dict
def
safe_readline
(
f
):
pos
=
f
.
tell
()
while
True
:
try
:
return
f
.
readline
()
except
UnicodeDecodeError
:
pos
-=
1
f
.
seek
(
pos
)
# search where this character begins
class
Binarizer
:
@
staticmethod
def
binarize
(
filename
,
dict
,
consumer
,
tokenize
=
tokenize_line
,
append_eos
=
True
,
reverse_order
=
False
,
offset
=
0
,
end
=-
1
,
already_numberized
=
False
,
)
->
Dict
[
str
,
int
]:
nseq
,
ntok
=
0
,
0
replaced
=
Counter
()
def
replaced_consumer
(
word
,
idx
):
if
idx
==
dict
.
unk_index
and
word
!=
dict
.
unk_word
:
replaced
.
update
([
word
])
with
open
(
PathManager
.
get_local_path
(
filename
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
f
.
seek
(
offset
)
# next(f) breaks f.tell(), hence readline() must be used
line
=
safe_readline
(
f
)
while
line
:
# f.tell() does not always give the byte position in the file
# sometimes it skips to a very large number
# it is unlikely that through a normal read we go from
# end bytes to end + 2**32 bytes (4 GB) and this makes it unlikely
# that the procedure breaks by the undeterministic behavior of
# f.tell()
if
end
>
0
and
f
.
tell
()
>
end
and
f
.
tell
()
<
end
+
2
**
32
:
break
if
already_numberized
:
id_strings
=
line
.
strip
().
split
()
id_list
=
[
int
(
id_string
)
for
id_string
in
id_strings
]
if
reverse_order
:
id_list
.
reverse
()
if
append_eos
:
id_list
.
append
(
dict
.
eos
())
ids
=
torch
.
IntTensor
(
id_list
)
else
:
ids
=
dict
.
encode_line
(
line
=
line
,
line_tokenizer
=
tokenize
,
add_if_not_exist
=
False
,
consumer
=
replaced_consumer
,
append_eos
=
append_eos
,
reverse_order
=
reverse_order
,
)
nseq
+=
1
ntok
+=
len
(
ids
)
consumer
(
ids
)
line
=
f
.
readline
()
return
{
"nseq"
:
nseq
,
"nunk"
:
sum
(
replaced
.
values
()),
"ntok"
:
ntok
,
"replaced"
:
replaced
,
}
@
staticmethod
def
binarize_alignments
(
filename
,
alignment_parser
,
consumer
,
offset
=
0
,
end
=-
1
)
->
Dict
[
str
,
int
]:
nseq
=
0
with
open
(
PathManager
.
get_local_path
(
filename
),
"r"
)
as
f
:
f
.
seek
(
offset
)
line
=
safe_readline
(
f
)
while
line
:
if
end
>
0
and
f
.
tell
()
>
end
:
break
ids
=
alignment_parser
(
line
)
nseq
+=
1
consumer
(
ids
)
line
=
f
.
readline
()
return
{
"nseq"
:
nseq
}
@
staticmethod
def
find_offsets
(
filename
,
num_chunks
)
->
List
[
int
]:
with
open
(
PathManager
.
get_local_path
(
filename
),
"r"
,
encoding
=
"utf-8"
)
as
f
:
size
=
os
.
fstat
(
f
.
fileno
()).
st_size
chunk_size
=
size
//
num_chunks
offsets
=
[
0
for
_
in
range
(
num_chunks
+
1
)]
for
i
in
range
(
1
,
num_chunks
):
f
.
seek
(
chunk_size
*
i
)
safe_readline
(
f
)
offsets
[
i
]
=
f
.
tell
()
return
offsets
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/checkpoint_utils.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
ast
import
collections
import
contextlib
import
logging
import
os
import
re
import
time
import
traceback
from
collections
import
OrderedDict
from
typing
import
Any
,
Dict
,
Optional
,
Union
from
random
import
randint
import
torch
from
fairseq.dataclass.configs
import
CheckpointConfig
from
fairseq.dataclass.utils
import
(
convert_namespace_to_omegaconf
,
overwrite_args_by_name
,
)
from
fairseq.distributed.fully_sharded_data_parallel
import
FSDP
,
has_FSDP
from
fairseq.file_io
import
PathManager
from
fairseq.models
import
FairseqDecoder
,
FairseqEncoder
from
omegaconf
import
DictConfig
,
open_dict
,
OmegaConf
logger
=
logging
.
getLogger
(
__name__
)
def
save_checkpoint
(
cfg
:
CheckpointConfig
,
trainer
,
epoch_itr
,
val_loss
):
from
fairseq
import
meters
# only one worker should attempt to create the required dir
if
trainer
.
data_parallel_rank
==
0
:
os
.
makedirs
(
cfg
.
save_dir
,
exist_ok
=
True
)
prev_best
=
getattr
(
save_checkpoint
,
"best"
,
val_loss
)
if
val_loss
is
not
None
:
best_function
=
max
if
cfg
.
maximize_best_checkpoint_metric
else
min
save_checkpoint
.
best
=
best_function
(
val_loss
,
prev_best
)
if
cfg
.
no_save
:
return
trainer
.
consolidate_optimizer
()
# TODO(SS): do we need this if no_save_optimizer_state
"""
if not trainer.should_save_checkpoint_on_current_rank:
if trainer.always_call_state_dict_during_save_checkpoint:
trainer.state_dict()
return
"""
if
not
trainer
.
is_data_parallel_master
:
return
write_timer
=
meters
.
StopwatchMeter
()
write_timer
.
start
()
epoch
=
epoch_itr
.
epoch
end_of_epoch
=
epoch_itr
.
end_of_epoch
()
updates
=
trainer
.
get_num_updates
()
logger
.
info
(
f
"Preparing to save checkpoint for epoch
{
epoch
}
@
{
updates
}
updates"
)
def
is_better
(
a
,
b
):
return
a
>=
b
if
cfg
.
maximize_best_checkpoint_metric
else
a
<=
b
suffix
=
trainer
.
checkpoint_suffix
checkpoint_conds
=
collections
.
OrderedDict
()
checkpoint_conds
[
"checkpoint{}{}.pt"
.
format
(
epoch
,
suffix
)]
=
(
end_of_epoch
and
not
cfg
.
no_epoch_checkpoints
and
epoch
%
cfg
.
save_interval
==
0
)
checkpoint_conds
[
"checkpoint_{}_{}{}.pt"
.
format
(
epoch
,
updates
,
suffix
)]
=
(
not
end_of_epoch
and
cfg
.
save_interval_updates
>
0
and
updates
%
cfg
.
save_interval_updates
==
0
)
checkpoint_conds
[
"checkpoint_best{}.pt"
.
format
(
suffix
)]
=
val_loss
is
not
None
and
(
not
hasattr
(
save_checkpoint
,
"best"
)
or
is_better
(
val_loss
,
save_checkpoint
.
best
)
)
if
val_loss
is
not
None
and
cfg
.
keep_best_checkpoints
>
0
:
worst_best
=
getattr
(
save_checkpoint
,
"best"
,
None
)
chkpts
=
checkpoint_paths
(
cfg
.
save_dir
,
pattern
=
r
"checkpoint\.best_{}_(\d+\.?\d*)\.pt"
.
format
(
cfg
.
best_checkpoint_metric
),
)
if
len
(
chkpts
)
>
0
:
p
=
chkpts
[
-
1
]
if
cfg
.
maximize_best_checkpoint_metric
else
chkpts
[
0
]
worst_best
=
float
(
p
.
rsplit
(
"_"
)[
-
1
].
replace
(
".pt"
,
""
))
# add random digits to resolve ties
rand_sfx
=
randint
(
0
,
cfg
.
keep_best_checkpoints
)
checkpoint_conds
[
"checkpoint.best_{}_{:.3f}{}.pt"
.
format
(
cfg
.
best_checkpoint_metric
,
val_loss
,
rand_sfx
)
]
=
worst_best
is
None
or
is_better
(
val_loss
,
worst_best
)
checkpoint_conds
[
"checkpoint_last{}.pt"
.
format
(
suffix
)
]
=
not
cfg
.
no_last_checkpoints
extra_state
=
{
"train_iterator"
:
epoch_itr
.
state_dict
(),
"val_loss"
:
val_loss
}
if
hasattr
(
save_checkpoint
,
"best"
):
extra_state
.
update
({
"best"
:
save_checkpoint
.
best
})
checkpoints
=
[
os
.
path
.
join
(
cfg
.
save_dir
,
fn
)
for
fn
,
cond
in
checkpoint_conds
.
items
()
if
cond
]
if
len
(
checkpoints
)
>
0
:
trainer
.
save_checkpoint
(
checkpoints
[
0
],
extra_state
)
for
cp
in
checkpoints
[
1
:]:
if
cfg
.
write_checkpoints_asynchronously
:
# TODO[ioPath]: Need to implement a delayed asynchronous
# file copying/moving feature.
logger
.
warning
(
f
"ioPath is not copying
{
checkpoints
[
0
]
}
to
{
cp
}
"
"since async write mode is on."
)
else
:
assert
PathManager
.
copy
(
checkpoints
[
0
],
cp
,
overwrite
=
True
),
f
"Failed to copy
{
checkpoints
[
0
]
}
to
{
cp
}
"
write_timer
.
stop
()
logger
.
info
(
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)"
.
format
(
checkpoints
[
0
],
epoch
,
updates
,
val_loss
,
write_timer
.
sum
)
)
if
not
end_of_epoch
and
cfg
.
keep_interval_updates
>
0
:
# remove old checkpoints; checkpoints are sorted in descending order
if
cfg
.
keep_interval_updates_pattern
==
-
1
:
checkpoints
=
checkpoint_paths
(
cfg
.
save_dir
,
pattern
=
r
"checkpoint_\d+_(\d+){}\.pt"
.
format
(
suffix
)
)
else
:
checkpoints
=
checkpoint_paths
(
cfg
.
save_dir
,
pattern
=
r
"checkpoint_\d+_(\d+){}\.pt"
.
format
(
suffix
),
keep_match
=
True
,
)
checkpoints
=
[
x
[
0
]
for
x
in
checkpoints
if
x
[
1
]
%
cfg
.
keep_interval_updates_pattern
!=
0
]
for
old_chk
in
checkpoints
[
cfg
.
keep_interval_updates
:]:
if
os
.
path
.
lexists
(
old_chk
):
os
.
remove
(
old_chk
)
elif
PathManager
.
exists
(
old_chk
):
PathManager
.
rm
(
old_chk
)
if
cfg
.
keep_last_epochs
>
0
:
# remove old epoch checkpoints; checkpoints are sorted in descending order
checkpoints
=
checkpoint_paths
(
cfg
.
save_dir
,
pattern
=
r
"checkpoint(\d+){}\.pt"
.
format
(
suffix
)
)
for
old_chk
in
checkpoints
[
cfg
.
keep_last_epochs
:]:
if
os
.
path
.
lexists
(
old_chk
):
os
.
remove
(
old_chk
)
if
cfg
.
keep_best_checkpoints
>
0
:
# only keep the best N checkpoints according to validation metric
checkpoints
=
checkpoint_paths
(
cfg
.
save_dir
,
pattern
=
r
"checkpoint\.best_{}_(\d+\.?\d*){}\.pt"
.
format
(
cfg
.
best_checkpoint_metric
,
suffix
),
)
if
not
cfg
.
maximize_best_checkpoint_metric
:
checkpoints
=
checkpoints
[::
-
1
]
for
old_chk
in
checkpoints
[
cfg
.
keep_best_checkpoints
:]:
if
os
.
path
.
lexists
(
old_chk
):
os
.
remove
(
old_chk
)
def
load_checkpoint
(
cfg
:
CheckpointConfig
,
trainer
,
**
passthrough_args
):
"""
Load a checkpoint and restore the training iterator.
*passthrough_args* will be passed through to
``trainer.get_train_iterator``.
"""
reset_optimizer
=
cfg
.
reset_optimizer
reset_lr_scheduler
=
cfg
.
reset_lr_scheduler
optimizer_overrides
=
ast
.
literal_eval
(
cfg
.
optimizer_overrides
)
reset_meters
=
cfg
.
reset_meters
reset_dataloader
=
cfg
.
reset_dataloader
if
cfg
.
finetune_from_model
is
not
None
and
(
reset_optimizer
or
reset_lr_scheduler
or
reset_meters
or
reset_dataloader
):
raise
ValueError
(
"--finetune-from-model can not be set together with either --reset-optimizer"
" or reset_lr_scheduler or reset_meters or reset_dataloader"
)
suffix
=
trainer
.
checkpoint_suffix
if
(
cfg
.
restore_file
==
"checkpoint_last.pt"
):
# default value of restore_file is 'checkpoint_last.pt'
checkpoint_path
=
os
.
path
.
join
(
cfg
.
save_dir
,
"checkpoint_last{}.pt"
.
format
(
suffix
)
)
first_launch
=
not
PathManager
.
exists
(
checkpoint_path
)
if
cfg
.
finetune_from_model
is
not
None
and
first_launch
:
# if there is no last checkpoint to restore, start the finetune from pretrained model
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
if
PathManager
.
exists
(
cfg
.
finetune_from_model
):
checkpoint_path
=
cfg
.
finetune_from_model
reset_optimizer
=
True
reset_lr_scheduler
=
True
reset_meters
=
True
reset_dataloader
=
True
logger
.
info
(
f
"loading pretrained model from
{
checkpoint_path
}
: "
"optimizer, lr scheduler, meters, dataloader will be reset"
)
else
:
raise
ValueError
(
f
"--funetune-from-model
{
cfg
.
finetune_from_model
}
does not exist"
)
elif
suffix
is
not
None
:
checkpoint_path
=
cfg
.
restore_file
.
replace
(
".pt"
,
suffix
+
".pt"
)
else
:
checkpoint_path
=
cfg
.
restore_file
if
cfg
.
restore_file
!=
"checkpoint_last.pt"
and
cfg
.
finetune_from_model
:
raise
ValueError
(
"--finetune-from-model and --restore-file (non-default value) "
"can not be specified together: "
+
str
(
cfg
)
)
extra_state
=
trainer
.
load_checkpoint
(
checkpoint_path
,
reset_optimizer
,
reset_lr_scheduler
,
optimizer_overrides
,
reset_meters
=
reset_meters
,
)
if
(
extra_state
is
not
None
and
"best"
in
extra_state
and
not
reset_optimizer
and
not
reset_meters
):
save_checkpoint
.
best
=
extra_state
[
"best"
]
if
extra_state
is
not
None
and
not
reset_dataloader
:
# restore iterator from checkpoint
itr_state
=
extra_state
[
"train_iterator"
]
epoch_itr
=
trainer
.
get_train_iterator
(
epoch
=
itr_state
[
"epoch"
],
load_dataset
=
True
,
**
passthrough_args
)
epoch_itr
.
load_state_dict
(
itr_state
)
else
:
epoch_itr
=
trainer
.
get_train_iterator
(
epoch
=
1
,
load_dataset
=
True
,
**
passthrough_args
)
trainer
.
lr_step
(
epoch_itr
.
epoch
)
return
extra_state
,
epoch_itr
def
load_checkpoint_to_cpu
(
path
,
arg_overrides
=
None
,
load_on_all_ranks
=
False
):
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
If doing single-GPU training or if the checkpoint is only being loaded by at
most one process on each node (current default behavior is for only rank 0
to read the checkpoint from disk), load_on_all_ranks should be False to
avoid errors from torch.distributed not having been initialized or
torch.distributed.barrier() hanging.
If all processes on each node may be loading the checkpoint
simultaneously, load_on_all_ranks should be set to True to avoid I/O
conflicts.
There's currently no support for > 1 but < all processes loading the
checkpoint on each node.
"""
local_path
=
PathManager
.
get_local_path
(
path
)
# The locally cached file returned by get_local_path() may be stale for
# remote files that are periodically updated/overwritten (ex:
# checkpoint_last.pt) - so we remove the local copy, sync across processes
# (if needed), and then download a fresh copy.
if
local_path
!=
path
and
PathManager
.
path_requires_pathmanager
(
path
):
try
:
os
.
remove
(
local_path
)
except
FileNotFoundError
:
# With potentially multiple processes removing the same file, the
# file being missing is benign (missing_ok isn't available until
# Python 3.8).
pass
if
load_on_all_ranks
:
torch
.
distributed
.
barrier
()
local_path
=
PathManager
.
get_local_path
(
path
)
with
open
(
local_path
,
"rb"
)
as
f
:
state
=
torch
.
load
(
f
,
map_location
=
torch
.
device
(
"cpu"
))
if
"args"
in
state
and
state
[
"args"
]
is
not
None
and
arg_overrides
is
not
None
:
args
=
state
[
"args"
]
for
arg_name
,
arg_val
in
arg_overrides
.
items
():
setattr
(
args
,
arg_name
,
arg_val
)
if
"cfg"
in
state
and
state
[
"cfg"
]
is
not
None
:
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
# omegaconf version that supports object flags, or when we migrate all existing models
from
omegaconf
import
_utils
old_primitive
=
_utils
.
is_primitive_type
_utils
.
is_primitive_type
=
lambda
_
:
True
state
[
"cfg"
]
=
OmegaConf
.
create
(
state
[
"cfg"
])
_utils
.
is_primitive_type
=
old_primitive
OmegaConf
.
set_struct
(
state
[
"cfg"
],
True
)
if
arg_overrides
is
not
None
:
overwrite_args_by_name
(
state
[
"cfg"
],
arg_overrides
)
state
=
_upgrade_state_dict
(
state
)
return
state
def
load_model_ensemble
(
filenames
,
arg_overrides
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
task
=
None
,
strict
=
True
,
suffix
=
""
,
num_shards
=
1
,
state
=
None
,
):
"""Loads an ensemble of models.
Args:
filenames (List[str]): checkpoint files to load
arg_overrides (Dict[str,Any], optional): override model args that
were used during model training
task (fairseq.tasks.FairseqTask, optional): task to use for loading
"""
assert
not
(
strict
and
num_shards
>
1
),
"Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble
,
args
,
_task
=
load_model_ensemble_and_task
(
filenames
,
arg_overrides
,
task
,
strict
,
suffix
,
num_shards
,
state
,
)
return
ensemble
,
args
def
get_maybe_sharded_checkpoint_filename
(
filename
:
str
,
suffix
:
str
,
shard_idx
:
int
,
num_shards
:
int
)
->
str
:
orig_filename
=
filename
filename
=
filename
.
replace
(
".pt"
,
suffix
+
".pt"
)
fsdp_filename
=
filename
[:
-
3
]
+
f
"-shard
{
shard_idx
}
.pt"
model_parallel_filename
=
orig_filename
[:
-
3
]
+
f
"_part
{
shard_idx
}
.pt"
if
PathManager
.
exists
(
fsdp_filename
):
return
fsdp_filename
elif
num_shards
>
1
:
return
model_parallel_filename
else
:
return
filename
def
load_model_ensemble_and_task
(
filenames
,
arg_overrides
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
task
=
None
,
strict
=
True
,
suffix
=
""
,
num_shards
=
1
,
state
=
None
,
):
assert
state
is
None
or
len
(
filenames
)
==
1
from
fairseq
import
tasks
assert
not
(
strict
and
num_shards
>
1
),
"Cannot load state dict with strict=True and checkpoint shards > 1"
ensemble
=
[]
cfg
=
None
for
filename
in
filenames
:
orig_filename
=
filename
model_shard_state
=
{
"shard_weights"
:
[],
"shard_metadata"
:
[]}
assert
num_shards
>
0
st
=
time
.
time
()
for
shard_idx
in
range
(
num_shards
):
filename
=
get_maybe_sharded_checkpoint_filename
(
orig_filename
,
suffix
,
shard_idx
,
num_shards
)
if
not
PathManager
.
exists
(
filename
):
raise
IOError
(
"Model file not found: {}"
.
format
(
filename
))
if
state
is
None
:
state
=
load_checkpoint_to_cpu
(
filename
,
arg_overrides
)
if
"args"
in
state
and
state
[
"args"
]
is
not
None
:
cfg
=
convert_namespace_to_omegaconf
(
state
[
"args"
])
elif
"cfg"
in
state
and
state
[
"cfg"
]
is
not
None
:
cfg
=
state
[
"cfg"
]
else
:
raise
RuntimeError
(
f
"Neither args nor cfg exist in state keys =
{
state
.
keys
()
}
"
)
if
task
is
None
:
task
=
tasks
.
setup_task
(
cfg
.
task
)
if
"task_state"
in
state
:
task
.
load_state_dict
(
state
[
"task_state"
])
if
"fsdp_metadata"
in
state
and
num_shards
>
1
:
model_shard_state
[
"shard_weights"
].
append
(
state
[
"model"
])
model_shard_state
[
"shard_metadata"
].
append
(
state
[
"fsdp_metadata"
])
# check FSDP import before the code goes too far
if
not
has_FSDP
:
raise
ImportError
(
"Cannot find FullyShardedDataParallel. "
"Please install fairscale with: pip install fairscale"
)
if
shard_idx
==
num_shards
-
1
:
consolidated_model_state
=
FSDP
.
consolidate_shard_weights
(
shard_weights
=
model_shard_state
[
"shard_weights"
],
shard_metadata
=
model_shard_state
[
"shard_metadata"
],
)
model
=
task
.
build_model
(
cfg
.
model
)
model
.
load_state_dict
(
consolidated_model_state
,
strict
=
strict
,
model_cfg
=
cfg
.
model
)
else
:
# model parallel checkpoint or unsharded checkpoint
model
=
task
.
build_model
(
cfg
.
model
)
model
.
load_state_dict
(
state
[
"model"
],
strict
=
strict
,
model_cfg
=
cfg
.
model
)
# reset state so it gets loaded for the next model in ensemble
state
=
None
if
shard_idx
%
10
==
0
and
shard_idx
>
0
:
elapsed
=
time
.
time
()
-
st
logger
.
info
(
f
"Loaded
{
shard_idx
}
shards in
{
elapsed
:.
2
f
}
s,
{
elapsed
/
(
shard_idx
+
1
):.
2
f
}
s/shard"
)
# build model for ensemble
ensemble
.
append
(
model
)
return
ensemble
,
cfg
,
task
def
checkpoint_paths
(
path
,
pattern
=
r
"checkpoint(\d+)\.pt"
,
keep_match
=
False
):
"""Retrieves all checkpoints found in `path` directory.
Checkpoints are identified by matching filename to the specified pattern. If
the pattern contains groups, the result will be sorted by the first group in
descending order.
"""
pt_regexp
=
re
.
compile
(
pattern
)
files
=
PathManager
.
ls
(
path
)
entries
=
[]
for
i
,
f
in
enumerate
(
files
):
m
=
pt_regexp
.
fullmatch
(
f
)
if
m
is
not
None
:
idx
=
float
(
m
.
group
(
1
))
if
len
(
m
.
groups
())
>
0
else
i
entries
.
append
((
idx
,
m
.
group
(
0
)))
if
keep_match
:
return
[(
os
.
path
.
join
(
path
,
x
[
1
]),
x
[
0
])
for
x
in
sorted
(
entries
,
reverse
=
True
)]
else
:
return
[
os
.
path
.
join
(
path
,
x
[
1
])
for
x
in
sorted
(
entries
,
reverse
=
True
)]
def
torch_persistent_save
(
obj
,
filename
,
async_write
:
bool
=
False
):
if
async_write
:
with
PathManager
.
opena
(
filename
,
"wb"
)
as
f
:
_torch_persistent_save
(
obj
,
f
)
else
:
if
PathManager
.
supports_rename
(
filename
):
# do atomic save
with
PathManager
.
open
(
filename
+
".tmp"
,
"wb"
)
as
f
:
_torch_persistent_save
(
obj
,
f
)
PathManager
.
rename
(
filename
+
".tmp"
,
filename
)
else
:
# fallback to non-atomic save
with
PathManager
.
open
(
filename
,
"wb"
)
as
f
:
_torch_persistent_save
(
obj
,
f
)
def
_torch_persistent_save
(
obj
,
f
):
if
isinstance
(
f
,
str
):
with
PathManager
.
open
(
f
,
"wb"
)
as
h
:
torch_persistent_save
(
obj
,
h
)
return
for
i
in
range
(
3
):
try
:
return
torch
.
save
(
obj
,
f
)
except
Exception
:
if
i
==
2
:
logger
.
error
(
traceback
.
format_exc
())
def
_upgrade_state_dict
(
state
):
"""Helper for upgrading old model checkpoints."""
# add optimizer_history
if
"optimizer_history"
not
in
state
:
state
[
"optimizer_history"
]
=
[
{
"criterion_name"
:
"CrossEntropyCriterion"
,
"best_loss"
:
state
[
"best_loss"
]}
]
state
[
"last_optimizer_state"
]
=
state
[
"optimizer"
]
del
state
[
"optimizer"
]
del
state
[
"best_loss"
]
# move extra_state into sub-dictionary
if
"epoch"
in
state
and
"extra_state"
not
in
state
:
state
[
"extra_state"
]
=
{
"epoch"
:
state
[
"epoch"
],
"batch_offset"
:
state
[
"batch_offset"
],
"val_loss"
:
state
[
"val_loss"
],
}
del
state
[
"epoch"
]
del
state
[
"batch_offset"
]
del
state
[
"val_loss"
]
# reduce optimizer history's memory usage (only keep the last state)
if
"optimizer"
in
state
[
"optimizer_history"
][
-
1
]:
state
[
"last_optimizer_state"
]
=
state
[
"optimizer_history"
][
-
1
][
"optimizer"
]
for
optim_hist
in
state
[
"optimizer_history"
]:
del
optim_hist
[
"optimizer"
]
# record the optimizer class name
if
"optimizer_name"
not
in
state
[
"optimizer_history"
][
-
1
]:
state
[
"optimizer_history"
][
-
1
][
"optimizer_name"
]
=
"FairseqNAG"
# move best_loss into lr_scheduler_state
if
"lr_scheduler_state"
not
in
state
[
"optimizer_history"
][
-
1
]:
state
[
"optimizer_history"
][
-
1
][
"lr_scheduler_state"
]
=
{
"best"
:
state
[
"optimizer_history"
][
-
1
][
"best_loss"
]
}
del
state
[
"optimizer_history"
][
-
1
][
"best_loss"
]
# keep track of number of updates
if
"num_updates"
not
in
state
[
"optimizer_history"
][
-
1
]:
state
[
"optimizer_history"
][
-
1
][
"num_updates"
]
=
0
# old model checkpoints may not have separate source/target positions
if
(
"args"
in
state
and
hasattr
(
state
[
"args"
],
"max_positions"
)
and
not
hasattr
(
state
[
"args"
],
"max_source_positions"
)
):
state
[
"args"
].
max_source_positions
=
state
[
"args"
].
max_positions
state
[
"args"
].
max_target_positions
=
state
[
"args"
].
max_positions
# use stateful training data iterator
if
"train_iterator"
not
in
state
[
"extra_state"
]:
state
[
"extra_state"
][
"train_iterator"
]
=
{
"epoch"
:
state
[
"extra_state"
][
"epoch"
],
"iterations_in_epoch"
:
state
[
"extra_state"
].
get
(
"batch_offset"
,
0
),
}
# backward compatibility, cfg updates
if
"args"
in
state
and
state
[
"args"
]
is
not
None
:
# default to translation task
if
not
hasattr
(
state
[
"args"
],
"task"
):
state
[
"args"
].
task
=
"translation"
# --raw-text and --lazy-load are deprecated
if
getattr
(
state
[
"args"
],
"raw_text"
,
False
):
state
[
"args"
].
dataset_impl
=
"raw"
elif
getattr
(
state
[
"args"
],
"lazy_load"
,
False
):
state
[
"args"
].
dataset_impl
=
"lazy"
# epochs start at 1
if
state
[
"extra_state"
][
"train_iterator"
]
is
not
None
:
state
[
"extra_state"
][
"train_iterator"
][
"epoch"
]
=
max
(
state
[
"extra_state"
][
"train_iterator"
].
get
(
"epoch"
,
1
),
1
)
# --remove-bpe ==> --postprocess
if
hasattr
(
state
[
"args"
],
"remove_bpe"
):
state
[
"args"
].
post_process
=
state
[
"args"
].
remove_bpe
# --min-lr ==> --stop-min-lr
if
hasattr
(
state
[
"args"
],
"min_lr"
):
state
[
"args"
].
stop_min_lr
=
state
[
"args"
].
min_lr
del
state
[
"args"
].
min_lr
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
if
(
hasattr
(
state
[
"args"
],
"criterion"
)
and
state
[
"args"
].
criterion
in
[
"binary_cross_entropy"
,
"kd_binary_cross_entropy"
,
]
):
state
[
"args"
].
criterion
=
"wav2vec"
# remove log_keys if it's None (criteria will supply a default value of [])
if
hasattr
(
state
[
"args"
],
"log_keys"
)
and
state
[
"args"
].
log_keys
is
None
:
delattr
(
state
[
"args"
],
"log_keys"
)
# speech_pretraining => audio pretraining
if
(
hasattr
(
state
[
"args"
],
"task"
)
and
state
[
"args"
].
task
==
"speech_pretraining"
):
state
[
"args"
].
task
=
"audio_pretraining"
# audio_cpc => wav2vec
if
hasattr
(
state
[
"args"
],
"arch"
)
and
state
[
"args"
].
arch
==
"audio_cpc"
:
state
[
"args"
].
arch
=
"wav2vec"
# convert legacy float learning rate to List[float]
if
hasattr
(
state
[
"args"
],
"lr"
)
and
isinstance
(
state
[
"args"
].
lr
,
float
):
state
[
"args"
].
lr
=
[
state
[
"args"
].
lr
]
# convert task data arg to a string instead of List[string]
if
(
hasattr
(
state
[
"args"
],
"data"
)
and
isinstance
(
state
[
"args"
].
data
,
list
)
and
len
(
state
[
"args"
].
data
)
>
0
):
state
[
"args"
].
data
=
state
[
"args"
].
data
[
0
]
# remove keys in state["args"] related to teacher-student learning
for
key
in
[
"static_teachers"
,
"static_teacher_weights"
,
"dynamic_teachers"
,
"dynamic_teacher_weights"
,
]:
if
key
in
state
[
"args"
]:
delattr
(
state
[
"args"
],
key
)
state
[
"cfg"
]
=
convert_namespace_to_omegaconf
(
state
[
"args"
])
if
"cfg"
in
state
and
state
[
"cfg"
]
is
not
None
:
cfg
=
state
[
"cfg"
]
with
open_dict
(
cfg
):
# any upgrades for Hydra-based configs
if
(
"task"
in
cfg
and
"eval_wer_config"
in
cfg
.
task
and
isinstance
(
cfg
.
task
.
eval_wer_config
.
print_alignment
,
bool
)
):
cfg
.
task
.
eval_wer_config
.
print_alignment
=
"hard"
if
"generation"
in
cfg
and
isinstance
(
cfg
.
generation
.
print_alignment
,
bool
):
cfg
.
generation
.
print_alignment
=
"hard"
if
(
"model"
in
cfg
and
"w2v_args"
in
cfg
.
model
and
cfg
.
model
.
w2v_args
is
not
None
and
(
hasattr
(
cfg
.
model
.
w2v_args
,
"task"
)
or
"task"
in
cfg
.
model
.
w2v_args
)
and
hasattr
(
cfg
.
model
.
w2v_args
.
task
,
"eval_wer_config"
)
and
cfg
.
model
.
w2v_args
.
task
.
eval_wer_config
is
not
None
and
isinstance
(
cfg
.
model
.
w2v_args
.
task
.
eval_wer_config
.
print_alignment
,
bool
)
):
cfg
.
model
.
w2v_args
.
task
.
eval_wer_config
.
print_alignment
=
"hard"
return
state
def
prune_state_dict
(
state_dict
,
model_cfg
:
Optional
[
DictConfig
]):
"""Prune the given state_dict if desired for LayerDrop
(https://arxiv.org/abs/1909.11556).
Training with LayerDrop allows models to be robust to pruning at inference
time. This function prunes state_dict to allow smaller models to be loaded
from a larger model and re-maps the existing state_dict for this to occur.
It's called by functions that load models from checkpoints and does not
need to be called directly.
"""
arch
=
None
if
model_cfg
is
not
None
:
arch
=
(
model_cfg
.
_name
if
isinstance
(
model_cfg
,
DictConfig
)
else
getattr
(
model_cfg
,
"arch"
,
None
)
)
if
not
model_cfg
or
arch
is
None
or
arch
==
"ptt_transformer"
:
# args should not be none, but don't crash if it is.
return
state_dict
encoder_layers_to_keep
=
getattr
(
model_cfg
,
"encoder_layers_to_keep"
,
None
)
decoder_layers_to_keep
=
getattr
(
model_cfg
,
"decoder_layers_to_keep"
,
None
)
if
not
encoder_layers_to_keep
and
not
decoder_layers_to_keep
:
return
state_dict
# apply pruning
logger
.
info
(
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
)
def
create_pruning_pass
(
layers_to_keep
,
layer_name
):
keep_layers
=
sorted
(
int
(
layer_string
)
for
layer_string
in
layers_to_keep
.
split
(
","
)
)
mapping_dict
=
{}
for
i
in
range
(
len
(
keep_layers
)):
mapping_dict
[
str
(
keep_layers
[
i
])]
=
str
(
i
)
regex
=
re
.
compile
(
r
"^{layer}.*\.layers\.(\d+)"
.
format
(
layer
=
layer_name
))
return
{
"substitution_regex"
:
regex
,
"mapping_dict"
:
mapping_dict
}
pruning_passes
=
[]
if
encoder_layers_to_keep
:
pruning_passes
.
append
(
create_pruning_pass
(
encoder_layers_to_keep
,
"encoder"
))
if
decoder_layers_to_keep
:
pruning_passes
.
append
(
create_pruning_pass
(
decoder_layers_to_keep
,
"decoder"
))
new_state_dict
=
{}
for
layer_name
in
state_dict
.
keys
():
match
=
re
.
search
(
r
"\.layers\.(\d+)\."
,
layer_name
)
# if layer has no number in it, it is a supporting layer, such as an
# embedding
if
not
match
:
new_state_dict
[
layer_name
]
=
state_dict
[
layer_name
]
continue
# otherwise, layer should be pruned.
original_layer_number
=
match
.
group
(
1
)
# figure out which mapping dict to replace from
for
pruning_pass
in
pruning_passes
:
if
original_layer_number
in
pruning_pass
[
"mapping_dict"
]
and
pruning_pass
[
"substitution_regex"
].
search
(
layer_name
):
new_layer_number
=
pruning_pass
[
"mapping_dict"
][
original_layer_number
]
substitution_match
=
pruning_pass
[
"substitution_regex"
].
search
(
layer_name
)
new_state_key
=
(
layer_name
[:
substitution_match
.
start
(
1
)]
+
new_layer_number
+
layer_name
[
substitution_match
.
end
(
1
)
:]
)
new_state_dict
[
new_state_key
]
=
state_dict
[
layer_name
]
# Since layers are now pruned, *_layers_to_keep are no longer needed.
# This is more of "It would make it work fix" rather than a proper fix.
if
isinstance
(
model_cfg
,
DictConfig
):
context
=
open_dict
(
model_cfg
)
else
:
context
=
contextlib
.
ExitStack
()
with
context
:
if
hasattr
(
model_cfg
,
"encoder_layers_to_keep"
):
model_cfg
.
encoder_layers_to_keep
=
None
if
hasattr
(
model_cfg
,
"decoder_layers_to_keep"
):
model_cfg
.
decoder_layers_to_keep
=
None
return
new_state_dict
def
load_pretrained_component_from_model
(
component
:
Union
[
FairseqEncoder
,
FairseqDecoder
],
checkpoint
:
str
):
"""
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
provided `component` object. If state_dict fails to load, there may be a
mismatch in the architecture of the corresponding `component` found in the
`checkpoint` file.
"""
if
not
PathManager
.
exists
(
checkpoint
):
raise
IOError
(
"Model file not found: {}"
.
format
(
checkpoint
))
state
=
load_checkpoint_to_cpu
(
checkpoint
)
if
isinstance
(
component
,
FairseqEncoder
):
component_type
=
"encoder"
elif
isinstance
(
component
,
FairseqDecoder
):
component_type
=
"decoder"
else
:
raise
ValueError
(
"component to load must be either a FairseqEncoder or "
"FairseqDecoder. Loading other component types are not supported."
)
component_state_dict
=
OrderedDict
()
for
key
in
state
[
"model"
].
keys
():
if
key
.
startswith
(
component_type
):
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
component_subkey
=
key
[
len
(
component_type
)
+
1
:]
component_state_dict
[
component_subkey
]
=
state
[
"model"
][
key
]
component
.
load_state_dict
(
component_state_dict
,
strict
=
True
)
return
component
def
verify_checkpoint_directory
(
save_dir
:
str
)
->
None
:
if
not
os
.
path
.
exists
(
save_dir
):
os
.
makedirs
(
save_dir
,
exist_ok
=
True
)
temp_file_path
=
os
.
path
.
join
(
save_dir
,
"dummy"
)
try
:
with
open
(
temp_file_path
,
"w"
):
pass
except
OSError
as
e
:
logger
.
warning
(
"Unable to access checkpoint save directory: {}"
.
format
(
save_dir
)
)
raise
e
else
:
os
.
remove
(
temp_file_path
)
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/cuda/ngram_repeat_block_cuda.cpp
0 → 100644
View file @
39ac40a9
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
#include <torch/extension.h>
#include <vector>
/*
CPP Binding for CUDA OP
*/
// CUDA forward declarations
torch
::
Tensor
ngram_repeat_block_cuda_forward
(
torch
::
Tensor
tokens
,
torch
::
Tensor
lprobs
,
int
bsz
,
int
step
,
int
beam_size
,
int
no_repeat_ngram_size
);
#define CHECK_CUDA(x) \
TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
// Input check and call to CUDA OP
// Backward method not required
torch
::
Tensor
ngram_repeat_block_forward
(
torch
::
Tensor
tokens
,
torch
::
Tensor
lprobs
,
int
bsz
,
int
step
,
int
beam_size
,
int
no_repeat_ngram_size
)
{
CHECK_INPUT
(
tokens
);
CHECK_INPUT
(
lprobs
);
assert
(
bsz
>
0
);
assert
(
step
>=
0
);
assert
(
beam_size
>
0
);
assert
(
no_repeat_ngram_size
>
0
);
return
ngram_repeat_block_cuda_forward
(
tokens
,
lprobs
,
bsz
,
step
,
beam_size
,
no_repeat_ngram_size
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"forward"
,
&
ngram_repeat_block_forward
,
"No Repeat Ngram Block forward (CUDA)"
);
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/cuda/ngram_repeat_block_cuda_kernel.cu
0 → 100644
View file @
39ac40a9
/*
Copyright (c) Microsoft Corporation.
Licensed under the MIT License.
*/
/*
Kernel implementation for blocking repeated n-grams.
*/
#include <cuda.h>
#include <cuda_runtime.h>
#include <math.h>
#include <torch/extension.h>
#include <vector>
// Ban repeated ngrams of length = 'no_repeat_ngram_size'
__global__
void
banRepeatedTokens
(
long
*
__restrict__
tokens
,
float
*
__restrict__
lprobs
,
int
max_predict_len
,
int
vocab_size
,
int
no_repeat_ngram_size
)
{
auto
row
=
blockIdx
.
x
;
auto
col
=
threadIdx
.
x
;
auto
start
=
row
*
(
max_predict_len
)
+
col
;
// Each thread compares ngram starting from
// thread index with final ngram starting from
// step - no_repeat_ngram_size +2
auto
check_start_pos
=
blockDim
.
x
;
auto
lprob_start
=
row
*
vocab_size
;
bool
is_banned
=
true
;
extern
__shared__
long
tokens_shm
[];
tokens_shm
[
col
]
=
tokens
[
start
];
if
(
col
==
blockDim
.
x
-
1
)
{
for
(
int
i
=
1
;
i
<
no_repeat_ngram_size
;
i
++
){
if
(
col
+
i
<
max_predict_len
){
tokens_shm
[
col
+
i
]
=
tokens
[
start
+
i
];
}
}
}
__syncthreads
();
for
(
int
k
=
0
;
k
<
no_repeat_ngram_size
-
1
;
k
++
)
{
if
(
tokens_shm
[
col
+
k
]
!=
tokens_shm
[
check_start_pos
+
k
])
{
is_banned
=
false
;
}
}
if
(
is_banned
==
true
)
{
auto
token_to_be_banned
=
tokens_shm
[
col
+
no_repeat_ngram_size
-
1
];
lprobs
[
lprob_start
+
token_to_be_banned
]
=
-
INFINITY
;
}
}
// Allocate blocks and threads based on
// batch size and sequence length and launch
// kernel
torch
::
Tensor
ngram_repeat_block_cuda_forward
(
const
torch
::
Tensor
tokens
,
torch
::
Tensor
lprobs
,
int
bsz
,
int
step
,
int
beam_size
,
int
no_repeat_ngram_size
)
{
int
threads
=
step
-
no_repeat_ngram_size
+
2
;
if
(
threads
<=
0
)
return
lprobs
;
int
max_predict_len
=
tokens
.
size
(
1
);
int
vocab_size
=
lprobs
.
size
(
1
);
auto
token_ptr
=
tokens
.
data_ptr
<
long
>
();
auto
lprob_ptr
=
lprobs
.
data_ptr
<
float
>
();
int
blocks
=
bsz
*
beam_size
;
int
shared_mem_size
=
(
step
+
1
)
*
sizeof
(
long
);
// Launching N blocks where N is number of samples in a batch (beams*bsz)
// Launching T threads where T is number of previous ngrams in a sample
// Allocating shared mem per block for fastser access of input tokens since
// each token will be accessed N times to compare with current Ngram where
// N is Ngram size.
banRepeatedTokens
<<<
blocks
,
threads
,
shared_mem_size
>>>
(
token_ptr
,
lprob_ptr
,
max_predict_len
,
vocab_size
,
no_repeat_ngram_size
);
return
lprobs
;
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libbase/balanced_assignment.cpp
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
C++ code for solving the linear assignment problem.
Based on the Auction Algorithm from https://dspace.mit.edu/bitstream/handle/1721.1/3265/P-2108-26912652.pdf and the implementation from:
https://github.com/bkj/auction-lap
Adapted to be more efficient when each worker is looking for k jobs instead of 1.
*/
#include <torch/extension.h>
#include <iostream>
using
namespace
torch
::
indexing
;
torch
::
Tensor
balanced_assignment
(
torch
::
Tensor
job_and_worker_to_score
)
{
int
max_iterations
=
100
;
torch
::
Tensor
epsilon
=
(
job_and_worker_to_score
.
max
()
-
job_and_worker_to_score
.
min
())
/
50
;
epsilon
.
clamp_min_
(
1e-04
);
torch
::
Tensor
worker_and_job_to_score
=
job_and_worker_to_score
.
detach
().
transpose
(
0
,
1
).
contiguous
();
int
num_workers
=
worker_and_job_to_score
.
size
(
0
);
int
num_jobs
=
worker_and_job_to_score
.
size
(
1
);
auto
device
=
worker_and_job_to_score
.
device
();
int
jobs_per_worker
=
num_jobs
/
num_workers
;
torch
::
Tensor
value
=
worker_and_job_to_score
.
clone
();
int
counter
=
0
;
torch
::
Tensor
max_value
=
worker_and_job_to_score
.
max
();
torch
::
Tensor
bid_indices
;
torch
::
Tensor
cost
=
worker_and_job_to_score
.
new_zeros
({
1
,
num_jobs
});
torch
::
Tensor
bids
=
worker_and_job_to_score
.
new_empty
({
num_workers
,
num_jobs
});
torch
::
Tensor
bid_increments
=
worker_and_job_to_score
.
new_empty
({
num_workers
,
jobs_per_worker
});
torch
::
Tensor
top_values
=
worker_and_job_to_score
.
new_empty
({
num_workers
,
jobs_per_worker
+
1
});
torch
::
Tensor
high_bids
=
worker_and_job_to_score
.
new_empty
({
num_jobs
});
torch
::
Tensor
top_index
=
top_values
.
to
(
torch
::
kLong
);
torch
::
Tensor
high_bidders
=
top_index
.
new_empty
({
num_jobs
});
torch
::
Tensor
have_bids
=
high_bidders
.
to
(
torch
::
kBool
);
torch
::
Tensor
jobs_indices
=
torch
::
arange
({
num_jobs
},
torch
::
dtype
(
torch
::
kLong
).
device
(
device
));
torch
::
Tensor
true_tensor
=
torch
::
ones
({
1
},
torch
::
dtype
(
torch
::
kBool
).
device
(
device
));
while
(
true
)
{
bids
.
zero_
();
torch
::
topk_out
(
top_values
,
top_index
,
value
,
jobs_per_worker
+
1
,
1
);
// Each worker bids the difference in value between that job and the k+1th job
torch
::
sub_out
(
bid_increments
,
top_values
.
index
({
Slice
(
None
,
None
),
Slice
(
0
,
jobs_per_worker
)}),
top_values
.
index
({
Slice
(
None
,
None
),
jobs_per_worker
}).
unsqueeze
(
1
));
bid_increments
.
add_
(
epsilon
);
bids
.
scatter_
(
1
,
top_index
.
index
({
Slice
(
None
,
None
),
Slice
(
0
,
jobs_per_worker
)}),
bid_increments
);
if
(
counter
<
max_iterations
&&
counter
>
0
)
{
// Put in a minimal bid to retain items from the last round if no-one else bids for them this round
bids
.
view
(
-
1
).
index_put_
({
bid_indices
},
epsilon
);
}
// Find the highest bidding worker per job
torch
::
max_out
(
high_bids
,
high_bidders
,
bids
,
0
);
torch
::
gt_out
(
have_bids
,
high_bids
,
0
);
if
(
have_bids
.
all
().
item
<
bool
>
())
{
// All jobs were bid for
break
;
}
// Make popular items more expensive
cost
.
add_
(
high_bids
);
torch
::
sub_out
(
value
,
worker_and_job_to_score
,
cost
);
bid_indices
=
((
high_bidders
*
num_jobs
)
+
jobs_indices
).
index
({
have_bids
});
if
(
counter
<
max_iterations
)
{
// Make sure that this item will be in the winning worker's top-k next time.
value
.
view
(
-
1
).
index_put_
({
bid_indices
},
max_value
);
}
else
{
// Suboptimal approximation that converges quickly from current solution
value
.
view
(
-
1
).
index_put_
({
bid_indices
},
worker_and_job_to_score
.
view
(
-
1
).
index
({
bid_indices
}));
}
counter
+=
1
;
}
return
top_index
.
index
({
Slice
(
None
,
None
),
Slice
(
0
,
jobs_per_worker
)}).
reshape
(
-
1
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"balanced_assignment"
,
&
balanced_assignment
,
"Balanced Assignment"
);
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libbleu/libbleu.cpp
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <map>
#include <array>
#include <cstring>
#include <cstdio>
typedef
struct
{
size_t
reflen
;
size_t
predlen
;
size_t
match1
;
size_t
count1
;
size_t
match2
;
size_t
count2
;
size_t
match3
;
size_t
count3
;
size_t
match4
;
size_t
count4
;
}
bleu_stat
;
// left trim (remove pad)
void
bleu_ltrim
(
size_t
*
len
,
int
**
sent
,
int
pad
)
{
size_t
start
=
0
;
while
(
start
<
*
len
)
{
if
(
*
(
*
sent
+
start
)
!=
pad
)
{
break
;
}
start
++
;
}
*
sent
+=
start
;
*
len
-=
start
;
}
// right trim remove (eos)
void
bleu_rtrim
(
size_t
*
len
,
int
**
sent
,
int
pad
,
int
eos
)
{
size_t
end
=
*
len
-
1
;
while
(
end
>
0
)
{
if
(
*
(
*
sent
+
end
)
!=
eos
&&
*
(
*
sent
+
end
)
!=
pad
)
{
break
;
}
end
--
;
}
*
len
=
end
+
1
;
}
// left and right trim
void
bleu_trim
(
size_t
*
len
,
int
**
sent
,
int
pad
,
int
eos
)
{
bleu_ltrim
(
len
,
sent
,
pad
);
bleu_rtrim
(
len
,
sent
,
pad
,
eos
);
}
size_t
bleu_hash
(
int
len
,
int
*
data
)
{
size_t
h
=
14695981039346656037ul
;
size_t
prime
=
0x100000001b3
;
char
*
b
=
(
char
*
)
data
;
size_t
blen
=
sizeof
(
int
)
*
len
;
while
(
blen
--
>
0
)
{
h
^=
*
b
++
;
h
*=
prime
;
}
return
h
;
}
void
bleu_addngram
(
size_t
*
ntotal
,
size_t
*
nmatch
,
size_t
n
,
size_t
reflen
,
int
*
ref
,
size_t
predlen
,
int
*
pred
)
{
if
(
predlen
<
n
)
{
return
;
}
predlen
=
predlen
-
n
+
1
;
(
*
ntotal
)
+=
predlen
;
if
(
reflen
<
n
)
{
return
;
}
reflen
=
reflen
-
n
+
1
;
std
::
map
<
size_t
,
size_t
>
count
;
while
(
predlen
>
0
)
{
size_t
w
=
bleu_hash
(
n
,
pred
++
);
count
[
w
]
++
;
predlen
--
;
}
while
(
reflen
>
0
)
{
size_t
w
=
bleu_hash
(
n
,
ref
++
);
if
(
count
[
w
]
>
0
)
{
(
*
nmatch
)
++
;
count
[
w
]
-=
1
;
}
reflen
--
;
}
}
extern
"C"
{
#ifdef _WIN64
__declspec
(
dllexport
)
#endif
void
bleu_zero_init
(
bleu_stat
*
stat
)
{
std
::
memset
(
stat
,
0
,
sizeof
(
bleu_stat
));
}
#ifdef _WIN64
__declspec
(
dllexport
)
#endif
void
bleu_one_init
(
bleu_stat
*
stat
)
{
bleu_zero_init
(
stat
);
stat
->
count1
=
0
;
stat
->
count2
=
1
;
stat
->
count3
=
1
;
stat
->
count4
=
1
;
stat
->
match1
=
0
;
stat
->
match2
=
1
;
stat
->
match3
=
1
;
stat
->
match4
=
1
;
}
#ifdef _WIN64
__declspec
(
dllexport
)
#endif
void
bleu_add
(
bleu_stat
*
stat
,
size_t
reflen
,
int
*
ref
,
size_t
predlen
,
int
*
pred
,
int
pad
,
int
eos
)
{
bleu_trim
(
&
reflen
,
&
ref
,
pad
,
eos
);
bleu_trim
(
&
predlen
,
&
pred
,
pad
,
eos
);
stat
->
reflen
+=
reflen
;
stat
->
predlen
+=
predlen
;
bleu_addngram
(
&
stat
->
count1
,
&
stat
->
match1
,
1
,
reflen
,
ref
,
predlen
,
pred
);
bleu_addngram
(
&
stat
->
count2
,
&
stat
->
match2
,
2
,
reflen
,
ref
,
predlen
,
pred
);
bleu_addngram
(
&
stat
->
count3
,
&
stat
->
match3
,
3
,
reflen
,
ref
,
predlen
,
pred
);
bleu_addngram
(
&
stat
->
count4
,
&
stat
->
match4
,
4
,
reflen
,
ref
,
predlen
,
pred
);
}
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libbleu/module.cpp
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <Python.h>
static
PyMethodDef
method_def
[]
=
{
{
NULL
,
NULL
,
0
,
NULL
}
};
static
struct
PyModuleDef
module_def
=
{
PyModuleDef_HEAD_INIT
,
"libbleu"
,
/* name of module */
NULL
,
/* module documentation, may be NULL */
-
1
,
/* size of per-interpreter state of the module,
or -1 if the module keeps state in global variables. */
method_def
};
#if PY_MAJOR_VERSION == 2
PyMODINIT_FUNC
init_libbleu
()
#else
PyMODINIT_FUNC
PyInit_libbleu
()
#endif
{
PyObject
*
m
=
PyModule_Create
(
&
module_def
);
if
(
!
m
)
{
return
NULL
;
}
return
m
;
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat/edit_dist.cpp
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <torch/torch.h> // @manual=//caffe2:torch_extension
#include <pybind11/detail/common.h>
#include <pybind11/pybind11.h>
#include <vector>
#include <algorithm>
#include <cstdint>
#include <iosfwd>
#include <memory>
#include <new>
#include <string>
#include <utility>
using
namespace
::
std
;
vector
<
vector
<
uint32_t
>>
edit_distance2_with_dp
(
vector
<
uint32_t
>&
x
,
vector
<
uint32_t
>&
y
)
{
uint32_t
lx
=
x
.
size
();
uint32_t
ly
=
y
.
size
();
vector
<
vector
<
uint32_t
>>
d
(
lx
+
1
,
vector
<
uint32_t
>
(
ly
+
1
));
for
(
uint32_t
i
=
0
;
i
<
lx
+
1
;
i
++
)
{
d
[
i
][
0
]
=
i
;
}
for
(
uint32_t
j
=
0
;
j
<
ly
+
1
;
j
++
)
{
d
[
0
][
j
]
=
j
;
}
for
(
uint32_t
i
=
1
;
i
<
lx
+
1
;
i
++
)
{
for
(
uint32_t
j
=
1
;
j
<
ly
+
1
;
j
++
)
{
d
[
i
][
j
]
=
min
(
min
(
d
[
i
-
1
][
j
],
d
[
i
][
j
-
1
])
+
1
,
d
[
i
-
1
][
j
-
1
]
+
2
*
(
x
.
at
(
i
-
1
)
==
y
.
at
(
j
-
1
)
?
0
:
1
));
}
}
return
d
;
}
vector
<
vector
<
uint32_t
>>
edit_distance2_backtracking
(
vector
<
vector
<
uint32_t
>>&
d
,
vector
<
uint32_t
>&
x
,
vector
<
uint32_t
>&
y
,
uint32_t
terminal_symbol
)
{
vector
<
uint32_t
>
seq
;
vector
<
vector
<
uint32_t
>>
edit_seqs
(
x
.
size
()
+
2
,
vector
<
uint32_t
>
());
/*
edit_seqs:
0~x.size() cell is the insertion sequences
last cell is the delete sequence
*/
if
(
x
.
size
()
==
0
)
{
edit_seqs
.
at
(
0
)
=
y
;
return
edit_seqs
;
}
uint32_t
i
=
d
.
size
()
-
1
;
uint32_t
j
=
d
.
at
(
0
).
size
()
-
1
;
while
((
i
>=
0
)
&&
(
j
>=
0
))
{
if
((
i
==
0
)
&&
(
j
==
0
))
{
break
;
}
if
((
j
>
0
)
&&
(
d
.
at
(
i
).
at
(
j
-
1
)
<
d
.
at
(
i
).
at
(
j
)))
{
seq
.
push_back
(
1
);
// insert
seq
.
push_back
(
y
.
at
(
j
-
1
));
j
--
;
}
else
if
((
i
>
0
)
&&
(
d
.
at
(
i
-
1
).
at
(
j
)
<
d
.
at
(
i
).
at
(
j
)))
{
seq
.
push_back
(
2
);
// delete
seq
.
push_back
(
x
.
at
(
i
-
1
));
i
--
;
}
else
{
seq
.
push_back
(
3
);
// keep
seq
.
push_back
(
x
.
at
(
i
-
1
));
i
--
;
j
--
;
}
}
uint32_t
prev_op
,
op
,
s
,
word
;
prev_op
=
0
,
s
=
0
;
for
(
uint32_t
k
=
0
;
k
<
seq
.
size
()
/
2
;
k
++
)
{
op
=
seq
.
at
(
seq
.
size
()
-
2
*
k
-
2
);
word
=
seq
.
at
(
seq
.
size
()
-
2
*
k
-
1
);
if
(
prev_op
!=
1
)
{
s
++
;
}
if
(
op
==
1
)
// insert
{
edit_seqs
.
at
(
s
-
1
).
push_back
(
word
);
}
else
if
(
op
==
2
)
// delete
{
edit_seqs
.
at
(
x
.
size
()
+
1
).
push_back
(
1
);
}
else
{
edit_seqs
.
at
(
x
.
size
()
+
1
).
push_back
(
0
);
}
prev_op
=
op
;
}
for
(
uint32_t
k
=
0
;
k
<
edit_seqs
.
size
();
k
++
)
{
if
(
edit_seqs
[
k
].
size
()
==
0
)
{
edit_seqs
[
k
].
push_back
(
terminal_symbol
);
}
}
return
edit_seqs
;
}
vector
<
vector
<
uint32_t
>>
edit_distance2_backtracking_with_delete
(
vector
<
vector
<
uint32_t
>>&
d
,
vector
<
uint32_t
>&
x
,
vector
<
uint32_t
>&
y
,
uint32_t
terminal_symbol
,
uint32_t
deletion_symbol
)
{
vector
<
uint32_t
>
seq
;
vector
<
vector
<
uint32_t
>>
edit_seqs
(
x
.
size
()
+
1
,
vector
<
uint32_t
>
());
/*
edit_seqs:
0~x.size() cell is the insertion sequences
last cell is the delete sequence
*/
if
(
x
.
size
()
==
0
)
{
edit_seqs
.
at
(
0
)
=
y
;
return
edit_seqs
;
}
uint32_t
i
=
d
.
size
()
-
1
;
uint32_t
j
=
d
.
at
(
0
).
size
()
-
1
;
while
((
i
>=
0
)
&&
(
j
>=
0
))
{
if
((
i
==
0
)
&&
(
j
==
0
))
{
break
;
}
if
((
j
>
0
)
&&
(
d
.
at
(
i
).
at
(
j
-
1
)
<
d
.
at
(
i
).
at
(
j
)))
{
seq
.
push_back
(
1
);
// insert
seq
.
push_back
(
y
.
at
(
j
-
1
));
j
--
;
}
else
if
((
i
>
0
)
&&
(
d
.
at
(
i
-
1
).
at
(
j
)
<
d
.
at
(
i
).
at
(
j
)))
{
seq
.
push_back
(
2
);
// delete
seq
.
push_back
(
x
.
at
(
i
-
1
));
i
--
;
}
else
{
seq
.
push_back
(
3
);
// keep
seq
.
push_back
(
x
.
at
(
i
-
1
));
i
--
;
j
--
;
}
}
uint32_t
prev_op
,
op
,
s
,
word
;
prev_op
=
0
,
s
=
0
;
for
(
uint32_t
k
=
0
;
k
<
seq
.
size
()
/
2
;
k
++
)
{
op
=
seq
.
at
(
seq
.
size
()
-
2
*
k
-
2
);
word
=
seq
.
at
(
seq
.
size
()
-
2
*
k
-
1
);
if
(
prev_op
!=
1
)
{
s
++
;
}
if
(
op
==
1
)
// insert
{
edit_seqs
.
at
(
s
-
1
).
push_back
(
word
);
}
else
if
(
op
==
2
)
// delete
{
edit_seqs
.
at
(
s
-
1
).
push_back
(
deletion_symbol
);
}
prev_op
=
op
;
}
for
(
uint32_t
k
=
0
;
k
<
edit_seqs
.
size
();
k
++
)
{
if
(
edit_seqs
.
at
(
k
).
size
()
==
0
)
{
edit_seqs
.
at
(
k
).
push_back
(
terminal_symbol
);
}
}
return
edit_seqs
;
}
vector
<
uint32_t
>
compute_ed2
(
vector
<
vector
<
uint32_t
>>&
xs
,
vector
<
vector
<
uint32_t
>>&
ys
)
{
vector
<
uint32_t
>
distances
(
xs
.
size
());
for
(
uint32_t
i
=
0
;
i
<
xs
.
size
();
i
++
)
{
vector
<
vector
<
uint32_t
>>
d
=
edit_distance2_with_dp
(
xs
.
at
(
i
),
ys
.
at
(
i
));
distances
.
at
(
i
)
=
d
.
at
(
xs
.
at
(
i
).
size
()).
at
(
ys
.
at
(
i
).
size
());
}
return
distances
;
}
vector
<
vector
<
vector
<
uint32_t
>>>
suggested_ed2_path
(
vector
<
vector
<
uint32_t
>>&
xs
,
vector
<
vector
<
uint32_t
>>&
ys
,
uint32_t
terminal_symbol
)
{
vector
<
vector
<
vector
<
uint32_t
>>>
seq
(
xs
.
size
());
for
(
uint32_t
i
=
0
;
i
<
xs
.
size
();
i
++
)
{
vector
<
vector
<
uint32_t
>>
d
=
edit_distance2_with_dp
(
xs
.
at
(
i
),
ys
.
at
(
i
));
seq
.
at
(
i
)
=
edit_distance2_backtracking
(
d
,
xs
.
at
(
i
),
ys
.
at
(
i
),
terminal_symbol
);
}
return
seq
;
}
vector
<
vector
<
vector
<
uint32_t
>>>
suggested_ed2_path_with_delete
(
vector
<
vector
<
uint32_t
>>&
xs
,
vector
<
vector
<
uint32_t
>>&
ys
,
uint32_t
terminal_symbol
,
uint32_t
deletion_symbol
)
{
vector
<
vector
<
vector
<
uint32_t
>>>
seq
(
xs
.
size
());
for
(
uint32_t
i
=
0
;
i
<
xs
.
size
();
i
++
)
{
vector
<
vector
<
uint32_t
>>
d
=
edit_distance2_with_dp
(
xs
.
at
(
i
),
ys
.
at
(
i
));
seq
.
at
(
i
)
=
edit_distance2_backtracking_with_delete
(
d
,
xs
.
at
(
i
),
ys
.
at
(
i
),
terminal_symbol
,
deletion_symbol
);
}
return
seq
;
}
PYBIND11_MODULE
(
libnat
,
m
)
{
m
.
def
(
"compute_ed2"
,
&
compute_ed2
,
"compute_ed2"
);
m
.
def
(
"suggested_ed2_path"
,
&
suggested_ed2_path
,
"suggested_ed2_path"
);
m
.
def
(
"suggested_ed2_path_with_delete"
,
&
suggested_ed2_path_with_delete
,
"suggested_ed2_path_with_delete"
);
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat_cuda/binding.cpp
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
/*
This code is partially adpoted from https://github.com/1ytic/pytorch-edit-distance
*/
#include "edit_dist.h"
#include <torch/types.h>
#ifndef TORCH_CHECK
#define TORCH_CHECK AT_CHECK
#endif
#define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
torch
::
Tensor
LevenshteinDistance
(
torch
::
Tensor
source
,
torch
::
Tensor
target
,
torch
::
Tensor
source_length
,
torch
::
Tensor
target_length
)
{
CHECK_INPUT
(
source
);
CHECK_INPUT
(
target
);
CHECK_INPUT
(
source_length
);
CHECK_INPUT
(
target_length
);
return
LevenshteinDistanceCuda
(
source
,
target
,
source_length
,
target_length
);
}
torch
::
Tensor
GenerateDeletionLabel
(
torch
::
Tensor
source
,
torch
::
Tensor
operations
)
{
CHECK_INPUT
(
source
);
CHECK_INPUT
(
operations
);
return
GenerateDeletionLabelCuda
(
source
,
operations
);
}
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
GenerateInsertionLabel
(
torch
::
Tensor
target
,
torch
::
Tensor
operations
)
{
CHECK_INPUT
(
target
);
CHECK_INPUT
(
operations
);
return
GenerateInsertionLabelCuda
(
target
,
operations
);
}
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"levenshtein_distance"
,
&
LevenshteinDistance
,
"Levenshtein distance"
);
m
.
def
(
"generate_deletion_labels"
,
&
GenerateDeletionLabel
,
"Generate Deletion Label"
);
m
.
def
(
"generate_insertion_labels"
,
&
GenerateInsertionLabel
,
"Generate Insertion Label"
);
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat_cuda/edit_dist.cu
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#include "edit_dist.h"
#include <THC/THC.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <device_launch_parameters.h>
#include <utility> // std::pair
template
<
typename
scalar_t
>
__global__
void
generate_deletion_label_kernel
(
const
scalar_t
*
__restrict__
source
,
const
size_t
source_size
,
const
size_t
operation_size
,
int
*
__restrict__
operations
,
int
*
__restrict__
labels
)
{
const
int
index
=
blockIdx
.
x
;
const
int
offset
=
index
*
operation_size
;
const
int
offset_label
=
index
*
source_size
;
for
(
int
i
=
0
;
i
<
source_size
;
i
++
)
{
labels
[
offset_label
+
i
]
=
0
;
}
int
k
=
0
;
for
(
int
i
=
0
;
i
<
operation_size
;
i
++
){
if
(
operations
[
offset
+
i
]
==
0
){
break
;
}
else
if
(
operations
[
offset
+
i
]
==
1
){
continue
;
}
else
{
labels
[
offset_label
+
k
]
=
3
-
operations
[
offset
+
i
];
k
++
;
}
}
}
template
<
typename
scalar_t
>
__global__
void
generate_insertion_label_kernel
(
const
scalar_t
*
__restrict__
target
,
const
size_t
target_size
,
const
size_t
operation_size
,
int
*
__restrict__
operations
,
int
*
__restrict__
labels
,
int
*
__restrict__
masks
)
{
const
int
index
=
blockIdx
.
x
;
const
int
offset
=
index
*
operation_size
;
const
int
offset_label
=
index
*
target_size
;
int
k
=
0
;
int
u
=
0
;
int
m
=
0
;
for
(
int
i
=
0
;
i
<
target_size
;
i
++
)
{
labels
[
offset_label
+
i
]
=
0
;
masks
[
offset_label
+
i
]
=
0
;
}
for
(
int
i
=
0
;
i
<
operation_size
-
1
;
i
++
){
if
(
operations
[
offset
+
i
]
==
0
){
break
;
}
else
if
(
operations
[
offset
+
i
]
==
2
){
continue
;
}
else
if
(
operations
[
offset
+
i
]
==
1
){
masks
[
offset_label
+
m
]
=
1
;
u
++
;
m
++
;
}
else
{
labels
[
offset_label
+
k
]
=
u
;
masks
[
offset_label
+
m
]
=
0
;
k
++
;
m
++
;
u
=
0
;
}
}
}
template
<
typename
scalar_t
>
__global__
void
levenshtein_distance_kernel
(
const
scalar_t
*
__restrict__
source
,
const
scalar_t
*
__restrict__
target
,
const
int
*
__restrict__
source_length
,
const
int
*
__restrict__
target_length
,
const
size_t
source_size
,
const
size_t
target_size
,
int
*
__restrict__
operations
,
int
*
__restrict__
errors_curr
)
{
const
int
index
=
blockIdx
.
x
;
const
int
offset
=
index
*
(
source_size
+
target_size
);
const
int
d
=
index
*
(
source_size
+
1
)
*
(
target_size
+
1
);
const
int
t
=
target_size
+
1
;
auto
err_idx
=
[
d
,
t
](
int
i
,
int
j
)
{
return
d
+
i
*
t
+
j
;
};
auto
opt_idx
=
[
offset
](
int
k
)
{
return
offset
+
k
;
};
const
int
hyp_len
=
source_length
[
index
];
const
int
ref_len
=
target_length
[
index
];
const
scalar_t
*
hyp_begin
=
source
+
index
*
source_size
;
const
scalar_t
*
ref_begin
=
target
+
index
*
target_size
;
// dynamic programming
for
(
int
i
=
0
;
i
<=
hyp_len
;
i
++
){
errors_curr
[
err_idx
(
i
,
0
)]
=
i
;
}
for
(
int
j
=
0
;
j
<=
ref_len
;
j
++
){
errors_curr
[
err_idx
(
0
,
j
)]
=
j
;
}
for
(
int
i
=
1
;
i
<=
hyp_len
;
i
++
){
for
(
int
j
=
1
;
j
<=
ref_len
;
j
++
){
errors_curr
[
err_idx
(
i
,
j
)]
=
min
(
min
(
errors_curr
[
err_idx
(
i
-
1
,
j
)],
errors_curr
[
err_idx
(
i
,
j
-
1
)]
)
+
1
,
errors_curr
[
err_idx
(
i
-
1
,
j
-
1
)]
+
2
*
(
*
(
hyp_begin
+
i
-
1
)
==
*
(
ref_begin
+
j
-
1
)
?
0
:
1
)
);
}
}
// back-tracing
int
i
=
hyp_len
;
int
j
=
ref_len
;
int
o
=
hyp_len
+
ref_len
;
for
(
int
k
=
0
;
k
<
source_size
+
target_size
;
k
++
)
{
operations
[
opt_idx
(
k
)]
=
0
;
}
while
((
i
>=
0
)
&&
(
j
>=
0
))
{
if
((
i
==
0
)
&&
(
j
==
0
))
{
break
;
}
if
((
j
>
0
)
&&
(
errors_curr
[
err_idx
(
i
,
j
-
1
)]
<
errors_curr
[
err_idx
(
i
,
j
)]))
{
o
--
;
operations
[
opt_idx
(
o
)]
=
1
;
j
--
;
// insertion
}
else
if
((
i
>
0
)
&&
(
errors_curr
[
err_idx
(
i
-
1
,
j
)]
<
errors_curr
[
err_idx
(
i
,
j
)]))
{
o
--
;
operations
[
opt_idx
(
o
)]
=
2
;
i
--
;
// deletion
}
else
{
o
--
;
operations
[
opt_idx
(
o
)]
=
3
;
i
--
;
j
--
;
// do nothing
}
}
// moving to the left
for
(
int
k
=
0
;
k
<
hyp_len
+
ref_len
;
k
++
)
{
if
(
k
+
o
<
hyp_len
+
ref_len
){
operations
[
opt_idx
(
k
)]
=
operations
[
opt_idx
(
k
+
o
)];
}
else
{
operations
[
opt_idx
(
k
)]
=
0
;
// padding
}
}
}
template
<
typename
scalar_t
>
__global__
void
faster_levenshtein_distance_kernel
(
const
scalar_t
*
__restrict__
source
,
const
scalar_t
*
__restrict__
target
,
const
int
*
__restrict__
source_length
,
const
int
*
__restrict__
target_length
,
const
size_t
source_size
,
const
size_t
target_size
,
int
*
__restrict__
operations
)
{
extern
__shared__
short
errors
[];
auto
errors_curr
=
errors
;
const
int
index
=
blockIdx
.
x
;
const
int
offset
=
index
*
(
source_size
+
target_size
);
const
int
t
=
target_size
+
1
;
auto
err_idx
=
[
t
](
int
i
,
int
j
)
{
return
i
*
t
+
j
;
};
auto
opt_idx
=
[
offset
](
int
k
)
{
return
offset
+
k
;
};
const
int
hyp_len
=
source_length
[
index
];
const
int
ref_len
=
target_length
[
index
];
const
scalar_t
*
hyp_begin
=
source
+
index
*
source_size
;
const
scalar_t
*
ref_begin
=
target
+
index
*
target_size
;
// dynamic programming
for
(
int
i
=
0
;
i
<=
hyp_len
;
i
++
){
errors_curr
[
err_idx
(
i
,
0
)]
=
i
;
}
for
(
int
j
=
0
;
j
<=
ref_len
;
j
++
){
errors_curr
[
err_idx
(
0
,
j
)]
=
j
;
}
for
(
int
i
=
1
;
i
<=
hyp_len
;
i
++
){
for
(
int
j
=
1
;
j
<=
ref_len
;
j
++
){
errors_curr
[
err_idx
(
i
,
j
)]
=
min
(
min
(
errors_curr
[
err_idx
(
i
-
1
,
j
)],
errors_curr
[
err_idx
(
i
,
j
-
1
)]
)
+
1
,
errors_curr
[
err_idx
(
i
-
1
,
j
-
1
)]
+
2
*
(
*
(
hyp_begin
+
i
-
1
)
==
*
(
ref_begin
+
j
-
1
)
?
0
:
1
)
);
}
}
// back-tracing
int
i
=
hyp_len
;
int
j
=
ref_len
;
int
o
=
hyp_len
+
ref_len
;
for
(
int
k
=
0
;
k
<
source_size
+
target_size
;
k
++
)
{
operations
[
opt_idx
(
k
)]
=
0
;
}
while
((
i
>=
0
)
&&
(
j
>=
0
))
{
if
((
i
==
0
)
&&
(
j
==
0
))
{
break
;
}
if
((
j
>
0
)
&&
(
errors_curr
[
err_idx
(
i
,
j
-
1
)]
<
errors_curr
[
err_idx
(
i
,
j
)]))
{
o
--
;
operations
[
opt_idx
(
o
)]
=
1
;
j
--
;
// insertion
}
else
if
((
i
>
0
)
&&
(
errors_curr
[
err_idx
(
i
-
1
,
j
)]
<
errors_curr
[
err_idx
(
i
,
j
)]))
{
o
--
;
operations
[
opt_idx
(
o
)]
=
2
;
i
--
;
// deletion
}
else
{
o
--
;
operations
[
opt_idx
(
o
)]
=
3
;
i
--
;
j
--
;
// do nothing
}
}
// moving to the left
for
(
int
k
=
0
;
k
<
hyp_len
+
ref_len
;
k
++
)
{
if
(
k
+
o
<
hyp_len
+
ref_len
){
operations
[
opt_idx
(
k
)]
=
operations
[
opt_idx
(
k
+
o
)];
}
else
{
operations
[
opt_idx
(
k
)]
=
0
;
// padding
}
}
}
torch
::
Tensor
GenerateDeletionLabelCuda
(
torch
::
Tensor
source
,
torch
::
Tensor
operations
)
{
const
auto
batch_size
=
source
.
size
(
0
);
at
::
TensorOptions
options
(
source
.
device
());
options
=
options
.
dtype
(
at
::
ScalarType
::
Int
);
auto
labels
=
torch
::
empty
({
batch_size
,
source
.
size
(
1
)},
options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
source
.
device
().
index
());
AT_DISPATCH_ALL_TYPES
(
source
.
scalar_type
(),
"generate_deletion_labels"
,
([
&
]
{
generate_deletion_label_kernel
<
scalar_t
><<<
batch_size
,
1
,
0
,
stream
>>>
(
source
.
data_ptr
<
scalar_t
>
(),
source
.
size
(
1
),
operations
.
size
(
1
),
operations
.
data_ptr
<
int
>
(),
labels
.
data_ptr
<
int
>
());
}));
return
labels
;
}
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
GenerateInsertionLabelCuda
(
torch
::
Tensor
target
,
torch
::
Tensor
operations
)
{
const
auto
batch_size
=
target
.
size
(
0
);
at
::
TensorOptions
options
(
target
.
device
());
options
=
options
.
dtype
(
at
::
ScalarType
::
Int
);
auto
labels
=
torch
::
empty
({
batch_size
,
target
.
size
(
1
)},
options
);
auto
masks
=
torch
::
empty
({
batch_size
,
target
.
size
(
1
)},
options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
target
.
device
().
index
());
AT_DISPATCH_ALL_TYPES
(
target
.
scalar_type
(),
"generate_insertion_labels"
,
([
&
]
{
generate_insertion_label_kernel
<
scalar_t
><<<
batch_size
,
1
,
0
,
stream
>>>
(
target
.
data_ptr
<
scalar_t
>
(),
target
.
size
(
1
),
operations
.
size
(
1
),
operations
.
data_ptr
<
int
>
(),
labels
.
data_ptr
<
int
>
(),
masks
.
data_ptr
<
int
>
());
}));
return
std
::
make_pair
(
labels
,
masks
);
}
torch
::
Tensor
LevenshteinDistanceCuda
(
torch
::
Tensor
source
,
torch
::
Tensor
target
,
torch
::
Tensor
source_length
,
torch
::
Tensor
target_length
)
{
const
auto
batch_size
=
source
.
size
(
0
);
const
auto
shared_size
=
(
source
.
size
(
1
)
+
1
)
*
(
target
.
size
(
1
)
+
1
)
*
sizeof
(
short
);
at
::
TensorOptions
options
(
source
.
device
());
options
=
options
.
dtype
(
at
::
ScalarType
::
Int
);
auto
operations
=
torch
::
empty
({
batch_size
,
source
.
size
(
1
)
+
target
.
size
(
1
)},
options
);
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
(
source
.
device
().
index
());
if
(
shared_size
>
40000
)
{
auto
distances
=
torch
::
empty
({
batch_size
,
(
source
.
size
(
1
)
+
1
)
*
(
target
.
size
(
1
)
+
1
)},
options
);
AT_DISPATCH_ALL_TYPES
(
source
.
scalar_type
(),
"levenshtein_distance"
,
([
&
]
{
levenshtein_distance_kernel
<
scalar_t
><<<
batch_size
,
1
,
0
,
stream
>>>
(
source
.
data_ptr
<
scalar_t
>
(),
target
.
data_ptr
<
scalar_t
>
(),
source_length
.
data_ptr
<
int
>
(),
target_length
.
data_ptr
<
int
>
(),
source
.
size
(
1
),
target
.
size
(
1
),
operations
.
data_ptr
<
int
>
(),
distances
.
data_ptr
<
int
>
());
}));
}
else
{
AT_DISPATCH_ALL_TYPES
(
source
.
scalar_type
(),
"faster_levenshtein_distance"
,
([
&
]
{
faster_levenshtein_distance_kernel
<
scalar_t
><<<
batch_size
,
1
,
shared_size
,
stream
>>>
(
source
.
data_ptr
<
scalar_t
>
(),
target
.
data_ptr
<
scalar_t
>
(),
source_length
.
data_ptr
<
int
>
(),
target_length
.
data_ptr
<
int
>
(),
source
.
size
(
1
),
target
.
size
(
1
),
operations
.
data_ptr
<
int
>
());
}));
}
return
operations
;
}
third_party/seed-tts-eval/thirdparty/UniSpeech/src/fairseq/clib/libnat_cuda/edit_dist.h
0 → 100644
View file @
39ac40a9
/**
* Copyright 2017-present, Facebook, Inc.
* All rights reserved.
*
* This source code is licensed under the license found in the
* LICENSE file in the root directory of this source tree.
*/
#pragma once
#include <torch/extension.h>
torch
::
Tensor
LevenshteinDistanceCuda
(
torch
::
Tensor
source
,
torch
::
Tensor
target
,
torch
::
Tensor
source_length
,
torch
::
Tensor
target_length
);
torch
::
Tensor
GenerateDeletionLabelCuda
(
torch
::
Tensor
source
,
torch
::
Tensor
operations
);
std
::
pair
<
torch
::
Tensor
,
torch
::
Tensor
>
GenerateInsertionLabelCuda
(
torch
::
Tensor
source
,
torch
::
Tensor
operations
);
Prev
1
…
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment