Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
VITA-Audio_pytorch
Commits
39ac40a9
Commit
39ac40a9
authored
Jun 06, 2025
by
chenzk
Browse files
v1.0
parents
Pipeline
#2747
failed with stages
in 0 seconds
Changes
427
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1053 additions
and
0 deletions
+1053
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gbw.yaml
...dparty/UniSpeech/src/config/model/transformer_lm_gbw.yaml
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt.yaml
...dparty/UniSpeech/src/config/model/transformer_lm_gpt.yaml
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt2_big.yaml
...y/UniSpeech/src/config/model/transformer_lm_gpt2_big.yaml
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt2_medium.yaml
...niSpeech/src/config/model/transformer_lm_gpt2_medium.yaml
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt2_small.yaml
...UniSpeech/src/config/model/transformer_lm_gpt2_small.yaml
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_wiki103.yaml
...ty/UniSpeech/src/config/model/transformer_lm_wiki103.yaml
+36
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/optimizer/adam.yaml
...-eval/thirdparty/UniSpeech/src/config/optimizer/adam.yaml
+5
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/optimizer/nag.yaml
...s-eval/thirdparty/UniSpeech/src/config/optimizer/nag.yaml
+3
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/task/language_modeling.yaml
...irdparty/UniSpeech/src/config/task/language_modeling.yaml
+10
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/.gitignore
...eed-tts-eval/thirdparty/UniSpeech/src/examples/.gitignore
+2
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/__init__.py
...ed-tts-eval/thirdparty/UniSpeech/src/examples/__init__.py
+9
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/scripts/ctc_decode.sh
...party/UniSpeech/src/examples/hubert/scripts/ctc_decode.sh
+9
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/scripts/finetune.sh
...rdparty/UniSpeech/src/examples/hubert/scripts/finetune.sh
+9
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/scripts/pretrain.sh
...rdparty/UniSpeech/src/examples/hubert/scripts/pretrain.sh
+81
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/README.md
...rty/UniSpeech/src/examples/hubert/simple_kmeans/README.md
+71
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_hubert_feature.py
.../src/examples/hubert/simple_kmeans/dump_hubert_feature.py
+152
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py
.../examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py
+126
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_km_label.py
...Speech/src/examples/hubert/simple_kmeans/dump_km_label.py
+98
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_mfcc_feature.py
...ch/src/examples/hubert/simple_kmeans/dump_mfcc_feature.py
+116
-0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/learn_kmeans.py
...iSpeech/src/examples/hubert/simple_kmeans/learn_kmeans.py
+146
-0
No files found.
Too many changes to show.
To preserve performance only
427 of 427+
files are displayed.
Plain diff
Email patch
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gbw.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
activation_fn
:
"
relu"
dropout
:
0.1
attention_dropout
:
0.1
activation_dropout
:
0.0
relu_dropout
:
0.0
decoder_embed_dim
:
512
decoder_output_dim
:
512
decoder_input_dim
:
512
decoder_ffn_embed_dim
:
4096
decoder_layers
:
12
decoder_attention_heads
:
16
decoder_normalize_before
:
true
no_decoder_final_norm
:
true
adaptive_softmax_cutoff
:
null
adaptive_softmax_dropout
:
0
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
false
adaptive_input_factor
:
4
adaptive_input_cutoff
:
null
tie_adaptive_weights
:
false
tie_adaptive_proj
:
false
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
activation_fn
:
"
gelu"
dropout
:
0.1
attention_dropout
:
0.1
activation_dropout
:
0.0
relu_dropout
:
0.0
decoder_embed_dim
:
768
decoder_output_dim
:
768
decoder_input_dim
:
768
decoder_ffn_embed_dim
:
3072
decoder_layers
:
12
decoder_attention_heads
:
12
decoder_normalize_before
:
true
no_decoder_final_norm
:
false
adaptive_softmax_cutoff
:
null
adaptive_softmax_dropout
:
0
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
false
adaptive_input_factor
:
4
adaptive_input_cutoff
:
null
tie_adaptive_weights
:
false
tie_adaptive_proj
:
false
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt2_big.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
activation_fn
:
"
gelu"
dropout
:
0.1
attention_dropout
:
0.1
activation_dropout
:
0.0
relu_dropout
:
0.0
decoder_embed_dim
:
1600
decoder_output_dim
:
1600
decoder_input_dim
:
1600
decoder_ffn_embed_dim
:
6400
decoder_layers
:
48
decoder_attention_heads
:
25
decoder_normalize_before
:
true
no_decoder_final_norm
:
false
adaptive_softmax_cutoff
:
null
adaptive_softmax_dropout
:
0
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
false
adaptive_input_factor
:
4
adaptive_input_cutoff
:
null
tie_adaptive_weights
:
false
tie_adaptive_proj
:
false
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt2_medium.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
activation_fn
:
"
gelu"
dropout
:
0.1
attention_dropout
:
0.1
activation_dropout
:
0.0
relu_dropout
:
0.0
decoder_embed_dim
:
1280
decoder_output_dim
:
1280
decoder_input_dim
:
1280
decoder_ffn_embed_dim
:
5120
decoder_layers
:
36
decoder_attention_heads
:
20
decoder_normalize_before
:
true
no_decoder_final_norm
:
false
adaptive_softmax_cutoff
:
null
adaptive_softmax_dropout
:
0
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
false
adaptive_input_factor
:
4
adaptive_input_cutoff
:
null
tie_adaptive_weights
:
false
tie_adaptive_proj
:
false
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_gpt2_small.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
activation_fn
:
"
gelu"
dropout
:
0.1
attention_dropout
:
0.1
activation_dropout
:
0.0
relu_dropout
:
0.0
decoder_embed_dim
:
1024
decoder_output_dim
:
1024
decoder_input_dim
:
1024
decoder_ffn_embed_dim
:
4096
decoder_layers
:
24
decoder_attention_heads
:
16
decoder_normalize_before
:
true
no_decoder_final_norm
:
false
adaptive_softmax_cutoff
:
null
adaptive_softmax_dropout
:
0
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
false
adaptive_input_factor
:
4
adaptive_input_cutoff
:
null
tie_adaptive_weights
:
false
tie_adaptive_proj
:
false
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/model/transformer_lm_wiki103.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
activation_fn
:
"
relu"
dropout
:
0.3
attention_dropout
:
0.1
activation_dropout
:
0.1
relu_dropout
:
0.1
decoder_embed_dim
:
1024
decoder_output_dim
:
1024
decoder_input_dim
:
1024
decoder_ffn_embed_dim
:
4096
decoder_layers
:
16
decoder_attention_heads
:
8
decoder_normalize_before
:
true
no_decoder_final_norm
:
true
adaptive_softmax_cutoff
:
"
20000,60000"
adaptive_softmax_dropout
:
0.2
adaptive_softmax_factor
:
4
no_token_positional_embeddings
:
false
share_decoder_input_output_embed
:
false
character_embeddings
:
false
character_filters
:
"
[(1,
64),
(2,
128),
(3,
192),
(4,
256),
(5,
256),
(6,
256),
(7,
256)]"
character_embedding_dim
:
4
char_embedder_highway_layers
:
2
adaptive_input
:
true
adaptive_input_factor
:
4
adaptive_input_cutoff
:
"
20000,60000"
tie_adaptive_weights
:
true
tie_adaptive_proj
:
true
decoder_learned_pos
:
false
decoder_layerdrop
:
0
decoder_layers_to_keep
:
null
layernorm_embedding
:
false
no_scale_embedding
:
false
quant_noise_pq
:
0
quant_noise_pq_block_size
:
8
quant_noise_scalar
:
0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/optimizer/adam.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
adam_betas
:
"
(0.9,
0.999)"
adam_eps
:
1.0e-8
weight_decay
:
0
use_old_adam
:
false
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/optimizer/nag.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
momentum
:
0.99
weight_decay
:
0.0
third_party/seed-tts-eval/thirdparty/UniSpeech/src/config/task/language_modeling.yaml
0 → 100644
View file @
39ac40a9
# @package _group_
data
:
???
sample_break_mode
:
"
none"
tokens_per_sample
:
1024
output_dictionary_size
:
-1
self_target
:
false
future_target
:
false
past_target
:
false
add_bos_token
:
false
max_target_positions
:
null
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/.gitignore
0 → 100644
View file @
39ac40a9
!*/*.sh
!*/*.md
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/__init__.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
try
:
from
fairseq.version
import
__version__
# noqa
except
ImportError
:
pass
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/scripts/ctc_decode.sh
0 → 100644
View file @
39ac40a9
model_path
=
MODEL_PATH
gen_subset
=
test_clean
result_path
=
${
model_path
}
/decode_ctc/
${
gen_subset
}
mkdir
-p
${
result_path
}
export
PYTHONENCODING
=
UTF-8
python examples/speech_recognition/infer.py DATA_PATH
--task
audio_pretraining
--nbest
1
--path
${
model_path
}
/checkpoint_best.pt
--gen-subset
${
gen_subset
}
--results-path
${
result_path
}
--w2l-decoder
viterbi
--word-score
-1
--sil-weight
0
--criterion
ctc
--max-tokens
1100000
--dict-path
DICT_PATH
--post-process
letter
--quiet
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/scripts/finetune.sh
0 → 100644
View file @
39ac40a9
model_path
=
MODEL_PATH
train_subset
=
train_clean_100
valid_subset
=
dev_other
mkdir
-p
${
model_path
}
python train.py
--distributed-world-size
8
--distributed-port
0
--nprocs-per-node
8 DATA_PATH
--save-dir
${
model_path
}
--post-process
letter
--train-subset
${
train_subset
}
--valid-subset
${
valid_subset
}
--no-epoch-checkpoints
--best-checkpoint-metric
wer
--num-workers
4
--max-update
80000
--sentence-avg
--task
hubert_pretraining
--fine-tuning
--single-target
--arch
hubert_ctc
--w2v-path
PRETRAINED_MODEL_PATH
'["ltr"]'
--apply-mask
--mask-selection
static
--mask-other
0
--mask-length
10
--mask-prob
0.65
--layerdrop
0.1
--mask-channel-selection
static
--mask-channel-other
0
--mask-channel-length
64
--mask-channel-prob
0.5
--zero-infinity
--feature-grad-mult
0
--freeze-finetune-updates
0
--optimizer
adam
--adam-betas
'(0.9, 0.98)'
--adam-eps
1e-08
--lr
0.00003
--lr-scheduler
tri_stage
--warmup-steps
8000
--hold-steps
32000
--decay-steps
40000
--final-lr-scale
0.05
--final-dropout
0.1
--dropout
0.1
--activation-dropout
0.1
--criterion
ctc
--attention-dropout
0.1
--dropout-input
0.1
--max-tokens
3200000
--seed
2337
--log-format
json
--log-interval
200
--ddp-backend
c10d
--fp16
--update-freq
1
--keep-interval-updates
1
--find-unused-parameters
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/scripts/pretrain.sh
0 → 100644
View file @
39ac40a9
pip
install
torch_complex
model_path
=
MODEL_PATH
data_path
=
DATA_PATH
label_path
=
LABEL_PATH
train_subset
=
train_960
valid_subset
=
valid
distributed_world_size
=
WORLD_SIZE
update_freq
=
$((
32
/
$WORLD_SIZE
))
max_tokens
=
1400000
warmup_updates
=
32000
total_num_update
=
400000
mkdir
-p
${
model_path
}
python train.py
\
--ddp-backend
no_c10d
\
--distributed-backend
'nccl'
\
--distributed-world-size
${
distributed_world_size
}
\
--distributed-port
29671
\
--nprocs-per-node
8
\
--find-unused-parameters
\
--fp16
\
--log-format
json
\
--log-interval
200
\
--seed
1337
\
--save-dir
${
model_path
}
\
--save-interval-updates
5000
\
--keep-interval-updates
10
\
--no-epoch-checkpoints
\
--num-workers
6
\
--task
hubert_pretraining
\
--criterion
hubert
\
--arch
ils_hubert
\
--train-subset
${
train_subset
}
\
--valid-subset
${
valid_subset
}
\
--log-keys
'[]'
\
${
data_path
}
\
--label-dir
${
label_path
}
\
--labels
'["km"]'
\
--sample-rate
16000
\
--max-sample-size
250000
\
--min-sample-size
32000
\
--max-tokens
${
max_tokens
}
\
--skip-invalid-size-inputs-valid-test
\
--validate-interval
5
\
--validate-interval-updates
10000
\
--pred-masked-weight
1.0
\
--pred-nomask-weight
0.0
\
--loss-weights
[
10,]
\
--label-rate
50
\
--mask-prob
0.80
\
--extractor-mode
default
\
--conv-feature-layers
'[(512,10,5)] + [(512,3,2)] * 4 + [(512,2,2)] * 2'
\
--final-dim
256
\
--encoder-layerdrop
0.05
\
--dropout-input
0.1
\
--dropout-features
0.1
\
--dropout
0.1
\
--attention-dropout
0.1
\
--feature-grad-mult
0.1
\
--activation-dropout
0.0
\
--optimizer
adam
\
--adam-betas
'(0.9,0.98)'
\
--adam-eps
1e-06
\
--weight-decay
0.01
\
--lr-scheduler
polynomial_decay
\
--warmup-updates
${
warmup_updates
}
\
--total-num-update
${
total_num_update
}
\
--max-update
400000
\
--lr
0.0005
\
--clip-norm
10.0
\
--update-freq
${
update_freq
}
\
--predict-layers
"[4,12]"
\
--relative-position-embedding
\
--num-buckets
320
\
--max-distance
800
\
--required-batch-size-multiple
1
\
--separate-label-embeds
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/README.md
0 → 100644
View file @
39ac40a9
# Sharded Feature Extraction and K-means Application
This folder contains scripts for preparing HUBERT labels from tsv files, the
steps are:
1.
feature extraction
2.
k-means clustering
3.
k-means application
## Data preparation
`*.tsv`
files contains a list of audio, where each line is the root, and
following lines are the subpath for each audio:
```
<root-dir>
<audio-path-1>
<audio-path-2>
...
```
## Feature extraction
### MFCC feature
Suppose the tsv file is at
`${tsv_dir}/${split}.tsv`
. To extract 39-D
mfcc+delta+ddelta features for the 1st iteration HUBERT training, run:
```
sh
python dump_mfcc_feature.py
${
tsv_dir
}
${
split
}
${
nshard
}
${
rank
}
${
feat_dir
}
```
This would shard the tsv file into
`${nshard}`
and extract features for the
`${rank}`
-th shard, where rank is an integer in
`[0, nshard-1]`
. Features would
be saved at
`${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`
.
### HUBERT feature
To extract features from the
`${layer}`
-th transformer layer of a trained
HUBERT model saved at
`${ckpt_path}`
, run:
```
sh
python dump_hubert_feature.py
${
tsv_dir
}
${
split
}
${
ckpt_path
}
${
layer
}
${
nshard
}
${
rank
}
${
feat_dir
}
```
Features would also be saved at
`${feat_dir}/${split}_${rank}_${nshard}.{npy,len}`
.
-
if out-of-memory, decrease the chunk size with
`--max_chunk`
## K-means clustering
To fit a k-means model with
`${n_clusters}`
clusters on 10% of the
`${split}`
data, run
```
sh
python learn_kmeans.py
${
feat_dir
}
${
split
}
${
nshard
}
${
km_path
}
${
n_cluster
}
--percent
0.1
```
This saves the k-means model to
`${km_path}`
.
-
set
`--precent -1`
to use all data
-
more kmeans options can be found with
`-h`
flag
## K-means application
To apply a trained k-means model
`${km_path}`
to obtain labels for
`${split}`
, run
```
sh
python dump_km_label.py
${
feat_dir
}
${
split
}
${
km_path
}
${
nshard
}
${
rank
}
${
lab_dir
}
```
This would extract labels for the
`${rank}`
-th shard out of
`${nshard}`
shards
and dump them to
`${lab_dir}/${split}_${rank}_${shard}.km`
Finally, merge shards for
`${split}`
by running
```
sh
for
rank
in
$(
seq
0
$((
nshard
-
1
))
)
;
do
cat
$lab_dir
/
${
split
}
_
${
rank
}
_
${
nshard
}
.km
done
>
$lab_dir
/
${
split
}
.km
```
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_hubert_feature.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
pdb
import
io
import
logging
import
math
import
os
import
sys
sys
.
path
.
append
(
os
.
getcwd
())
import
fairseq
import
soundfile
as
sf
import
torch
import
torch.nn.functional
as
F
import
tqdm
from
npy_append_array
import
NpyAppendArray
from
fairseq.data.audio.audio_utils
import
(
parse_path
,
read_from_stored_zip
,
is_sf_audio_data
,
)
logging
.
basicConfig
(
format
=
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
level
=
os
.
environ
.
get
(
"LOGLEVEL"
,
"INFO"
).
upper
(),
stream
=
sys
.
stdout
,
)
logger
=
logging
.
getLogger
(
"dump_hubert_feature"
)
class
HubertFeatureReader
(
object
):
def
__init__
(
self
,
ckpt_path
,
layer
,
max_chunk
=
1600000
):
(
model
,
cfg
,
task
,
)
=
fairseq
.
checkpoint_utils
.
load_model_ensemble_and_task
([
ckpt_path
])
self
.
model
=
model
[
0
].
eval
().
cuda
()
if
hasattr
(
self
.
model
,
'w2v_encoder'
):
self
.
model
=
self
.
model
.
w2v_encoder
.
w2v_model
self
.
task
=
task
self
.
layer
=
layer
self
.
max_chunk
=
max_chunk
logger
.
info
(
f
"TASK CONFIG:
\n
{
self
.
task
.
cfg
}
"
)
logger
.
info
(
f
" max_chunk =
{
self
.
max_chunk
}
"
)
def
read_audio
(
self
,
path
,
ref_len
=
None
):
wav
,
sr
=
sf
.
read
(
path
)
assert
sr
==
self
.
task
.
cfg
.
sample_rate
,
sr
if
wav
.
ndim
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
ndim
==
1
,
wav
.
ndim
if
ref_len
is
not
None
and
abs
(
ref_len
-
len
(
wav
))
>
160
:
logging
.
warning
(
f
"ref
{
ref_len
}
!= read
{
len
(
wav
)
}
(
{
path
}
)"
)
return
wav
def
get_feats
(
self
,
path
,
ref_len
=
None
):
x
=
self
.
read_audio
(
path
,
ref_len
)
with
torch
.
no_grad
():
x
=
torch
.
from_numpy
(
x
).
float
().
cuda
()
if
self
.
task
.
cfg
.
normalize
:
x
=
F
.
layer_norm
(
x
,
x
.
shape
)
x
=
x
.
view
(
1
,
-
1
)
feat
=
[]
for
start
in
range
(
0
,
x
.
size
(
1
),
self
.
max_chunk
):
x_chunk
=
x
[:,
start
:
start
+
self
.
max_chunk
]
feat_chunk
,
_
=
self
.
model
.
extract_features
(
source
=
x_chunk
,
padding_mask
=
None
,
mask
=
False
,
output_layer
=
self
.
layer
,
)
feat
.
append
(
feat_chunk
)
return
torch
.
cat
(
feat
,
1
).
squeeze
(
0
)
def
get_path_iterator
(
tsv
,
nshard
,
rank
):
with
open
(
tsv
,
"r"
)
as
f
:
root
=
f
.
readline
().
rstrip
()
lines
=
[
line
.
rstrip
()
for
line
in
f
]
tot
=
len
(
lines
)
shard_size
=
math
.
ceil
(
tot
/
nshard
)
start
,
end
=
rank
*
shard_size
,
min
((
rank
+
1
)
*
shard_size
,
tot
)
assert
start
<
end
,
"start={start}, end={end}"
logger
.
info
(
f
"rank
{
rank
}
of
{
nshard
}
, process
{
end
-
start
}
"
f
"(
{
start
}
-
{
end
}
) out of
{
tot
}
"
)
lines
=
lines
[
start
:
end
]
def
iterate
():
for
line
in
lines
:
line
=
line
.
split
(
"
\t
"
)
subpath
=
line
[
0
]
nsample
=
line
[
1
]
path_or_fp
=
os
.
path
.
join
(
root
,
subpath
)
_path
,
slice_ptr
=
parse_path
(
path_or_fp
)
if
len
(
slice_ptr
)
==
2
:
byte_data
=
read_from_stored_zip
(
_path
,
slice_ptr
[
0
],
slice_ptr
[
1
])
assert
is_sf_audio_data
(
byte_data
)
path_or_fp
=
io
.
BytesIO
(
byte_data
)
yield
path_or_fp
,
int
(
nsample
)
return
iterate
,
len
(
lines
)
def
dump_feature
(
tsv_dir
,
split
,
ckpt_path
,
layer
,
nshard
,
rank
,
feat_dir
,
max_chunk
):
reader
=
HubertFeatureReader
(
ckpt_path
,
layer
,
max_chunk
)
generator
,
num
=
get_path_iterator
(
f
"
{
tsv_dir
}
/
{
split
}
.tsv"
,
nshard
,
rank
)
iterator
=
generator
()
feat_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.npy"
leng_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.len"
os
.
makedirs
(
feat_dir
,
exist_ok
=
True
)
if
os
.
path
.
exists
(
feat_path
):
os
.
remove
(
feat_path
)
feat_f
=
NpyAppendArray
(
feat_path
)
with
open
(
leng_path
,
"w"
)
as
leng_f
:
for
path
,
nsample
in
tqdm
.
tqdm
(
iterator
,
total
=
num
):
feat
=
reader
.
get_feats
(
path
,
nsample
)
feat_f
.
append
(
feat
.
cpu
().
numpy
())
leng_f
.
write
(
f
"
{
len
(
feat
)
}
\n
"
)
logger
.
info
(
"finished successfully"
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"tsv_dir"
)
parser
.
add_argument
(
"split"
)
parser
.
add_argument
(
"ckpt_path"
)
parser
.
add_argument
(
"layer"
,
type
=
int
)
parser
.
add_argument
(
"nshard"
,
type
=
int
)
parser
.
add_argument
(
"rank"
,
type
=
int
)
parser
.
add_argument
(
"feat_dir"
)
parser
.
add_argument
(
"--max_chunk"
,
type
=
int
,
default
=
1600000
)
args
=
parser
.
parse_args
()
logger
.
info
(
args
)
dump_feature
(
**
vars
(
args
))
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_hubert_feature_s2t.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
csv
import
io
import
logging
import
math
import
os
import
os.path
as
op
import
sys
import
tqdm
from
dump_hubert_feature
import
HubertFeatureReader
from
fairseq.data.audio.audio_utils
import
get_waveform
from
fairseq.data.audio.speech_to_text_dataset
import
(
read_from_uncompressed_zip
,
)
from
npy_append_array
import
NpyAppendArray
logging
.
basicConfig
(
format
=
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
level
=
os
.
environ
.
get
(
"LOGLEVEL"
,
"INFO"
).
upper
(),
stream
=
sys
.
stdout
,
)
logger
=
logging
.
getLogger
(
"dump_hubert_feature_s2t"
)
class
HubertFeatureReaderS2T
(
HubertFeatureReader
):
def
read_audio
(
self
,
path
,
ref_len
=
None
):
path
,
*
extra
=
path
.
split
(
":"
)
assert
len
(
extra
)
==
2
assert
path
.
endswith
(
".zip"
)
data
=
read_from_uncompressed_zip
(
path
,
int
(
extra
[
0
]),
int
(
extra
[
1
]))
f
=
io
.
BytesIO
(
data
)
wav
,
sr
=
get_waveform
(
f
)
assert
sr
==
self
.
task
.
cfg
.
sample_rate
,
sr
if
wav
.
ndim
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
ndim
==
1
,
wav
.
ndim
if
ref_len
is
not
None
and
abs
(
ref_len
-
len
(
wav
))
>
160
:
logging
.
warning
(
f
"ref
{
ref_len
}
!= read
{
len
(
wav
)
}
(
{
path
}
)"
)
return
wav
def
get_path_iterator
(
root
,
tsv
,
nshard
,
rank
):
with
open
(
tsv
)
as
f
:
reader
=
csv
.
DictReader
(
f
,
delimiter
=
"
\t
"
,
quotechar
=
None
,
doublequote
=
False
,
lineterminator
=
"
\n
"
,
quoting
=
csv
.
QUOTE_NONE
,
)
subpaths
=
[
op
.
join
(
root
,
e
[
"audio"
])
for
e
in
reader
]
tot
=
len
(
subpaths
)
shard_size
=
math
.
ceil
(
tot
/
nshard
)
start
,
end
=
rank
*
shard_size
,
min
((
rank
+
1
)
*
shard_size
,
tot
)
assert
start
<
end
,
"start={start}, end={end}"
logger
.
info
(
f
"rank
{
rank
}
of
{
nshard
}
, process
{
end
-
start
}
"
f
"(
{
start
}
-
{
end
}
) out of
{
tot
}
"
)
subpaths
=
subpaths
[
start
:
end
]
def
iterate
():
for
subpath
in
subpaths
:
yield
op
.
join
(
root
,
subpath
)
return
iterate
,
len
(
subpaths
)
def
dump_feature
(
root
,
tsv_path
,
ckpt_path
,
layer
,
nshard
,
rank
,
feat_dir
,
feat_name
,
max_chunk
,
):
reader
=
HubertFeatureReaderS2T
(
ckpt_path
,
layer
,
max_chunk
)
generator
,
num
=
get_path_iterator
(
root
,
tsv_path
,
nshard
,
rank
)
iterator
=
generator
()
feat_path
=
f
"
{
feat_dir
}
/
{
feat_name
}
_
{
rank
}
_
{
nshard
}
.npy"
leng_path
=
f
"
{
feat_dir
}
/
{
feat_name
}
_
{
rank
}
_
{
nshard
}
.len"
os
.
makedirs
(
feat_dir
,
exist_ok
=
True
)
if
op
.
exists
(
feat_path
):
os
.
remove
(
feat_path
)
feat_f
=
NpyAppendArray
(
feat_path
)
with
open
(
leng_path
,
"w"
)
as
leng_f
:
for
path
in
tqdm
.
tqdm
(
iterator
,
total
=
num
):
feat
=
reader
.
get_feats
(
path
)
feat_f
.
append
(
feat
.
cpu
().
numpy
())
leng_f
.
write
(
f
"
{
len
(
feat
)
}
\n
"
)
logger
.
info
(
"finished successfully"
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"root"
)
parser
.
add_argument
(
"tsv_path"
)
parser
.
add_argument
(
"ckpt_path"
)
parser
.
add_argument
(
"layer"
,
type
=
int
)
parser
.
add_argument
(
"nshard"
,
type
=
int
)
parser
.
add_argument
(
"rank"
,
type
=
int
)
parser
.
add_argument
(
"feat_dir"
)
parser
.
add_argument
(
"feat_name"
)
parser
.
add_argument
(
"--max_chunk"
,
type
=
int
,
default
=
1600000
)
args
=
parser
.
parse_args
()
logger
.
info
(
args
)
dump_feature
(
**
vars
(
args
))
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_km_label.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
os
import
sys
import
numpy
as
np
import
joblib
import
torch
import
tqdm
logging
.
basicConfig
(
format
=
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
level
=
os
.
environ
.
get
(
"LOGLEVEL"
,
"INFO"
).
upper
(),
stream
=
sys
.
stdout
,
)
logger
=
logging
.
getLogger
(
"dump_km_label"
)
class
ApplyKmeans
(
object
):
def
__init__
(
self
,
km_path
):
self
.
km_model
=
joblib
.
load
(
km_path
)
self
.
C_np
=
self
.
km_model
.
cluster_centers_
.
transpose
()
self
.
Cnorm_np
=
(
self
.
C_np
**
2
).
sum
(
0
,
keepdims
=
True
)
self
.
C
=
torch
.
from_numpy
(
self
.
C_np
)
self
.
Cnorm
=
torch
.
from_numpy
(
self
.
Cnorm_np
)
if
torch
.
cuda
.
is_available
():
self
.
C
=
self
.
C
.
cuda
()
self
.
Cnorm
=
self
.
Cnorm
.
cuda
()
def
__call__
(
self
,
x
):
if
isinstance
(
x
,
torch
.
Tensor
):
dist
=
(
x
.
pow
(
2
).
sum
(
1
,
keepdim
=
True
)
-
2
*
torch
.
matmul
(
x
,
self
.
C
)
+
self
.
Cnorm
)
return
dist
.
argmin
(
dim
=
1
).
cpu
().
numpy
()
else
:
dist
=
(
(
x
**
2
).
sum
(
1
,
keepdims
=
True
)
-
2
*
np
.
matmul
(
x
,
self
.
C_np
)
+
self
.
Cnorm_np
)
return
np
.
argmin
(
dist
,
axis
=
1
)
def
get_feat_iterator
(
feat_dir
,
split
,
nshard
,
rank
):
feat_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.npy"
leng_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.len"
with
open
(
leng_path
,
"r"
)
as
f
:
lengs
=
[
int
(
line
.
rstrip
())
for
line
in
f
]
offsets
=
[
0
]
+
np
.
cumsum
(
lengs
[:
-
1
]).
tolist
()
def
iterate
():
feat
=
np
.
load
(
feat_path
,
mmap_mode
=
"r"
)
assert
feat
.
shape
[
0
]
==
(
offsets
[
-
1
]
+
lengs
[
-
1
])
for
offset
,
leng
in
zip
(
offsets
,
lengs
):
yield
feat
[
offset
:
offset
+
leng
]
return
iterate
,
len
(
lengs
)
def
dump_label
(
feat_dir
,
split
,
km_path
,
nshard
,
rank
,
lab_dir
):
apply_kmeans
=
ApplyKmeans
(
km_path
)
generator
,
num
=
get_feat_iterator
(
feat_dir
,
split
,
nshard
,
rank
)
iterator
=
generator
()
lab_path
=
f
"
{
lab_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.km"
os
.
makedirs
(
lab_dir
,
exist_ok
=
True
)
with
open
(
lab_path
,
"w"
)
as
f
:
for
feat
in
tqdm
.
tqdm
(
iterator
,
total
=
num
):
# feat = torch.from_numpy(feat).cuda()
lab
=
apply_kmeans
(
feat
).
tolist
()
f
.
write
(
" "
.
join
(
map
(
str
,
lab
))
+
"
\n
"
)
logger
.
info
(
"finished successfully"
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"feat_dir"
)
parser
.
add_argument
(
"split"
)
parser
.
add_argument
(
"km_path"
)
parser
.
add_argument
(
"nshard"
,
type
=
int
)
parser
.
add_argument
(
"rank"
,
type
=
int
)
parser
.
add_argument
(
"lab_dir"
)
args
=
parser
.
parse_args
()
logging
.
info
(
str
(
args
))
dump_label
(
**
vars
(
args
))
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/dump_mfcc_feature.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
math
import
os
import
sys
import
soundfile
as
sf
import
torch
import
torchaudio
import
tqdm
from
npy_append_array
import
NpyAppendArray
logging
.
basicConfig
(
format
=
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
level
=
os
.
environ
.
get
(
"LOGLEVEL"
,
"INFO"
).
upper
(),
stream
=
sys
.
stdout
,
)
logger
=
logging
.
getLogger
(
"dump_mfcc_feature"
)
class
MfccFeatureReader
(
object
):
def
__init__
(
self
,
sample_rate
):
self
.
sample_rate
=
sample_rate
def
read_audio
(
self
,
path
,
ref_len
=
None
):
wav
,
sr
=
sf
.
read
(
path
)
assert
sr
==
self
.
sample_rate
,
sr
if
wav
.
ndim
==
2
:
wav
=
wav
.
mean
(
-
1
)
assert
wav
.
ndim
==
1
,
wav
.
ndim
if
ref_len
is
not
None
and
abs
(
ref_len
-
len
(
wav
))
>
160
:
logging
.
warning
(
f
"ref
{
ref_len
}
!= read
{
len
(
wav
)
}
(
{
path
}
)"
)
return
wav
def
get_feats
(
self
,
path
,
ref_len
=
None
):
x
=
self
.
read_audio
(
path
,
ref_len
)
with
torch
.
no_grad
():
x
=
torch
.
from_numpy
(
x
).
float
()
x
=
x
.
view
(
1
,
-
1
)
mfccs
=
torchaudio
.
compliance
.
kaldi
.
mfcc
(
waveform
=
x
,
sample_frequency
=
self
.
sample_rate
,
use_energy
=
False
,
)
# (time, freq)
mfccs
=
mfccs
.
transpose
(
0
,
1
)
# (freq, time)
deltas
=
torchaudio
.
functional
.
compute_deltas
(
mfccs
)
ddeltas
=
torchaudio
.
functional
.
compute_deltas
(
deltas
)
concat
=
torch
.
cat
([
mfccs
,
deltas
,
ddeltas
],
dim
=
0
)
concat
=
concat
.
transpose
(
0
,
1
).
contiguous
()
# (freq, time)
return
concat
def
get_path_iterator
(
tsv
,
nshard
,
rank
):
with
open
(
tsv
,
"r"
)
as
f
:
root
=
f
.
readline
().
rstrip
()
lines
=
[
line
.
rstrip
()
for
line
in
f
]
tot
=
len
(
lines
)
shard_size
=
math
.
ceil
(
tot
/
nshard
)
start
,
end
=
rank
*
shard_size
,
min
((
rank
+
1
)
*
shard_size
,
tot
)
assert
start
<
end
,
"start={start}, end={end}"
logger
.
info
(
f
"rank
{
rank
}
of
{
nshard
}
, process
{
end
-
start
}
"
f
"(
{
start
}
-
{
end
}
) out of
{
tot
}
"
)
lines
=
lines
[
start
:
end
]
def
iterate
():
for
line
in
lines
:
subpath
,
nsample
=
line
.
split
(
"
\t
"
)
yield
f
"
{
root
}
/
{
subpath
}
"
,
int
(
nsample
)
return
iterate
,
len
(
lines
)
def
dump_feature
(
tsv_dir
,
split
,
sample_rate
,
nshard
,
rank
,
feat_dir
):
reader
=
MfccFeatureReader
(
sample_rate
)
generator
,
num
=
get_path_iterator
(
f
"
{
tsv_dir
}
/
{
split
}
.tsv"
,
nshard
,
rank
)
iterator
=
generator
()
feat_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.npy"
leng_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.len"
os
.
makedirs
(
feat_dir
,
exist_ok
=
True
)
if
os
.
path
.
exists
(
feat_path
):
os
.
remove
(
feat_path
)
feat_f
=
NpyAppendArray
(
feat_path
)
with
open
(
leng_path
,
"w"
)
as
leng_f
:
for
path
,
nsample
in
tqdm
.
tqdm
(
iterator
,
total
=
num
):
feat
=
reader
.
get_feats
(
path
,
nsample
)
feat_f
.
append
(
feat
.
cpu
().
numpy
())
leng_f
.
write
(
f
"
{
len
(
feat
)
}
\n
"
)
logger
.
info
(
"finished successfully"
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"tsv_dir"
)
parser
.
add_argument
(
"split"
)
parser
.
add_argument
(
"nshard"
,
type
=
int
)
parser
.
add_argument
(
"rank"
,
type
=
int
)
parser
.
add_argument
(
"feat_dir"
)
parser
.
add_argument
(
"--sample_rate"
,
type
=
int
,
default
=
16000
)
args
=
parser
.
parse_args
()
logger
.
info
(
args
)
dump_feature
(
**
vars
(
args
))
third_party/seed-tts-eval/thirdparty/UniSpeech/src/examples/hubert/simple_kmeans/learn_kmeans.py
0 → 100644
View file @
39ac40a9
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import
logging
import
os
import
sys
import
numpy
as
np
from
sklearn.cluster
import
MiniBatchKMeans
import
joblib
logging
.
basicConfig
(
format
=
"%(asctime)s | %(levelname)s | %(name)s | %(message)s"
,
datefmt
=
"%Y-%m-%d %H:%M:%S"
,
level
=
os
.
environ
.
get
(
"LOGLEVEL"
,
"INFO"
).
upper
(),
stream
=
sys
.
stdout
,
)
logger
=
logging
.
getLogger
(
"learn_kmeans"
)
def
get_km_model
(
n_clusters
,
init
,
max_iter
,
batch_size
,
tol
,
max_no_improvement
,
n_init
,
reassignment_ratio
,
):
return
MiniBatchKMeans
(
n_clusters
=
n_clusters
,
init
=
init
,
max_iter
=
max_iter
,
batch_size
=
batch_size
,
verbose
=
1
,
compute_labels
=
False
,
tol
=
tol
,
max_no_improvement
=
max_no_improvement
,
init_size
=
None
,
n_init
=
n_init
,
reassignment_ratio
=
reassignment_ratio
,
)
def
load_feature_shard
(
feat_dir
,
split
,
nshard
,
rank
,
percent
):
feat_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.npy"
leng_path
=
f
"
{
feat_dir
}
/
{
split
}
_
{
rank
}
_
{
nshard
}
.len"
with
open
(
leng_path
,
"r"
)
as
f
:
lengs
=
[
int
(
line
.
rstrip
())
for
line
in
f
]
offsets
=
[
0
]
+
np
.
cumsum
(
lengs
[:
-
1
]).
tolist
()
if
percent
<
0
:
return
np
.
load
(
feat_path
,
mmap_mode
=
"r"
)
else
:
nsample
=
int
(
np
.
ceil
(
len
(
lengs
)
*
percent
))
indices
=
np
.
random
.
choice
(
len
(
lengs
),
nsample
,
replace
=
False
)
feat
=
np
.
load
(
feat_path
,
mmap_mode
=
"r"
)
sampled_feat
=
np
.
concatenate
(
[
feat
[
offsets
[
i
]:
offsets
[
i
]
+
lengs
[
i
]]
for
i
in
indices
],
axis
=
0
)
logger
.
info
(
(
f
"sampled
{
nsample
}
utterances,
{
len
(
sampled_feat
)
}
frames "
f
"from shard
{
rank
}
/
{
nshard
}
"
)
)
return
sampled_feat
def
load_feature
(
feat_dir
,
split
,
nshard
,
seed
,
percent
):
assert
percent
<=
1.0
feat
=
np
.
concatenate
(
[
load_feature_shard
(
feat_dir
,
split
,
nshard
,
r
,
percent
)
for
r
in
range
(
nshard
)
],
axis
=
0
,
)
logging
.
info
(
f
"loaded feature with dimension
{
feat
.
shape
}
"
)
return
feat
def
learn_kmeans
(
feat_dir
,
split
,
nshard
,
km_path
,
n_clusters
,
seed
,
percent
,
init
,
max_iter
,
batch_size
,
tol
,
n_init
,
reassignment_ratio
,
max_no_improvement
,
):
np
.
random
.
seed
(
seed
)
feat
=
load_feature
(
feat_dir
,
split
,
nshard
,
seed
,
percent
)
km_model
=
get_km_model
(
n_clusters
,
init
,
max_iter
,
batch_size
,
tol
,
max_no_improvement
,
n_init
,
reassignment_ratio
,
)
km_model
.
fit
(
feat
)
joblib
.
dump
(
km_model
,
km_path
)
inertia
=
-
km_model
.
score
(
feat
)
/
len
(
feat
)
logger
.
info
(
"total intertia: %.5f"
,
inertia
)
logger
.
info
(
"finished successfully"
)
if
__name__
==
"__main__"
:
import
argparse
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"feat_dir"
,
type
=
str
)
parser
.
add_argument
(
"split"
,
type
=
str
)
parser
.
add_argument
(
"nshard"
,
type
=
int
)
parser
.
add_argument
(
"km_path"
,
type
=
str
)
parser
.
add_argument
(
"n_clusters"
,
type
=
int
)
parser
.
add_argument
(
"--seed"
,
default
=
0
,
type
=
int
)
parser
.
add_argument
(
"--percent"
,
default
=-
1
,
type
=
float
,
help
=
"sample a subset; -1 for all"
)
parser
.
add_argument
(
"--init"
,
default
=
"k-means++"
)
parser
.
add_argument
(
"--max_iter"
,
default
=
100
,
type
=
int
)
parser
.
add_argument
(
"--batch_size"
,
default
=
10000
,
type
=
int
)
parser
.
add_argument
(
"--tol"
,
default
=
0.0
,
type
=
float
)
parser
.
add_argument
(
"--max_no_improvement"
,
default
=
100
,
type
=
int
)
parser
.
add_argument
(
"--n_init"
,
default
=
20
,
type
=
int
)
parser
.
add_argument
(
"--reassignment_ratio"
,
default
=
0.0
,
type
=
float
)
args
=
parser
.
parse_args
()
logging
.
info
(
str
(
args
))
learn_kmeans
(
**
vars
(
args
))
Prev
1
…
14
15
16
17
18
19
20
21
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment