Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Conformer_pytorch
Commits
a7785cc6
Commit
a7785cc6
authored
Mar 26, 2024
by
Sugon_ldc
Browse files
delete soft link
parent
9a2a05ca
Changes
162
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1344 additions
and
0 deletions
+1344
-0
examples/aishell/s0/wenet/utils/__pycache__/cmvn.cpython-38.pyc
...es/aishell/s0/wenet/utils/__pycache__/cmvn.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/common.cpython-38.pyc
.../aishell/s0/wenet/utils/__pycache__/common.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/compute_acc.cpython-38.pyc
...ell/s0/wenet/utils/__pycache__/compute_acc.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/config.cpython-38.pyc
.../aishell/s0/wenet/utils/__pycache__/config.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/executor.cpython-38.pyc
...ishell/s0/wenet/utils/__pycache__/executor.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/file_utils.cpython-38.pyc
...hell/s0/wenet/utils/__pycache__/file_utils.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/global_vars.cpython-38.pyc
...ell/s0/wenet/utils/__pycache__/global_vars.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/init_model.cpython-38.pyc
...hell/s0/wenet/utils/__pycache__/init_model.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/mask.cpython-38.pyc
...es/aishell/s0/wenet/utils/__pycache__/mask.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/__pycache__/scheduler.cpython-38.pyc
...shell/s0/wenet/utils/__pycache__/scheduler.cpython-38.pyc
+0
-0
examples/aishell/s0/wenet/utils/checkpoint.py
examples/aishell/s0/wenet/utils/checkpoint.py
+106
-0
examples/aishell/s0/wenet/utils/cmvn.py
examples/aishell/s0/wenet/utils/cmvn.py
+93
-0
examples/aishell/s0/wenet/utils/common.py
examples/aishell/s0/wenet/utils/common.py
+257
-0
examples/aishell/s0/wenet/utils/compute_acc.py
examples/aishell/s0/wenet/utils/compute_acc.py
+395
-0
examples/aishell/s0/wenet/utils/config.py
examples/aishell/s0/wenet/utils/config.py
+39
-0
examples/aishell/s0/wenet/utils/ctc_util.py
examples/aishell/s0/wenet/utils/ctc_util.py
+83
-0
examples/aishell/s0/wenet/utils/executor.py
examples/aishell/s0/wenet/utils/executor.py
+166
-0
examples/aishell/s0/wenet/utils/file_utils.py
examples/aishell/s0/wenet/utils/file_utils.py
+66
-0
examples/aishell/s0/wenet/utils/global_vars.py
examples/aishell/s0/wenet/utils/global_vars.py
+29
-0
examples/aishell/s0/wenet/utils/init_model.py
examples/aishell/s0/wenet/utils/init_model.py
+110
-0
No files found.
examples/aishell/s0/wenet/utils/__pycache__/cmvn.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/common.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/compute_acc.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/config.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/executor.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/file_utils.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/global_vars.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/init_model.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/mask.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/__pycache__/scheduler.cpython-38.pyc
0 → 100644
View file @
a7785cc6
File added
examples/aishell/s0/wenet/utils/checkpoint.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
import
os
import
re
import
yaml
import
torch
from
collections
import
OrderedDict
import
datetime
def
load_checkpoint
(
model
:
torch
.
nn
.
Module
,
path
:
str
)
->
dict
:
if
torch
.
cuda
.
is_available
():
logging
.
info
(
'Checkpoint: loading from checkpoint %s for GPU'
%
path
)
checkpoint
=
torch
.
load
(
path
)
else
:
logging
.
info
(
'Checkpoint: loading from checkpoint %s for CPU'
%
path
)
checkpoint
=
torch
.
load
(
path
,
map_location
=
'cpu'
)
model
.
load_state_dict
(
checkpoint
,
strict
=
False
)
info_path
=
re
.
sub
(
'.pt$'
,
'.yaml'
,
path
)
configs
=
{}
if
os
.
path
.
exists
(
info_path
):
with
open
(
info_path
,
'r'
)
as
fin
:
configs
=
yaml
.
load
(
fin
,
Loader
=
yaml
.
FullLoader
)
return
configs
def
save_checkpoint
(
model
:
torch
.
nn
.
Module
,
path
:
str
,
infos
=
None
):
'''
Args:
infos (dict or None): any info you want to save.
'''
logging
.
info
(
'Checkpoint: save to checkpoint %s'
%
path
)
if
isinstance
(
model
,
torch
.
nn
.
DataParallel
):
state_dict
=
model
.
module
.
state_dict
()
elif
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
state_dict
=
model
.
module
.
state_dict
()
else
:
state_dict
=
model
.
state_dict
()
torch
.
save
(
state_dict
,
path
)
info_path
=
re
.
sub
(
'.pt$'
,
'.yaml'
,
path
)
if
infos
is
None
:
infos
=
{}
infos
[
'save_time'
]
=
datetime
.
datetime
.
now
().
strftime
(
'%d/%m/%Y %H:%M:%S'
)
with
open
(
info_path
,
'w'
)
as
fout
:
data
=
yaml
.
dump
(
infos
)
fout
.
write
(
data
)
def
filter_modules
(
model_state_dict
,
modules
):
new_mods
=
[]
incorrect_mods
=
[]
mods_model
=
model_state_dict
.
keys
()
for
mod
in
modules
:
if
any
(
key
.
startswith
(
mod
)
for
key
in
mods_model
):
new_mods
+=
[
mod
]
else
:
incorrect_mods
+=
[
mod
]
if
incorrect_mods
:
logging
.
warning
(
"module(s) %s don't match or (partially match) "
"available modules in model."
,
incorrect_mods
,
)
logging
.
warning
(
"for information, the existing modules in model are:"
)
logging
.
warning
(
"%s"
,
mods_model
)
return
new_mods
def
load_trained_modules
(
model
:
torch
.
nn
.
Module
,
args
:
None
):
# Load encoder modules with pre-trained model(s).
enc_model_path
=
args
.
enc_init
enc_modules
=
args
.
enc_init_mods
main_state_dict
=
model
.
state_dict
()
logging
.
warning
(
"model(s) found for pre-initialization"
)
if
os
.
path
.
isfile
(
enc_model_path
):
logging
.
info
(
'Checkpoint: loading from checkpoint %s for CPU'
%
enc_model_path
)
model_state_dict
=
torch
.
load
(
enc_model_path
,
map_location
=
'cpu'
)
modules
=
filter_modules
(
model_state_dict
,
enc_modules
)
partial_state_dict
=
OrderedDict
()
for
key
,
value
in
model_state_dict
.
items
():
if
any
(
key
.
startswith
(
m
)
for
m
in
modules
):
partial_state_dict
[
key
]
=
value
main_state_dict
.
update
(
partial_state_dict
)
else
:
logging
.
warning
(
"model was not found : %s"
,
enc_model_path
)
model
.
load_state_dict
(
main_state_dict
)
configs
=
{}
return
configs
examples/aishell/s0/wenet/utils/cmvn.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
json
import
math
import
numpy
as
np
def
_load_json_cmvn
(
json_cmvn_file
):
""" Load the json format cmvn stats file and calculate cmvn
Args:
json_cmvn_file: cmvn stats file in json format
Returns:
a numpy array of [means, vars]
"""
with
open
(
json_cmvn_file
)
as
f
:
cmvn_stats
=
json
.
load
(
f
)
means
=
cmvn_stats
[
'mean_stat'
]
variance
=
cmvn_stats
[
'var_stat'
]
count
=
cmvn_stats
[
'frame_num'
]
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
_load_kaldi_cmvn
(
kaldi_cmvn_file
):
""" Load the kaldi format cmvn stats file and calculate cmvn
Args:
kaldi_cmvn_file: kaldi text style global cmvn file, which
is generated by:
compute-cmvn-stats --binary=false scp:feats.scp global_cmvn
Returns:
a numpy array of [means, vars]
"""
means
=
[]
variance
=
[]
with
open
(
kaldi_cmvn_file
,
'r'
)
as
fid
:
# kaldi binary file start with '\0B'
if
fid
.
read
(
2
)
==
'
\0
B'
:
logging
.
error
(
'kaldi cmvn binary file is not supported, please '
'recompute it by: compute-cmvn-stats --binary=false '
' scp:feats.scp global_cmvn'
)
sys
.
exit
(
1
)
fid
.
seek
(
0
)
arr
=
fid
.
read
().
split
()
assert
(
arr
[
0
]
==
'['
)
assert
(
arr
[
-
2
]
==
'0'
)
assert
(
arr
[
-
1
]
==
']'
)
feat_dim
=
int
((
len
(
arr
)
-
2
-
2
)
/
2
)
for
i
in
range
(
1
,
feat_dim
+
1
):
means
.
append
(
float
(
arr
[
i
]))
count
=
float
(
arr
[
feat_dim
+
1
])
for
i
in
range
(
feat_dim
+
2
,
2
*
feat_dim
+
2
):
variance
.
append
(
float
(
arr
[
i
]))
for
i
in
range
(
len
(
means
)):
means
[
i
]
/=
count
variance
[
i
]
=
variance
[
i
]
/
count
-
means
[
i
]
*
means
[
i
]
if
variance
[
i
]
<
1.0e-20
:
variance
[
i
]
=
1.0e-20
variance
[
i
]
=
1.0
/
math
.
sqrt
(
variance
[
i
])
cmvn
=
np
.
array
([
means
,
variance
])
return
cmvn
def
load_cmvn
(
cmvn_file
,
is_json
):
if
is_json
:
cmvn
=
_load_json_cmvn
(
cmvn_file
)
else
:
cmvn
=
_load_kaldi_cmvn
(
cmvn_file
)
return
cmvn
[
0
],
cmvn
[
1
]
examples/aishell/s0/wenet/utils/common.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Modified from ESPnet(https://github.com/espnet/espnet)
"""Unility functions for Transformer."""
import
math
from
typing
import
List
,
Tuple
import
torch
from
torch.nn.utils.rnn
import
pad_sequence
IGNORE_ID
=
-
1
def
pad_list
(
xs
:
List
[
torch
.
Tensor
],
pad_value
:
int
):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch
=
len
(
xs
)
max_len
=
max
([
x
.
size
(
0
)
for
x
in
xs
])
pad
=
torch
.
zeros
(
n_batch
,
max_len
,
dtype
=
xs
[
0
].
dtype
,
device
=
xs
[
0
].
device
)
pad
=
pad
.
fill_
(
pad_value
)
for
i
in
range
(
n_batch
):
pad
[
i
,
:
xs
[
i
].
size
(
0
)]
=
xs
[
i
]
return
pad
def
add_blank
(
ys_pad
:
torch
.
Tensor
,
blank
:
int
,
ignore_id
:
int
)
->
torch
.
Tensor
:
""" Prepad blank for transducer predictor
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
blank (int): index of <blank>
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> blank = 0
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in = add_blank(ys_pad, 0, -1)
>>> ys_in
tensor([[0, 1, 2, 3, 4, 5],
[0, 4, 5, 6, 0, 0],
[0, 7, 8, 9, 0, 0]])
"""
bs
=
ys_pad
.
size
(
0
)
_blank
=
torch
.
tensor
([
blank
],
dtype
=
torch
.
long
,
requires_grad
=
False
,
device
=
ys_pad
.
device
)
_blank
=
_blank
.
repeat
(
bs
).
unsqueeze
(
1
)
# [bs,1]
out
=
torch
.
cat
([
_blank
,
ys_pad
],
dim
=
1
)
# [bs, Lmax+1]
return
torch
.
where
(
out
==
ignore_id
,
blank
,
out
)
def
add_sos_eos
(
ys_pad
:
torch
.
Tensor
,
sos
:
int
,
eos
:
int
,
ignore_id
:
int
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""Add <sos> and <eos> labels.
Args:
ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax)
sos (int): index of <sos>
eos (int): index of <eeos>
ignore_id (int): index of padding
Returns:
ys_in (torch.Tensor) : (B, Lmax + 1)
ys_out (torch.Tensor) : (B, Lmax + 1)
Examples:
>>> sos_id = 10
>>> eos_id = 11
>>> ignore_id = -1
>>> ys_pad
tensor([[ 1, 2, 3, 4, 5],
[ 4, 5, 6, -1, -1],
[ 7, 8, 9, -1, -1]], dtype=torch.int32)
>>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id)
>>> ys_in
tensor([[10, 1, 2, 3, 4, 5],
[10, 4, 5, 6, 11, 11],
[10, 7, 8, 9, 11, 11]])
>>> ys_out
tensor([[ 1, 2, 3, 4, 5, 11],
[ 4, 5, 6, 11, -1, -1],
[ 7, 8, 9, 11, -1, -1]])
"""
_sos
=
torch
.
tensor
([
sos
],
dtype
=
torch
.
long
,
requires_grad
=
False
,
device
=
ys_pad
.
device
)
_eos
=
torch
.
tensor
([
eos
],
dtype
=
torch
.
long
,
requires_grad
=
False
,
device
=
ys_pad
.
device
)
ys
=
[
y
[
y
!=
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
ys_in
=
[
torch
.
cat
([
_sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
torch
.
cat
([
y
,
_eos
],
dim
=
0
)
for
y
in
ys
]
return
pad_list
(
ys_in
,
eos
),
pad_list
(
ys_out
,
ignore_id
)
def
reverse_pad_list
(
ys_pad
:
torch
.
Tensor
,
ys_lens
:
torch
.
Tensor
,
pad_value
:
float
=
-
1.0
)
->
torch
.
Tensor
:
"""Reverse padding for the list of tensors.
Args:
ys_pad (tensor): The padded tensor (B, Tokenmax).
ys_lens (tensor): The lens of token seqs (B)
pad_value (int): Value for padding.
Returns:
Tensor: Padded tensor (B, Tokenmax).
Examples:
>>> x
tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]])
>>> pad_list(x, 0)
tensor([[4, 3, 2, 1],
[7, 6, 5, 0],
[9, 8, 0, 0]])
"""
r_ys_pad
=
pad_sequence
([(
torch
.
flip
(
y
.
int
()[:
i
],
[
0
]))
for
y
,
i
in
zip
(
ys_pad
,
ys_lens
)],
True
,
pad_value
)
return
r_ys_pad
def
th_accuracy
(
pad_outputs
:
torch
.
Tensor
,
pad_targets
:
torch
.
Tensor
,
ignore_label
:
int
)
->
float
:
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
size
(
0
),
pad_targets
.
size
(
1
),
pad_outputs
.
size
(
1
)).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
numerator
=
torch
.
sum
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
))
denominator
=
torch
.
sum
(
mask
)
return
float
(
numerator
)
/
float
(
denominator
)
def
get_rnn
(
rnn_type
:
str
)
->
torch
.
nn
.
Module
:
assert
rnn_type
in
[
"rnn"
,
"lstm"
,
"gru"
]
if
rnn_type
==
"rnn"
:
return
torch
.
nn
.
RNN
elif
rnn_type
==
"lstm"
:
return
torch
.
nn
.
LSTM
else
:
return
torch
.
nn
.
GRU
def
get_activation
(
act
):
"""Return activation function."""
# Lazy load to avoid unused import
from
wenet.transformer.swish
import
Swish
activation_funcs
=
{
"hardtanh"
:
torch
.
nn
.
Hardtanh
,
"tanh"
:
torch
.
nn
.
Tanh
,
"relu"
:
torch
.
nn
.
ReLU
,
"selu"
:
torch
.
nn
.
SELU
,
"swish"
:
getattr
(
torch
.
nn
,
"SiLU"
,
Swish
),
"gelu"
:
torch
.
nn
.
GELU
}
return
activation_funcs
[
act
]()
def
get_subsample
(
config
):
input_layer
=
config
[
"encoder_conf"
][
"input_layer"
]
assert
input_layer
in
[
"conv2d"
,
"conv2d6"
,
"conv2d8"
]
if
input_layer
==
"conv2d"
:
return
4
elif
input_layer
==
"conv2d6"
:
return
6
elif
input_layer
==
"conv2d8"
:
return
8
def
remove_duplicates_and_blank
(
hyp
:
List
[
int
])
->
List
[
int
]:
new_hyp
:
List
[
int
]
=
[]
cur
=
0
while
cur
<
len
(
hyp
):
if
hyp
[
cur
]
!=
0
:
new_hyp
.
append
(
hyp
[
cur
])
prev
=
cur
while
cur
<
len
(
hyp
)
and
hyp
[
cur
]
==
hyp
[
prev
]:
cur
+=
1
return
new_hyp
def
replace_duplicates_with_blank
(
hyp
:
List
[
int
])
->
List
[
int
]:
new_hyp
:
List
[
int
]
=
[]
cur
=
0
while
cur
<
len
(
hyp
):
new_hyp
.
append
(
hyp
[
cur
])
prev
=
cur
cur
+=
1
while
cur
<
len
(
hyp
)
and
hyp
[
cur
]
==
hyp
[
prev
]
and
hyp
[
cur
]
!=
0
:
new_hyp
.
append
(
0
)
cur
+=
1
return
new_hyp
def
log_add
(
args
:
List
[
int
])
->
float
:
"""
Stable log add
"""
if
all
(
a
==
-
float
(
'inf'
)
for
a
in
args
):
return
-
float
(
'inf'
)
a_max
=
max
(
args
)
lsp
=
math
.
log
(
sum
(
math
.
exp
(
a
-
a_max
)
for
a
in
args
))
return
a_max
+
lsp
examples/aishell/s0/wenet/utils/compute_acc.py
0 → 100755
View file @
a7785cc6
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# This file is originally copied from tools/compute-wer.py and modified to calculate the char accuracy
import
re
,
sys
,
unicodedata
import
codecs
remove_tag
=
True
spacelist
=
[
' '
,
'
\t
'
,
'
\r
'
,
'
\n
'
]
puncts
=
[
'!'
,
','
,
'?'
,
'、'
,
'。'
,
'!'
,
','
,
';'
,
'?'
,
':'
,
'「'
,
'」'
,
'︰'
,
'『'
,
'』'
,
'《'
,
'》'
]
def
characterize
(
string
)
:
res
=
[]
i
=
0
while
i
<
len
(
string
):
char
=
string
[
i
]
if
char
in
puncts
:
i
+=
1
continue
cat1
=
unicodedata
.
category
(
char
)
#https://unicodebook.readthedocs.io/unicode.html#unicode-categories
if
cat1
==
'Zs'
or
cat1
==
'Cn'
or
char
in
spacelist
:
# space or not assigned
i
+=
1
continue
if
cat1
==
'Lo'
:
# letter-other
res
.
append
(
char
)
i
+=
1
else
:
# some input looks like: <unk><noise>, we want to separate it to two words.
sep
=
' '
if
char
==
'<'
:
sep
=
'>'
j
=
i
+
1
while
j
<
len
(
string
):
c
=
string
[
j
]
if
ord
(
c
)
>=
128
or
(
c
in
spacelist
)
or
(
c
==
sep
):
break
j
+=
1
if
j
<
len
(
string
)
and
string
[
j
]
==
'>'
:
j
+=
1
res
.
append
(
string
[
i
:
j
])
i
=
j
return
res
def
stripoff_tags
(
x
):
if
not
x
:
return
''
chars
=
[]
i
=
0
;
T
=
len
(
x
)
while
i
<
T
:
if
x
[
i
]
==
'<'
:
while
i
<
T
and
x
[
i
]
!=
'>'
:
i
+=
1
i
+=
1
else
:
chars
.
append
(
x
[
i
])
i
+=
1
return
''
.
join
(
chars
)
def
normalize
(
sentence
,
ignore_words
,
cs
,
split
=
None
):
""" sentence, ignore_words are both in unicode
"""
new_sentence
=
[]
for
token
in
sentence
:
x
=
token
if
not
cs
:
x
=
x
.
upper
()
if
x
in
ignore_words
:
continue
if
remove_tag
:
x
=
stripoff_tags
(
x
)
if
not
x
:
continue
if
split
and
x
in
split
:
new_sentence
+=
split
[
x
]
else
:
new_sentence
.
append
(
x
)
return
new_sentence
class
Calculator
:
def
__init__
(
self
)
:
self
.
data
=
{}
self
.
space
=
[]
self
.
cost
=
{}
self
.
cost
[
'cor'
]
=
0
self
.
cost
[
'sub'
]
=
1
self
.
cost
[
'del'
]
=
1
self
.
cost
[
'ins'
]
=
1
def
calculate
(
self
,
lab
,
rec
)
:
# Initialization
lab
.
insert
(
0
,
''
)
rec
.
insert
(
0
,
''
)
while
len
(
self
.
space
)
<
len
(
lab
)
:
self
.
space
.
append
([])
for
row
in
self
.
space
:
for
element
in
row
:
element
[
'dist'
]
=
0
element
[
'error'
]
=
'non'
while
len
(
row
)
<
len
(
rec
)
:
row
.
append
({
'dist'
:
0
,
'error'
:
'non'
})
for
i
in
range
(
len
(
lab
))
:
self
.
space
[
i
][
0
][
'dist'
]
=
i
self
.
space
[
i
][
0
][
'error'
]
=
'del'
for
j
in
range
(
len
(
rec
))
:
self
.
space
[
0
][
j
][
'dist'
]
=
j
self
.
space
[
0
][
j
][
'error'
]
=
'ins'
self
.
space
[
0
][
0
][
'error'
]
=
'non'
for
token
in
lab
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
rec
:
if
token
not
in
self
.
data
and
len
(
token
)
>
0
:
self
.
data
[
token
]
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
# Computing edit distance
for
i
,
lab_token
in
enumerate
(
lab
)
:
for
j
,
rec_token
in
enumerate
(
rec
)
:
if
i
==
0
or
j
==
0
:
continue
min_dist
=
sys
.
maxsize
min_error
=
'none'
dist
=
self
.
space
[
i
-
1
][
j
][
'dist'
]
+
self
.
cost
[
'del'
]
error
=
'del'
if
dist
<
min_dist
:
min_dist
=
dist
min_error
=
error
dist
=
self
.
space
[
i
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'ins'
]
error
=
'ins'
if
dist
<
min_dist
:
min_dist
=
dist
min_error
=
error
if
lab_token
==
rec_token
:
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'cor'
]
error
=
'cor'
else
:
dist
=
self
.
space
[
i
-
1
][
j
-
1
][
'dist'
]
+
self
.
cost
[
'sub'
]
error
=
'sub'
if
dist
<
min_dist
:
min_dist
=
dist
min_error
=
error
self
.
space
[
i
][
j
][
'dist'
]
=
min_dist
self
.
space
[
i
][
j
][
'error'
]
=
min_error
# Tracing back
result
=
{
'lab'
:[],
'rec'
:[],
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
i
=
len
(
lab
)
-
1
j
=
len
(
rec
)
-
1
while
True
:
if
self
.
space
[
i
][
j
][
'error'
]
==
'cor'
:
# correct
if
len
(
lab
[
i
])
>
0
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'cor'
]
=
self
.
data
[
lab
[
i
]][
'cor'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'cor'
]
=
result
[
'cor'
]
+
1
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
i
=
i
-
1
j
=
j
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'sub'
:
# substitution
if
len
(
lab
[
i
])
>
0
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'sub'
]
=
self
.
data
[
lab
[
i
]][
'sub'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'sub'
]
=
result
[
'sub'
]
+
1
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
i
=
i
-
1
j
=
j
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'del'
:
# deletion
if
len
(
lab
[
i
])
>
0
:
self
.
data
[
lab
[
i
]][
'all'
]
=
self
.
data
[
lab
[
i
]][
'all'
]
+
1
self
.
data
[
lab
[
i
]][
'del'
]
=
self
.
data
[
lab
[
i
]][
'del'
]
+
1
result
[
'all'
]
=
result
[
'all'
]
+
1
result
[
'del'
]
=
result
[
'del'
]
+
1
result
[
'lab'
].
insert
(
0
,
lab
[
i
])
result
[
'rec'
].
insert
(
0
,
""
)
i
=
i
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'ins'
:
# insertion
if
len
(
rec
[
j
])
>
0
:
self
.
data
[
rec
[
j
]][
'ins'
]
=
self
.
data
[
rec
[
j
]][
'ins'
]
+
1
result
[
'ins'
]
=
result
[
'ins'
]
+
1
result
[
'lab'
].
insert
(
0
,
""
)
result
[
'rec'
].
insert
(
0
,
rec
[
j
])
j
=
j
-
1
elif
self
.
space
[
i
][
j
][
'error'
]
==
'non'
:
# starting point
break
else
:
# shouldn't reach here
print
(
'this should not happen , i = {i} , j = {j} , error = {error}'
.
format
(
i
=
i
,
j
=
j
,
error
=
self
.
space
[
i
][
j
][
'error'
]))
return
result
def
overall
(
self
)
:
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
self
.
data
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
return
result
def
cluster
(
self
,
data
)
:
result
=
{
'all'
:
0
,
'cor'
:
0
,
'sub'
:
0
,
'ins'
:
0
,
'del'
:
0
}
for
token
in
data
:
if
token
in
self
.
data
:
result
[
'all'
]
=
result
[
'all'
]
+
self
.
data
[
token
][
'all'
]
result
[
'cor'
]
=
result
[
'cor'
]
+
self
.
data
[
token
][
'cor'
]
result
[
'sub'
]
=
result
[
'sub'
]
+
self
.
data
[
token
][
'sub'
]
result
[
'ins'
]
=
result
[
'ins'
]
+
self
.
data
[
token
][
'ins'
]
result
[
'del'
]
=
result
[
'del'
]
+
self
.
data
[
token
][
'del'
]
return
result
def
keys
(
self
)
:
return
list
(
self
.
data
.
keys
())
def
width
(
string
):
return
sum
(
1
+
(
unicodedata
.
east_asian_width
(
c
)
in
"AFW"
)
for
c
in
string
)
def
default_cluster
(
word
)
:
unicode_names
=
[
unicodedata
.
name
(
char
)
for
char
in
word
]
for
i
in
reversed
(
range
(
len
(
unicode_names
)))
:
if
unicode_names
[
i
].
startswith
(
'DIGIT'
)
:
# 1
unicode_names
[
i
]
=
'Number'
# 'DIGIT'
elif
(
unicode_names
[
i
].
startswith
(
'CJK UNIFIED IDEOGRAPH'
)
or
unicode_names
[
i
].
startswith
(
'CJK COMPATIBILITY IDEOGRAPH'
))
:
# 明 / 郎
unicode_names
[
i
]
=
'Mandarin'
# 'CJK IDEOGRAPH'
elif
(
unicode_names
[
i
].
startswith
(
'LATIN CAPITAL LETTER'
)
or
unicode_names
[
i
].
startswith
(
'LATIN SMALL LETTER'
))
:
# A / a
unicode_names
[
i
]
=
'English'
# 'LATIN LETTER'
elif
unicode_names
[
i
].
startswith
(
'HIRAGANA LETTER'
)
:
# は こ め
unicode_names
[
i
]
=
'Japanese'
# 'GANA LETTER'
elif
(
unicode_names
[
i
].
startswith
(
'AMPERSAND'
)
or
unicode_names
[
i
].
startswith
(
'APOSTROPHE'
)
or
unicode_names
[
i
].
startswith
(
'COMMERCIAL AT'
)
or
unicode_names
[
i
].
startswith
(
'DEGREE CELSIUS'
)
or
unicode_names
[
i
].
startswith
(
'EQUALS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'FULL STOP'
)
or
unicode_names
[
i
].
startswith
(
'HYPHEN-MINUS'
)
or
unicode_names
[
i
].
startswith
(
'LOW LINE'
)
or
unicode_names
[
i
].
startswith
(
'NUMBER SIGN'
)
or
unicode_names
[
i
].
startswith
(
'PLUS SIGN'
)
or
unicode_names
[
i
].
startswith
(
'SEMICOLON'
))
:
# & / ' / @ / ℃ / = / . / - / _ / # / + / ;
del
unicode_names
[
i
]
else
:
return
'Other'
if
len
(
unicode_names
)
==
0
:
return
'Other'
if
len
(
unicode_names
)
==
1
:
return
unicode_names
[
0
]
for
i
in
range
(
len
(
unicode_names
)
-
1
)
:
if
unicode_names
[
i
]
!=
unicode_names
[
i
+
1
]
:
return
'Other'
return
unicode_names
[
0
]
def
compute_char_acc
(
args
):
calculator
=
Calculator
()
cluster_file
=
''
ignore_words
=
set
()
tochar
=
True
verbose
=
1
padding_symbol
=
' '
case_sensitive
=
False
max_words_per_line
=
sys
.
maxsize
split
=
None
if
not
case_sensitive
:
ig
=
set
([
w
.
upper
()
for
w
in
ignore_words
])
ignore_words
=
ig
default_clusters
=
{}
default_words
=
{}
ref_file
=
args
.
val_ref_file
hyp_file
=
args
.
val_hyp_file
rec_set
=
{}
if
split
and
not
case_sensitive
:
newsplit
=
dict
()
for
w
in
split
:
words
=
split
[
w
]
for
i
in
range
(
len
(
words
)):
words
[
i
]
=
words
[
i
].
upper
()
newsplit
[
w
.
upper
()]
=
words
split
=
newsplit
with
codecs
.
open
(
hyp_file
,
'r'
,
'utf-8'
)
as
fh
:
for
line
in
fh
:
if
tochar
:
array
=
characterize
(
line
)
else
:
array
=
line
.
strip
().
split
()
if
len
(
array
)
==
0
:
continue
fid
=
array
[
0
]
rec_set
[
fid
]
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
# compute error rate on the interaction of reference file and hyp file
for
line
in
open
(
ref_file
,
'r'
,
encoding
=
'utf-8'
)
:
if
tochar
:
array
=
characterize
(
line
)
else
:
array
=
line
.
rstrip
(
'
\n
'
).
split
()
if
len
(
array
)
==
0
:
continue
fid
=
array
[
0
]
if
fid
not
in
rec_set
:
continue
lab
=
normalize
(
array
[
1
:],
ignore_words
,
case_sensitive
,
split
)
rec
=
rec_set
[
fid
]
#if verbose:
# print('\nutt: %s' % fid)
for
word
in
rec
+
lab
:
if
word
not
in
default_words
:
default_cluster_name
=
default_cluster
(
word
)
if
default_cluster_name
not
in
default_clusters
:
default_clusters
[
default_cluster_name
]
=
{}
if
word
not
in
default_clusters
[
default_cluster_name
]
:
default_clusters
[
default_cluster_name
][
word
]
=
1
default_words
[
word
]
=
default_cluster_name
result
=
calculator
.
calculate
(
lab
,
rec
)
if
verbose
:
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
else
:
wer
=
0.0
#print('WER: %4.2f %%' % wer, end = ' ')
#print('N=%d C=%d S=%d D=%d I=%d' %
# (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
space
=
{}
space
[
'lab'
]
=
[]
space
[
'rec'
]
=
[]
for
idx
in
range
(
len
(
result
[
'lab'
]))
:
len_lab
=
width
(
result
[
'lab'
][
idx
])
len_rec
=
width
(
result
[
'rec'
][
idx
])
length
=
max
(
len_lab
,
len_rec
)
space
[
'lab'
].
append
(
length
-
len_lab
)
space
[
'rec'
].
append
(
length
-
len_rec
)
upper_lab
=
len
(
result
[
'lab'
])
upper_rec
=
len
(
result
[
'rec'
])
lab1
,
rec1
=
0
,
0
while
lab1
<
upper_lab
or
rec1
<
upper_rec
:
#if verbose > 1:
# print('lab(%s):' % fid.encode('utf-8'), end = ' ')
#else:
# print('lab:', end = ' ')
lab2
=
min
(
upper_lab
,
lab1
+
max_words_per_line
)
for
idx
in
range
(
lab1
,
lab2
):
token
=
result
[
'lab'
][
idx
]
#print('{token}'.format(token = token), end = '')
#for n in range(space['lab'][idx]) :
# print(padding_symbol, end = '')
#print(' ',end='')
#print()
#if verbose > 1:
# print('rec(%s):' % fid.encode('utf-8'), end = ' ')
#else:
# print('rec:', end = ' ')
rec2
=
min
(
upper_rec
,
rec1
+
max_words_per_line
)
for
idx
in
range
(
rec1
,
rec2
):
token
=
result
[
'rec'
][
idx
]
#print('{token}'.format(token = token), end = '')
#for n in range(space['rec'][idx]) :
# print(padding_symbol, end = '')
#print(' ',end='')
#print('\n', end='\n')
lab1
=
lab2
rec1
=
rec2
#if verbose:
# print('===========================================================================')
# print()
result
=
calculator
.
overall
()
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
else
:
wer
=
0.0
#print('Overall -> %4.2f %%' % wer, end = ' ')
#print('N=%d C=%d S=%d D=%d I=%d' %
# (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
#if not verbose:
# print()
char_acc
=
100.0
-
wer
return
char_acc
if
verbose
:
for
cluster_id
in
default_clusters
:
result
=
calculator
.
cluster
([
k
for
k
in
default_clusters
[
cluster_id
]
])
if
result
[
'all'
]
!=
0
:
wer
=
float
(
result
[
'ins'
]
+
result
[
'sub'
]
+
result
[
'del'
])
*
100.0
/
result
[
'all'
]
else
:
wer
=
0.0
#print('%s -> %4.2f %%' % (cluster_id, wer), end = ' ')
#print('N=%d C=%d S=%d D=%d I=%d' %
# (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
#print()
#print('===========================================================================')
examples/aishell/s0/wenet/utils/config.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Shaoshang Qi
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
copy
def
override_config
(
configs
,
override_list
):
new_configs
=
copy
.
deepcopy
(
configs
)
for
item
in
override_list
:
arr
=
item
.
split
()
if
len
(
arr
)
!=
2
:
print
(
f
"the overrive
{
item
}
format not correct, skip it"
)
continue
keys
=
arr
[
0
].
split
(
'.'
)
s_configs
=
new_configs
for
i
,
key
in
enumerate
(
keys
):
if
key
not
in
s_configs
:
print
(
f
"the overrive
{
item
}
format not correct, skip it"
)
if
i
==
len
(
keys
)
-
1
:
param_type
=
type
(
s_configs
[
key
])
if
param_type
!=
bool
:
s_configs
[
key
]
=
param_type
(
arr
[
1
])
else
:
s_configs
[
key
]
=
arr
[
1
]
in
[
'true'
,
'True'
]
print
(
f
"override
{
arr
[
0
]
}
with
{
arr
[
1
]
}
"
)
else
:
s_configs
=
s_configs
[
key
]
return
new_configs
examples/aishell/s0/wenet/utils/ctc_util.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
torch
def
insert_blank
(
label
,
blank_id
=
0
):
"""Insert blank token between every two label token."""
label
=
np
.
expand_dims
(
label
,
1
)
blanks
=
np
.
zeros
((
label
.
shape
[
0
],
1
),
dtype
=
np
.
int64
)
+
blank_id
label
=
np
.
concatenate
([
blanks
,
label
],
axis
=
1
)
label
=
label
.
reshape
(
-
1
)
label
=
np
.
append
(
label
,
label
[
0
])
return
label
def
forced_align
(
ctc_probs
:
torch
.
Tensor
,
y
:
torch
.
Tensor
,
blank_id
=
0
)
->
list
:
"""ctc forced alignment.
Args:
torch.Tensor ctc_probs: hidden state sequence, 2d tensor (T, D)
torch.Tensor y: id sequence tensor 1d tensor (L)
int blank_id: blank symbol index
Returns:
torch.Tensor: alignment result
"""
y_insert_blank
=
insert_blank
(
y
,
blank_id
)
log_alpha
=
torch
.
zeros
((
ctc_probs
.
size
(
0
),
len
(
y_insert_blank
)))
log_alpha
=
log_alpha
-
float
(
'inf'
)
# log of zero
state_path
=
(
torch
.
zeros
(
(
ctc_probs
.
size
(
0
),
len
(
y_insert_blank
)),
dtype
=
torch
.
int16
)
-
1
)
# state path
# init start state
log_alpha
[
0
,
0
]
=
ctc_probs
[
0
][
y_insert_blank
[
0
]]
log_alpha
[
0
,
1
]
=
ctc_probs
[
0
][
y_insert_blank
[
1
]]
for
t
in
range
(
1
,
ctc_probs
.
size
(
0
)):
for
s
in
range
(
len
(
y_insert_blank
)):
if
y_insert_blank
[
s
]
==
blank_id
or
s
<
2
or
y_insert_blank
[
s
]
==
y_insert_blank
[
s
-
2
]:
candidates
=
torch
.
tensor
(
[
log_alpha
[
t
-
1
,
s
],
log_alpha
[
t
-
1
,
s
-
1
]])
prev_state
=
[
s
,
s
-
1
]
else
:
candidates
=
torch
.
tensor
([
log_alpha
[
t
-
1
,
s
],
log_alpha
[
t
-
1
,
s
-
1
],
log_alpha
[
t
-
1
,
s
-
2
],
])
prev_state
=
[
s
,
s
-
1
,
s
-
2
]
log_alpha
[
t
,
s
]
=
torch
.
max
(
candidates
)
+
ctc_probs
[
t
][
y_insert_blank
[
s
]]
state_path
[
t
,
s
]
=
prev_state
[
torch
.
argmax
(
candidates
)]
state_seq
=
-
1
*
torch
.
ones
((
ctc_probs
.
size
(
0
),
1
),
dtype
=
torch
.
int16
)
candidates
=
torch
.
tensor
([
log_alpha
[
-
1
,
len
(
y_insert_blank
)
-
1
],
log_alpha
[
-
1
,
len
(
y_insert_blank
)
-
2
]
])
prev_state
=
[
len
(
y_insert_blank
)
-
1
,
len
(
y_insert_blank
)
-
2
]
state_seq
[
-
1
]
=
prev_state
[
torch
.
argmax
(
candidates
)]
for
t
in
range
(
ctc_probs
.
size
(
0
)
-
2
,
-
1
,
-
1
):
state_seq
[
t
]
=
state_path
[
t
+
1
,
state_seq
[
t
+
1
,
0
]]
output_alignment
=
[]
for
t
in
range
(
0
,
ctc_probs
.
size
(
0
)):
output_alignment
.
append
(
y_insert_blank
[
state_seq
[
t
,
0
]])
return
output_alignment
examples/aishell/s0/wenet/utils/executor.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
logging
from
contextlib
import
nullcontext
# if your python version < 3.7 use the below one
# from contextlib import suppress as nullcontext
import
torch
from
torch.nn.utils
import
clip_grad_norm_
from
wenet.utils.global_vars
import
get_global_steps
,
global_steps_inc
,
get_num_trained_samples
,
num_trained_samples_inc
import
time
class
Executor
:
def
__init__
(
self
):
self
.
step
=
0
def
train
(
self
,
model
,
optimizer
,
scheduler
,
data_loader
,
device
,
writer
,
args
,
scaler
):
''' Train one epoch
'''
model
.
train
()
clip
=
args
.
get
(
'grad_clip'
,
50.0
)
log_interval
=
args
.
get
(
'log_interval'
,
10
)
rank
=
args
.
get
(
'rank'
,
0
)
epoch
=
args
.
get
(
'epoch'
,
0
)
accum_grad
=
args
.
get
(
'accum_grad'
,
1
)
is_distributed
=
args
.
get
(
'is_distributed'
,
True
)
use_amp
=
args
.
get
(
'use_amp'
,
False
)
logging
.
info
(
'using accumulate grad, new batch size is {} times'
' larger than before'
.
format
(
accum_grad
))
if
use_amp
:
assert
scaler
is
not
None
# A context manager to be used in conjunction with an instance of
# torch.nn.parallel.DistributedDataParallel to be able to train
# with uneven inputs across participating processes.
if
isinstance
(
model
,
torch
.
nn
.
parallel
.
DistributedDataParallel
):
model_context
=
model
.
join
else
:
model_context
=
nullcontext
num_seen_utts
=
0
with
model_context
():
for
batch_idx
,
batch
in
enumerate
(
data_loader
):
key
,
feats
,
target
,
feats_lengths
,
target_lengths
=
batch
feats
=
feats
.
to
(
device
)
target
=
target
.
to
(
device
)
feats_lengths
=
feats_lengths
.
to
(
device
)
target_lengths
=
target_lengths
.
to
(
device
)
num_utts
=
target_lengths
.
size
(
0
)
if
num_utts
==
0
:
continue
context
=
None
# Disable gradient synchronizations across DDP processes.
# Within this context, gradients will be accumulated on module
# variables, which will later be synchronized.
if
is_distributed
and
batch_idx
%
accum_grad
!=
0
:
context
=
model
.
no_sync
# Used for single gpu training and DDP gradient synchronization
# processes.
else
:
context
=
nullcontext
with
context
():
# autocast context
# The more details about amp can be found in
# https://pytorch.org/docs/stable/notes/amp_examples.html
with
torch
.
cuda
.
amp
.
autocast
(
scaler
is
not
None
):
loss_dict
=
model
(
feats
,
feats_lengths
,
target
,
target_lengths
)
loss
=
loss_dict
[
'loss'
]
/
accum_grad
if
use_amp
:
scaler
.
scale
(
loss
).
backward
()
else
:
loss
.
backward
()
num_seen_utts
+=
num_utts
global_steps_inc
()
num_trained_samples_inc
(
num_utts
)
if
batch_idx
%
accum_grad
==
0
:
#if rank == 0 and writer is not None:
# writer.add_scalar('train_loss', loss, self.step)
# Use mixed precision training
if
use_amp
:
scaler
.
unscale_
(
optimizer
)
grad_norm
=
clip_grad_norm_
(
model
.
parameters
(),
clip
)
# Must invoke scaler.update() if unscale_() is used in
# the iteration to avoid the following error:
# RuntimeError: unscale_() has already been called
# on this optimizer since the last update().
# We don't check grad here since that if the gradient
# has inf/nan values, scaler.step will skip
# optimizer.step().
scaler
.
step
(
optimizer
)
scaler
.
update
()
else
:
grad_norm
=
clip_grad_norm_
(
model
.
parameters
(),
clip
)
if
torch
.
isfinite
(
grad_norm
):
optimizer
.
step
()
optimizer
.
zero_grad
()
scheduler
.
step
()
self
.
step
+=
1
#if batch_idx % log_interval == 0:
# lr = optimizer.param_groups[0]['lr']
# log_str = 'TRAIN Batch {}/{} loss {:.6f} '.format(
# epoch, batch_idx,
# loss.item() * accum_grad)
# for name, value in loss_dict.items():
# if name != 'loss' and value is not None:
# log_str += '{} {:.6f} '.format(name, value.item())
# log_str += 'lr {:.8f} rank {}'.format(lr, rank)
# logging.debug(log_str)
lr
=
optimizer
.
param_groups
[
0
][
'lr'
]
loss_str
=
"%.4f"
%
(
loss
.
item
()
*
accum_grad
)
global_steps
=
get_global_steps
()
num_trained_samples
=
get_num_trained_samples
()
step_output
=
f
'[PerfLog] {{"event": "STEP_END", "value": {{"epoch":
{
epoch
+
1
}
, "global_steps":
{
global_steps
}
,"loss":
{
loss_str
}
,"num_trained_samples":
{
num_trained_samples
}
, "learning_rate":
{
lr
:.
9
f
}
}}}}'
logging
.
info
(
f
'rank
{
rank
}
: '
+
step_output
)
def
cv
(
self
,
model
,
data_loader
,
device
,
args
):
''' Cross validation on
'''
model
.
eval
()
rank
=
args
.
get
(
'rank'
,
0
)
epoch
=
args
.
get
(
'epoch'
,
0
)
log_interval
=
args
.
get
(
'log_interval'
,
10
)
# in order to avoid division by 0
num_seen_utts
=
1
total_loss
=
0.0
with
torch
.
no_grad
():
for
batch_idx
,
batch
in
enumerate
(
data_loader
):
key
,
feats
,
target
,
feats_lengths
,
target_lengths
=
batch
feats
=
feats
.
to
(
device
)
target
=
target
.
to
(
device
)
feats_lengths
=
feats_lengths
.
to
(
device
)
target_lengths
=
target_lengths
.
to
(
device
)
num_utts
=
target_lengths
.
size
(
0
)
if
num_utts
==
0
:
continue
loss_dict
=
model
(
feats
,
feats_lengths
,
target
,
target_lengths
)
loss
=
loss_dict
[
'loss'
]
if
torch
.
isfinite
(
loss
):
num_seen_utts
+=
num_utts
total_loss
+=
loss
.
item
()
*
num_utts
if
batch_idx
%
log_interval
==
0
:
log_str
=
'CV Batch {}/{} loss {:.6f} '
.
format
(
epoch
,
batch_idx
,
loss
.
item
())
for
name
,
value
in
loss_dict
.
items
():
if
name
!=
'loss'
and
value
is
not
None
:
log_str
+=
'{} {:.6f} '
.
format
(
name
,
value
.
item
())
log_str
+=
'history loss {:.6f}'
.
format
(
total_loss
/
num_seen_utts
)
log_str
+=
' rank {}'
.
format
(
rank
)
logging
.
debug
(
log_str
)
return
total_loss
,
num_seen_utts
examples/aishell/s0/wenet/utils/file_utils.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
re
def
read_lists
(
list_file
):
lists
=
[]
with
open
(
list_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
lists
.
append
(
line
.
strip
())
return
lists
def
read_non_lang_symbols
(
non_lang_sym_path
):
"""read non-linguistic symbol from file.
The file format is like below:
{NOISE}
\n
{BRK}
\n
...
Args:
non_lang_sym_path: non-linguistic symbol file path, None means no any
syms.
"""
if
non_lang_sym_path
is
None
:
return
None
else
:
syms
=
read_lists
(
non_lang_sym_path
)
non_lang_syms_pattern
=
re
.
compile
(
r
"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})"
)
for
sym
in
syms
:
if
non_lang_syms_pattern
.
fullmatch
(
sym
)
is
None
:
class
BadSymbolFormat
(
Exception
):
pass
raise
BadSymbolFormat
(
"Non-linguistic symbols should be "
"formatted in {xxx}/<xxx>/[xxx], consider"
" modify '%s' to meet the requirment. "
"More details can be found in discussions here : "
"https://github.com/wenet-e2e/wenet/pull/819"
%
(
sym
))
return
syms
def
read_symbol_table
(
symbol_table_file
):
symbol_table
=
{}
with
open
(
symbol_table_file
,
'r'
,
encoding
=
'utf8'
)
as
fin
:
for
line
in
fin
:
arr
=
line
.
strip
().
split
()
assert
len
(
arr
)
==
2
symbol_table
[
arr
[
0
]]
=
int
(
arr
[
1
])
return
symbol_table
examples/aishell/s0/wenet/utils/global_vars.py
0 → 100644
View file @
a7785cc6
# Global variables and their getters and setters for cross python packages
# global_steps
global_steps
=
0
def
get_global_steps
():
return
global_steps
def
set_global_steps
(
value
):
global
global_steps
global_steps
=
value
def
global_steps_inc
():
global
global_steps
global_steps
+=
1
# num_trained_samples
num_trained_samples
=
0
def
get_num_trained_samples
():
return
num_trained_samples
def
set_num_trained_samples
(
value
):
global
num_trained_samples
num_trained_samples
=
value
def
num_trained_samples_inc
(
value
):
global
num_trained_samples
num_trained_samples
+=
value
examples/aishell/s0/wenet/utils/init_model.py
0 → 100644
View file @
a7785cc6
# Copyright (c) 2022 Binbin Zhang (binbzha@qq.com)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
torch
from
wenet.transducer.joint
import
TransducerJoint
from
wenet.transducer.predictor
import
(
ConvPredictor
,
EmbeddingPredictor
,
RNNPredictor
)
from
wenet.transducer.transducer
import
Transducer
from
wenet.transformer.asr_model
import
ASRModel
from
wenet.transformer.cmvn
import
GlobalCMVN
from
wenet.transformer.ctc
import
CTC
from
wenet.transformer.decoder
import
BiTransformerDecoder
,
TransformerDecoder
from
wenet.transformer.encoder
import
ConformerEncoder
,
TransformerEncoder
from
wenet.squeezeformer.encoder
import
SqueezeformerEncoder
from
wenet.efficient_conformer.encoder
import
EfficientConformerEncoder
from
wenet.utils.cmvn
import
load_cmvn
def
init_model
(
configs
):
if
configs
[
'cmvn_file'
]
is
not
None
:
mean
,
istd
=
load_cmvn
(
configs
[
'cmvn_file'
],
configs
[
'is_json_cmvn'
])
global_cmvn
=
GlobalCMVN
(
torch
.
from_numpy
(
mean
).
float
(),
torch
.
from_numpy
(
istd
).
float
())
else
:
global_cmvn
=
None
input_dim
=
configs
[
'input_dim'
]
vocab_size
=
configs
[
'output_dim'
]
encoder_type
=
configs
.
get
(
'encoder'
,
'conformer'
)
decoder_type
=
configs
.
get
(
'decoder'
,
'bitransformer'
)
if
encoder_type
==
'conformer'
:
encoder
=
ConformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
elif
encoder_type
==
'squeezeformer'
:
encoder
=
SqueezeformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
elif
encoder_type
==
'efficientConformer'
:
encoder
=
EfficientConformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
],
**
configs
[
'encoder_conf'
][
'efficient_conf'
]
if
'efficient_conf'
in
configs
[
'encoder_conf'
]
else
{})
else
:
encoder
=
TransformerEncoder
(
input_dim
,
global_cmvn
=
global_cmvn
,
**
configs
[
'encoder_conf'
])
if
decoder_type
==
'transformer'
:
decoder
=
TransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
else
:
assert
0.0
<
configs
[
'model_conf'
][
'reverse_weight'
]
<
1.0
assert
configs
[
'decoder_conf'
][
'r_num_blocks'
]
>
0
decoder
=
BiTransformerDecoder
(
vocab_size
,
encoder
.
output_size
(),
**
configs
[
'decoder_conf'
])
ctc
=
CTC
(
vocab_size
,
encoder
.
output_size
())
# Init joint CTC/Attention or Transducer model
if
'predictor'
in
configs
:
predictor_type
=
configs
.
get
(
'predictor'
,
'rnn'
)
if
predictor_type
==
'rnn'
:
predictor
=
RNNPredictor
(
vocab_size
,
**
configs
[
'predictor_conf'
])
elif
predictor_type
==
'embedding'
:
predictor
=
EmbeddingPredictor
(
vocab_size
,
**
configs
[
'predictor_conf'
])
configs
[
'predictor_conf'
][
'output_size'
]
=
configs
[
'predictor_conf'
][
'embed_size'
]
elif
predictor_type
==
'conv'
:
predictor
=
ConvPredictor
(
vocab_size
,
**
configs
[
'predictor_conf'
])
configs
[
'predictor_conf'
][
'output_size'
]
=
configs
[
'predictor_conf'
][
'embed_size'
]
else
:
raise
NotImplementedError
(
"only rnn, embedding and conv type support now"
)
configs
[
'joint_conf'
][
'enc_output_size'
]
=
configs
[
'encoder_conf'
][
'output_size'
]
configs
[
'joint_conf'
][
'pred_output_size'
]
=
configs
[
'predictor_conf'
][
'output_size'
]
joint
=
TransducerJoint
(
vocab_size
,
**
configs
[
'joint_conf'
])
model
=
Transducer
(
vocab_size
=
vocab_size
,
blank
=
0
,
predictor
=
predictor
,
encoder
=
encoder
,
attention_decoder
=
decoder
,
joint
=
joint
,
ctc
=
ctc
,
**
configs
[
'model_conf'
])
else
:
model
=
ASRModel
(
vocab_size
=
vocab_size
,
encoder
=
encoder
,
decoder
=
decoder
,
ctc
=
ctc
,
**
configs
[
'model_conf'
])
return
model
Prev
1
…
4
5
6
7
8
9
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment