Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
sunzhq2
yidong-infer
Commits
60a2c57a
Commit
60a2c57a
authored
Jan 27, 2026
by
sunzhq2
Committed by
xuxo
Jan 27, 2026
Browse files
update conformer
parent
4a699441
Changes
216
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
5873 additions
and
0 deletions
+5873
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py
...b/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py
+172
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/dnn_wpe.py
...uild/lib/espnet/nets/pytorch_backend/frontends/dnn_wpe.py
+93
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/feature_transform.py
...spnet/nets/pytorch_backend/frontends/feature_transform.py
+261
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/frontend.py
...ild/lib/espnet/nets/pytorch_backend/frontends/frontend.py
+148
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/mask_estimator.py
...b/espnet/nets/pytorch_backend/frontends/mask_estimator.py
+76
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/gtn_ctc.py
...20240621/build/lib/espnet/nets/pytorch_backend/gtn_ctc.py
+118
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/initialization.py
...1/build/lib/espnet/nets/pytorch_backend/initialization.py
+55
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/__init__.py
...0621/build/lib/espnet/nets/pytorch_backend/lm/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/default.py
...40621/build/lib/espnet/nets/pytorch_backend/lm/default.py
+429
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/seq_rnn.py
...40621/build/lib/espnet/nets/pytorch_backend/lm/seq_rnn.py
+178
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/transformer.py
...1/build/lib/espnet/nets/pytorch_backend/lm/transformer.py
+250
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/maskctc/__init__.py
...build/lib/espnet/nets/pytorch_backend/maskctc/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/maskctc/add_mask_token.py
...lib/espnet/nets/pytorch_backend/maskctc/add_mask_token.py
+39
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/maskctc/mask.py
...621/build/lib/espnet/nets/pytorch_backend/maskctc/mask.py
+24
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/nets_utils.py
...40621/build/lib/espnet/nets/pytorch_backend/nets_utils.py
+503
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/__init__.py
...621/build/lib/espnet/nets/pytorch_backend/rnn/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/argument.py
...621/build/lib/espnet/nets/pytorch_backend/rnn/argument.py
+156
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/attentions.py
...1/build/lib/espnet/nets/pytorch_backend/rnn/attentions.py
+1791
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/decoders.py
...621/build/lib/espnet/nets/pytorch_backend/rnn/decoders.py
+1206
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/encoders.py
...621/build/lib/espnet/nets/pytorch_backend/rnn/encoders.py
+371
-0
No files found.
Too many changes to show.
To preserve performance only
216 of 216+
files are displayed.
Plain diff
Email patch
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/dnn_beamformer.py
0 → 100644
View file @
60a2c57a
"""DNN beamformer module."""
from
typing
import
Tuple
import
torch
from
torch.nn
import
functional
as
F
from
torch_complex.tensor
import
ComplexTensor
from
espnet.nets.pytorch_backend.frontends.beamformer
import
(
# noqa: H301
apply_beamforming_vector
,
get_mvdr_vector
,
get_power_spectral_density_matrix
,
)
from
espnet.nets.pytorch_backend.frontends.mask_estimator
import
MaskEstimator
class
DNN_Beamformer
(
torch
.
nn
.
Module
):
"""DNN mask based Beamformer
Citation:
Multichannel End-to-end Speech Recognition; T. Ochiai et al., 2017;
https://arxiv.org/abs/1703.04783
"""
def
__init__
(
self
,
bidim
,
btype
=
"blstmp"
,
blayers
=
3
,
bunits
=
300
,
bprojs
=
320
,
bnmask
=
2
,
dropout_rate
=
0.0
,
badim
=
320
,
ref_channel
:
int
=
-
1
,
beamformer_type
=
"mvdr"
,
):
super
().
__init__
()
self
.
mask
=
MaskEstimator
(
btype
,
bidim
,
blayers
,
bunits
,
bprojs
,
dropout_rate
,
nmask
=
bnmask
)
self
.
ref
=
AttentionReference
(
bidim
,
badim
)
self
.
ref_channel
=
ref_channel
self
.
nmask
=
bnmask
if
beamformer_type
!=
"mvdr"
:
raise
ValueError
(
"Not supporting beamformer_type={}"
.
format
(
beamformer_type
)
)
self
.
beamformer_type
=
beamformer_type
def
forward
(
self
,
data
:
ComplexTensor
,
ilens
:
torch
.
LongTensor
)
->
Tuple
[
ComplexTensor
,
torch
.
LongTensor
,
ComplexTensor
]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq
Args:
data (ComplexTensor): (B, T, C, F)
ilens (torch.Tensor): (B,)
Returns:
enhanced (ComplexTensor): (B, T, F)
ilens (torch.Tensor): (B,)
"""
def
apply_beamforming
(
data
,
ilens
,
psd_speech
,
psd_noise
):
# u: (B, C)
if
self
.
ref_channel
<
0
:
u
,
_
=
self
.
ref
(
psd_speech
,
ilens
)
else
:
# (optional) Create onehot vector for fixed reference microphone
u
=
torch
.
zeros
(
*
(
data
.
size
()[:
-
3
]
+
(
data
.
size
(
-
2
),)),
device
=
data
.
device
)
u
[...,
self
.
ref_channel
].
fill_
(
1
)
ws
=
get_mvdr_vector
(
psd_speech
,
psd_noise
,
u
)
enhanced
=
apply_beamforming_vector
(
ws
,
data
)
return
enhanced
,
ws
# data (B, T, C, F) -> (B, F, C, T)
data
=
data
.
permute
(
0
,
3
,
2
,
1
)
# mask: (B, F, C, T)
masks
,
_
=
self
.
mask
(
data
,
ilens
)
assert
self
.
nmask
==
len
(
masks
)
if
self
.
nmask
==
2
:
# (mask_speech, mask_noise)
mask_speech
,
mask_noise
=
masks
psd_speech
=
get_power_spectral_density_matrix
(
data
,
mask_speech
)
psd_noise
=
get_power_spectral_density_matrix
(
data
,
mask_noise
)
enhanced
,
ws
=
apply_beamforming
(
data
,
ilens
,
psd_speech
,
psd_noise
)
# (..., F, T) -> (..., T, F)
enhanced
=
enhanced
.
transpose
(
-
1
,
-
2
)
mask_speech
=
mask_speech
.
transpose
(
-
1
,
-
3
)
else
:
# multi-speaker case: (mask_speech1, ..., mask_noise)
mask_speech
=
list
(
masks
[:
-
1
])
mask_noise
=
masks
[
-
1
]
psd_speeches
=
[
get_power_spectral_density_matrix
(
data
,
mask
)
for
mask
in
mask_speech
]
psd_noise
=
get_power_spectral_density_matrix
(
data
,
mask_noise
)
enhanced
=
[]
ws
=
[]
for
i
in
range
(
self
.
nmask
-
1
):
psd_speech
=
psd_speeches
.
pop
(
i
)
# treat all other speakers' psd_speech as noises
enh
,
w
=
apply_beamforming
(
data
,
ilens
,
psd_speech
,
sum
(
psd_speeches
)
+
psd_noise
)
psd_speeches
.
insert
(
i
,
psd_speech
)
# (..., F, T) -> (..., T, F)
enh
=
enh
.
transpose
(
-
1
,
-
2
)
mask_speech
[
i
]
=
mask_speech
[
i
].
transpose
(
-
1
,
-
3
)
enhanced
.
append
(
enh
)
ws
.
append
(
w
)
return
enhanced
,
ilens
,
mask_speech
class
AttentionReference
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
bidim
,
att_dim
):
super
().
__init__
()
self
.
mlp_psd
=
torch
.
nn
.
Linear
(
bidim
,
att_dim
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
def
forward
(
self
,
psd_in
:
ComplexTensor
,
ilens
:
torch
.
LongTensor
,
scaling
:
float
=
2.0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
LongTensor
]:
"""The forward function
Args:
psd_in (ComplexTensor): (B, F, C, C)
ilens (torch.Tensor): (B,)
scaling (float):
Returns:
u (torch.Tensor): (B, C)
ilens (torch.Tensor): (B,)
"""
B
,
_
,
C
=
psd_in
.
size
()[:
3
]
assert
psd_in
.
size
(
2
)
==
psd_in
.
size
(
3
),
psd_in
.
size
()
# psd_in: (B, F, C, C)
psd
=
psd_in
.
masked_fill
(
torch
.
eye
(
C
,
dtype
=
torch
.
bool
,
device
=
psd_in
.
device
),
0
)
# psd: (B, F, C, C) -> (B, C, F)
psd
=
(
psd
.
sum
(
dim
=-
1
)
/
(
C
-
1
)).
transpose
(
-
1
,
-
2
)
# Calculate amplitude
psd_feat
=
(
psd
.
real
**
2
+
psd
.
imag
**
2
)
**
0.5
# (B, C, F) -> (B, C, F2)
mlp_psd
=
self
.
mlp_psd
(
psd_feat
)
# (B, C, F2) -> (B, C, 1) -> (B, C)
e
=
self
.
gvec
(
torch
.
tanh
(
mlp_psd
)).
squeeze
(
-
1
)
u
=
F
.
softmax
(
scaling
*
e
,
dim
=-
1
)
return
u
,
ilens
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/dnn_wpe.py
0 → 100644
View file @
60a2c57a
from
typing
import
Tuple
import
torch
from
pytorch_wpe
import
wpe_one_iteration
from
torch_complex.tensor
import
ComplexTensor
from
espnet.nets.pytorch_backend.frontends.mask_estimator
import
MaskEstimator
from
espnet.nets.pytorch_backend.nets_utils
import
make_pad_mask
class
DNN_WPE
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
wtype
:
str
=
"blstmp"
,
widim
:
int
=
257
,
wlayers
:
int
=
3
,
wunits
:
int
=
300
,
wprojs
:
int
=
320
,
dropout_rate
:
float
=
0.0
,
taps
:
int
=
5
,
delay
:
int
=
3
,
use_dnn_mask
:
bool
=
True
,
iterations
:
int
=
1
,
normalization
:
bool
=
False
,
):
super
().
__init__
()
self
.
iterations
=
iterations
self
.
taps
=
taps
self
.
delay
=
delay
self
.
normalization
=
normalization
self
.
use_dnn_mask
=
use_dnn_mask
self
.
inverse_power
=
True
if
self
.
use_dnn_mask
:
self
.
mask_est
=
MaskEstimator
(
wtype
,
widim
,
wlayers
,
wunits
,
wprojs
,
dropout_rate
,
nmask
=
1
)
def
forward
(
self
,
data
:
ComplexTensor
,
ilens
:
torch
.
LongTensor
)
->
Tuple
[
ComplexTensor
,
torch
.
LongTensor
,
ComplexTensor
]:
"""The forward function
Notation:
B: Batch
C: Channel
T: Time or Sequence length
F: Freq or Some dimension of the feature vector
Args:
data: (B, C, T, F)
ilens: (B,)
Returns:
data: (B, C, T, F)
ilens: (B,)
"""
# (B, T, C, F) -> (B, F, C, T)
enhanced
=
data
=
data
.
permute
(
0
,
3
,
2
,
1
)
mask
=
None
for
i
in
range
(
self
.
iterations
):
# Calculate power: (..., C, T)
power
=
enhanced
.
real
**
2
+
enhanced
.
imag
**
2
if
i
==
0
and
self
.
use_dnn_mask
:
# mask: (B, F, C, T)
(
mask
,),
_
=
self
.
mask_est
(
enhanced
,
ilens
)
if
self
.
normalization
:
# Normalize along T
mask
=
mask
/
mask
.
sum
(
dim
=-
1
)[...,
None
]
# (..., C, T) * (..., C, T) -> (..., C, T)
power
=
power
*
mask
# Averaging along the channel axis: (..., C, T) -> (..., T)
power
=
power
.
mean
(
dim
=-
2
)
# enhanced: (..., C, T) -> (..., C, T)
enhanced
=
wpe_one_iteration
(
data
.
contiguous
(),
power
,
taps
=
self
.
taps
,
delay
=
self
.
delay
,
inverse_power
=
self
.
inverse_power
,
)
enhanced
.
masked_fill_
(
make_pad_mask
(
ilens
,
enhanced
.
real
),
0
)
# (B, F, C, T) -> (B, T, C, F)
enhanced
=
enhanced
.
permute
(
0
,
3
,
2
,
1
)
if
mask
is
not
None
:
mask
=
mask
.
transpose
(
-
1
,
-
3
)
return
enhanced
,
ilens
,
mask
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/feature_transform.py
0 → 100644
View file @
60a2c57a
from
typing
import
List
,
Tuple
,
Union
import
librosa
import
numpy
as
np
import
torch
from
torch_complex.tensor
import
ComplexTensor
from
espnet.nets.pytorch_backend.nets_utils
import
make_pad_mask
class
FeatureTransform
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
# Mel options,
fs
:
int
=
16000
,
n_fft
:
int
=
512
,
n_mels
:
int
=
80
,
fmin
:
float
=
0.0
,
fmax
:
float
=
None
,
# Normalization
stats_file
:
str
=
None
,
apply_uttmvn
:
bool
=
True
,
uttmvn_norm_means
:
bool
=
True
,
uttmvn_norm_vars
:
bool
=
False
,
):
super
().
__init__
()
self
.
apply_uttmvn
=
apply_uttmvn
self
.
logmel
=
LogMel
(
fs
=
fs
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
)
self
.
stats_file
=
stats_file
if
stats_file
is
not
None
:
self
.
global_mvn
=
GlobalMVN
(
stats_file
)
else
:
self
.
global_mvn
=
None
if
self
.
apply_uttmvn
is
not
None
:
self
.
uttmvn
=
UtteranceMVN
(
norm_means
=
uttmvn_norm_means
,
norm_vars
=
uttmvn_norm_vars
)
else
:
self
.
uttmvn
=
None
def
forward
(
self
,
x
:
ComplexTensor
,
ilens
:
Union
[
torch
.
LongTensor
,
np
.
ndarray
,
List
[
int
]]
)
->
Tuple
[
torch
.
Tensor
,
torch
.
LongTensor
]:
# (B, T, F) or (B, T, C, F)
if
x
.
dim
()
not
in
(
3
,
4
):
raise
ValueError
(
f
"Input dim must be 3 or 4:
{
x
.
dim
()
}
"
)
if
not
torch
.
is_tensor
(
ilens
):
ilens
=
torch
.
from_numpy
(
np
.
asarray
(
ilens
)).
to
(
x
.
device
)
if
x
.
dim
()
==
4
:
# h: (B, T, C, F) -> h: (B, T, F)
if
self
.
training
:
# Select 1ch randomly
ch
=
np
.
random
.
randint
(
x
.
size
(
2
))
h
=
x
[:,
:,
ch
,
:]
else
:
# Use the first channel
h
=
x
[:,
:,
0
,
:]
else
:
h
=
x
# h: ComplexTensor(B, T, F) -> torch.Tensor(B, T, F)
h
=
h
.
real
**
2
+
h
.
imag
**
2
h
,
_
=
self
.
logmel
(
h
,
ilens
)
if
self
.
stats_file
is
not
None
:
h
,
_
=
self
.
global_mvn
(
h
,
ilens
)
if
self
.
apply_uttmvn
:
h
,
_
=
self
.
uttmvn
(
h
,
ilens
)
return
h
,
ilens
class
LogMel
(
torch
.
nn
.
Module
):
"""Convert STFT to fbank feats
The arguments is same as librosa.filters.mel
Args:
fs: number > 0 [scalar] sampling rate of the incoming signal
n_fft: int > 0 [scalar] number of FFT components
n_mels: int > 0 [scalar] number of Mel bands to generate
fmin: float >= 0 [scalar] lowest frequency (in Hz)
fmax: float >= 0 [scalar] highest frequency (in Hz).
If `None`, use `fmax = fs / 2.0`
htk: use HTK formula instead of Slaney
norm: {None, 1, np.inf} [scalar]
if 1, divide the triangular mel weights by the width of the mel band
(area normalization). Otherwise, leave all the triangles aiming for
a peak value of 1.0
"""
def
__init__
(
self
,
fs
:
int
=
16000
,
n_fft
:
int
=
512
,
n_mels
:
int
=
80
,
fmin
:
float
=
0.0
,
fmax
:
float
=
None
,
htk
:
bool
=
False
,
norm
=
1
,
):
super
().
__init__
()
_mel_options
=
dict
(
sr
=
fs
,
n_fft
=
n_fft
,
n_mels
=
n_mels
,
fmin
=
fmin
,
fmax
=
fmax
,
htk
=
htk
,
norm
=
norm
)
self
.
mel_options
=
_mel_options
# Note(kamo): The mel matrix of librosa is different from kaldi.
melmat
=
librosa
.
filters
.
mel
(
**
_mel_options
)
# melmat: (D2, D1) -> (D1, D2)
self
.
register_buffer
(
"melmat"
,
torch
.
from_numpy
(
melmat
.
T
).
float
())
def
extra_repr
(
self
):
return
", "
.
join
(
f
"
{
k
}
=
{
v
}
"
for
k
,
v
in
self
.
mel_options
.
items
())
def
forward
(
self
,
feat
:
torch
.
Tensor
,
ilens
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
LongTensor
]:
# feat: (B, T, D1) x melmat: (D1, D2) -> mel_feat: (B, T, D2)
mel_feat
=
torch
.
matmul
(
feat
,
self
.
melmat
)
logmel_feat
=
(
mel_feat
+
1e-20
).
log
()
# Zero padding
logmel_feat
=
logmel_feat
.
masked_fill
(
make_pad_mask
(
ilens
,
logmel_feat
,
1
),
0.0
)
return
logmel_feat
,
ilens
class
GlobalMVN
(
torch
.
nn
.
Module
):
"""Apply global mean and variance normalization
Args:
stats_file(str): npy file of 1-dim array or text file.
From the _first element to
the {(len(array) - 1) / 2}th element are treated as
the sum of features,
and the rest excluding the last elements are
treated as the sum of the square value of features,
and the last elements eqauls to the number of samples.
std_floor(float):
"""
def
__init__
(
self
,
stats_file
:
str
,
norm_means
:
bool
=
True
,
norm_vars
:
bool
=
True
,
eps
:
float
=
1.0e-20
,
):
super
().
__init__
()
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
stats_file
=
stats_file
stats
=
np
.
load
(
stats_file
)
stats
=
stats
.
astype
(
float
)
assert
(
len
(
stats
)
-
1
)
%
2
==
0
,
stats
.
shape
count
=
stats
.
flatten
()[
-
1
]
mean
=
stats
[:
(
len
(
stats
)
-
1
)
//
2
]
/
count
var
=
stats
[(
len
(
stats
)
-
1
)
//
2
:
-
1
]
/
count
-
mean
*
mean
std
=
np
.
maximum
(
np
.
sqrt
(
var
),
eps
)
self
.
register_buffer
(
"bias"
,
torch
.
from_numpy
(
-
mean
.
astype
(
np
.
float32
)))
self
.
register_buffer
(
"scale"
,
torch
.
from_numpy
(
1
/
std
.
astype
(
np
.
float32
)))
def
extra_repr
(
self
):
return
(
f
"stats_file=
{
self
.
stats_file
}
, "
f
"norm_means=
{
self
.
norm_means
}
, norm_vars=
{
self
.
norm_vars
}
"
)
def
forward
(
self
,
x
:
torch
.
Tensor
,
ilens
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
LongTensor
]:
# feat: (B, T, D)
if
self
.
norm_means
:
x
+=
self
.
bias
.
type_as
(
x
)
x
.
masked_fill
(
make_pad_mask
(
ilens
,
x
,
1
),
0.0
)
if
self
.
norm_vars
:
x
*=
self
.
scale
.
type_as
(
x
)
return
x
,
ilens
class
UtteranceMVN
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
norm_means
:
bool
=
True
,
norm_vars
:
bool
=
False
,
eps
:
float
=
1.0e-20
):
super
().
__init__
()
self
.
norm_means
=
norm_means
self
.
norm_vars
=
norm_vars
self
.
eps
=
eps
def
extra_repr
(
self
):
return
f
"norm_means=
{
self
.
norm_means
}
, norm_vars=
{
self
.
norm_vars
}
"
def
forward
(
self
,
x
:
torch
.
Tensor
,
ilens
:
torch
.
LongTensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
LongTensor
]:
return
utterance_mvn
(
x
,
ilens
,
norm_means
=
self
.
norm_means
,
norm_vars
=
self
.
norm_vars
,
eps
=
self
.
eps
)
def
utterance_mvn
(
x
:
torch
.
Tensor
,
ilens
:
torch
.
LongTensor
,
norm_means
:
bool
=
True
,
norm_vars
:
bool
=
False
,
eps
:
float
=
1.0e-20
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
LongTensor
]:
"""Apply utterance mean and variance normalization
Args:
x: (B, T, D), assumed zero padded
ilens: (B, T, D)
norm_means:
norm_vars:
eps:
"""
ilens_
=
ilens
.
type_as
(
x
)
# mean: (B, D)
mean
=
x
.
sum
(
dim
=
1
)
/
ilens_
[:,
None
]
if
norm_means
:
x
-=
mean
[:,
None
,
:]
x_
=
x
else
:
x_
=
x
-
mean
[:,
None
,
:]
# Zero padding
x_
.
masked_fill
(
make_pad_mask
(
ilens
,
x_
,
1
),
0.0
)
if
norm_vars
:
var
=
x_
.
pow
(
2
).
sum
(
dim
=
1
)
/
ilens_
[:,
None
]
var
=
torch
.
clamp
(
var
,
min
=
eps
)
x
/=
var
.
sqrt
()[:,
None
,
:]
x_
=
x
return
x_
,
ilens
def
feature_transform_for
(
args
,
n_fft
):
return
FeatureTransform
(
# Mel options,
fs
=
args
.
fbank_fs
,
n_fft
=
n_fft
,
n_mels
=
args
.
n_mels
,
fmin
=
args
.
fbank_fmin
,
fmax
=
args
.
fbank_fmax
,
# Normalization
stats_file
=
args
.
stats_file
,
apply_uttmvn
=
args
.
apply_uttmvn
,
uttmvn_norm_means
=
args
.
uttmvn_norm_means
,
uttmvn_norm_vars
=
args
.
uttmvn_norm_vars
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/frontend.py
0 → 100644
View file @
60a2c57a
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
numpy
import
torch
import
torch.nn
as
nn
from
torch_complex.tensor
import
ComplexTensor
from
espnet.nets.pytorch_backend.frontends.dnn_beamformer
import
DNN_Beamformer
from
espnet.nets.pytorch_backend.frontends.dnn_wpe
import
DNN_WPE
class
Frontend
(
nn
.
Module
):
def
__init__
(
self
,
idim
:
int
,
# WPE options
use_wpe
:
bool
=
False
,
wtype
:
str
=
"blstmp"
,
wlayers
:
int
=
3
,
wunits
:
int
=
300
,
wprojs
:
int
=
320
,
wdropout_rate
:
float
=
0.0
,
taps
:
int
=
5
,
delay
:
int
=
3
,
use_dnn_mask_for_wpe
:
bool
=
True
,
# Beamformer options
use_beamformer
:
bool
=
False
,
btype
:
str
=
"blstmp"
,
blayers
:
int
=
3
,
bunits
:
int
=
300
,
bprojs
:
int
=
320
,
bnmask
:
int
=
2
,
badim
:
int
=
320
,
ref_channel
:
int
=
-
1
,
bdropout_rate
=
0.0
,
):
super
().
__init__
()
self
.
use_beamformer
=
use_beamformer
self
.
use_wpe
=
use_wpe
self
.
use_dnn_mask_for_wpe
=
use_dnn_mask_for_wpe
# use frontend for all the data,
# e.g. in the case of multi-speaker speech separation
self
.
use_frontend_for_all
=
bnmask
>
2
if
self
.
use_wpe
:
if
self
.
use_dnn_mask_for_wpe
:
# Use DNN for power estimation
# (Not observed significant gains)
iterations
=
1
else
:
# Performing as conventional WPE, without DNN Estimator
iterations
=
2
self
.
wpe
=
DNN_WPE
(
wtype
=
wtype
,
widim
=
idim
,
wunits
=
wunits
,
wprojs
=
wprojs
,
wlayers
=
wlayers
,
taps
=
taps
,
delay
=
delay
,
dropout_rate
=
wdropout_rate
,
iterations
=
iterations
,
use_dnn_mask
=
use_dnn_mask_for_wpe
,
)
else
:
self
.
wpe
=
None
if
self
.
use_beamformer
:
self
.
beamformer
=
DNN_Beamformer
(
btype
=
btype
,
bidim
=
idim
,
bunits
=
bunits
,
bprojs
=
bprojs
,
blayers
=
blayers
,
bnmask
=
bnmask
,
dropout_rate
=
bdropout_rate
,
badim
=
badim
,
ref_channel
=
ref_channel
,
)
else
:
self
.
beamformer
=
None
def
forward
(
self
,
x
:
ComplexTensor
,
ilens
:
Union
[
torch
.
LongTensor
,
numpy
.
ndarray
,
List
[
int
]]
)
->
Tuple
[
ComplexTensor
,
torch
.
LongTensor
,
Optional
[
ComplexTensor
]]:
assert
len
(
x
)
==
len
(
ilens
),
(
len
(
x
),
len
(
ilens
))
# (B, T, F) or (B, T, C, F)
if
x
.
dim
()
not
in
(
3
,
4
):
raise
ValueError
(
f
"Input dim must be 3 or 4:
{
x
.
dim
()
}
"
)
if
not
torch
.
is_tensor
(
ilens
):
ilens
=
torch
.
from_numpy
(
numpy
.
asarray
(
ilens
)).
to
(
x
.
device
)
mask
=
None
h
=
x
if
h
.
dim
()
==
4
:
if
self
.
training
:
choices
=
[(
False
,
False
)]
if
not
self
.
use_frontend_for_all
else
[]
if
self
.
use_wpe
:
choices
.
append
((
True
,
False
))
if
self
.
use_beamformer
:
choices
.
append
((
False
,
True
))
use_wpe
,
use_beamformer
=
choices
[
numpy
.
random
.
randint
(
len
(
choices
))]
else
:
use_wpe
=
self
.
use_wpe
use_beamformer
=
self
.
use_beamformer
# 1. WPE
if
use_wpe
:
# h: (B, T, C, F) -> h: (B, T, C, F)
h
,
ilens
,
mask
=
self
.
wpe
(
h
,
ilens
)
# 2. Beamformer
if
use_beamformer
:
# h: (B, T, C, F) -> h: (B, T, F)
h
,
ilens
,
mask
=
self
.
beamformer
(
h
,
ilens
)
return
h
,
ilens
,
mask
def
frontend_for
(
args
,
idim
):
return
Frontend
(
idim
=
idim
,
# WPE options
use_wpe
=
args
.
use_wpe
,
wtype
=
args
.
wtype
,
wlayers
=
args
.
wlayers
,
wunits
=
args
.
wunits
,
wprojs
=
args
.
wprojs
,
wdropout_rate
=
args
.
wdropout_rate
,
taps
=
args
.
wpe_taps
,
delay
=
args
.
wpe_delay
,
use_dnn_mask_for_wpe
=
args
.
use_dnn_mask_for_wpe
,
# Beamformer options
use_beamformer
=
args
.
use_beamformer
,
btype
=
args
.
btype
,
blayers
=
args
.
blayers
,
bunits
=
args
.
bunits
,
bprojs
=
args
.
bprojs
,
bnmask
=
args
.
bnmask
,
badim
=
args
.
badim
,
ref_channel
=
args
.
ref_channel
,
bdropout_rate
=
args
.
bdropout_rate
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/frontends/mask_estimator.py
0 → 100644
View file @
60a2c57a
from
typing
import
Tuple
import
numpy
as
np
import
torch
from
torch.nn
import
functional
as
F
from
torch_complex.tensor
import
ComplexTensor
from
espnet.nets.pytorch_backend.nets_utils
import
make_pad_mask
from
espnet.nets.pytorch_backend.rnn.encoders
import
RNN
,
RNNP
class
MaskEstimator
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
type
,
idim
,
layers
,
units
,
projs
,
dropout
,
nmask
=
1
):
super
().
__init__
()
subsample
=
np
.
ones
(
layers
+
1
,
dtype
=
np
.
int64
)
typ
=
type
.
lstrip
(
"vgg"
).
rstrip
(
"p"
)
if
type
[
-
1
]
==
"p"
:
self
.
brnn
=
RNNP
(
idim
,
layers
,
units
,
projs
,
subsample
,
dropout
,
typ
=
typ
)
else
:
self
.
brnn
=
RNN
(
idim
,
layers
,
units
,
projs
,
dropout
,
typ
=
typ
)
self
.
type
=
type
self
.
nmask
=
nmask
self
.
linears
=
torch
.
nn
.
ModuleList
(
[
torch
.
nn
.
Linear
(
projs
,
idim
)
for
_
in
range
(
nmask
)]
)
def
forward
(
self
,
xs
:
ComplexTensor
,
ilens
:
torch
.
LongTensor
)
->
Tuple
[
Tuple
[
torch
.
Tensor
,
...],
torch
.
LongTensor
]:
"""The forward function
Args:
xs: (B, F, C, T)
ilens: (B,)
Returns:
hs (torch.Tensor): The hidden vector (B, F, C, T)
masks: A tuple of the masks. (B, F, C, T)
ilens: (B,)
"""
assert
xs
.
size
(
0
)
==
ilens
.
size
(
0
),
(
xs
.
size
(
0
),
ilens
.
size
(
0
))
_
,
_
,
C
,
input_length
=
xs
.
size
()
# (B, F, C, T) -> (B, C, T, F)
xs
=
xs
.
permute
(
0
,
2
,
3
,
1
)
# Calculate amplitude: (B, C, T, F) -> (B, C, T, F)
xs
=
(
xs
.
real
**
2
+
xs
.
imag
**
2
)
**
0.5
# xs: (B, C, T, F) -> xs: (B * C, T, F)
xs
=
xs
.
contiguous
().
view
(
-
1
,
xs
.
size
(
-
2
),
xs
.
size
(
-
1
))
# ilens: (B,) -> ilens_: (B * C)
ilens_
=
ilens
[:,
None
].
expand
(
-
1
,
C
).
contiguous
().
view
(
-
1
)
# xs: (B * C, T, F) -> xs: (B * C, T, D)
xs
,
_
,
_
=
self
.
brnn
(
xs
,
ilens_
)
# xs: (B * C, T, D) -> xs: (B, C, T, D)
xs
=
xs
.
view
(
-
1
,
C
,
xs
.
size
(
-
2
),
xs
.
size
(
-
1
))
masks
=
[]
for
linear
in
self
.
linears
:
# xs: (B, C, T, D) -> mask:(B, C, T, F)
mask
=
linear
(
xs
)
mask
=
torch
.
sigmoid
(
mask
)
# Zero padding
mask
.
masked_fill
(
make_pad_mask
(
ilens
,
mask
,
length_dim
=
2
),
0
)
# (B, C, T, F) -> (B, F, C, T)
mask
=
mask
.
permute
(
0
,
3
,
1
,
2
)
# Take cares of multi gpu cases: If input_length > max(ilens)
if
mask
.
size
(
-
1
)
<
input_length
:
mask
=
F
.
pad
(
mask
,
[
0
,
input_length
-
mask
.
size
(
-
1
)],
value
=
0
)
masks
.
append
(
mask
)
return
tuple
(
masks
),
ilens
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/gtn_ctc.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""GTN CTC implementation."""
import
gtn
import
torch
class
GTNCTCLossFunction
(
torch
.
autograd
.
Function
):
"""GTN CTC module."""
# Copied from FB's GTN example implementation:
# https://github.com/facebookresearch/gtn_applications/blob/master/utils.py#L251
@
staticmethod
def
create_ctc_graph
(
target
,
blank_idx
):
"""Build gtn graph.
:param list target: single target sequence
:param int blank_idx: index of blank token
:return: gtn graph of target sequence
:rtype: gtn.Graph
"""
g_criterion
=
gtn
.
Graph
(
False
)
L
=
len
(
target
)
S
=
2
*
L
+
1
for
s
in
range
(
S
):
idx
=
(
s
-
1
)
//
2
g_criterion
.
add_node
(
s
==
0
,
s
==
S
-
1
or
s
==
S
-
2
)
label
=
target
[
idx
]
if
s
%
2
else
blank_idx
g_criterion
.
add_arc
(
s
,
s
,
label
)
if
s
>
0
:
g_criterion
.
add_arc
(
s
-
1
,
s
,
label
)
if
s
%
2
and
s
>
1
and
label
!=
target
[
idx
-
1
]:
g_criterion
.
add_arc
(
s
-
2
,
s
,
label
)
g_criterion
.
arc_sort
(
False
)
return
g_criterion
@
staticmethod
def
forward
(
ctx
,
log_probs
,
targets
,
ilens
,
blank_idx
=
0
,
reduction
=
"none"
):
"""Forward computation.
:param torch.tensor log_probs: batched log softmax probabilities (B, Tmax, oDim)
:param list targets: batched target sequences, list of lists
:param int blank_idx: index of blank token
:return: ctc loss value
:rtype: torch.Tensor
"""
B
,
_
,
C
=
log_probs
.
shape
losses
=
[
None
]
*
B
scales
=
[
None
]
*
B
emissions_graphs
=
[
None
]
*
B
def
process
(
b
):
# create emission graph
T
=
ilens
[
b
]
g_emissions
=
gtn
.
linear_graph
(
T
,
C
,
log_probs
.
requires_grad
)
cpu_data
=
log_probs
[
b
][:
T
].
cpu
().
contiguous
()
g_emissions
.
set_weights
(
cpu_data
.
data_ptr
())
# create criterion graph
g_criterion
=
GTNCTCLossFunction
.
create_ctc_graph
(
targets
[
b
],
blank_idx
)
# compose the graphs
g_loss
=
gtn
.
negate
(
gtn
.
forward_score
(
gtn
.
intersect
(
g_emissions
,
g_criterion
))
)
scale
=
1.0
if
reduction
==
"mean"
:
L
=
len
(
targets
[
b
])
scale
=
1.0
/
L
if
L
>
0
else
scale
elif
reduction
!=
"none"
:
raise
ValueError
(
"invalid value for reduction '"
+
str
(
reduction
)
+
"'"
)
# Save for backward:
losses
[
b
]
=
g_loss
scales
[
b
]
=
scale
emissions_graphs
[
b
]
=
g_emissions
gtn
.
parallel_for
(
process
,
range
(
B
))
ctx
.
auxiliary_data
=
(
losses
,
scales
,
emissions_graphs
,
log_probs
.
shape
,
ilens
)
loss
=
torch
.
tensor
([
losses
[
b
].
item
()
*
scales
[
b
]
for
b
in
range
(
B
)])
return
torch
.
mean
(
loss
.
cuda
()
if
log_probs
.
is_cuda
else
loss
)
@
staticmethod
def
backward
(
ctx
,
grad_output
):
"""Backward computation.
:param torch.tensor grad_output: backward passed gradient value
:return: cumulative gradient output
:rtype: (torch.Tensor, None, None, None)
"""
losses
,
scales
,
emissions_graphs
,
in_shape
,
ilens
=
ctx
.
auxiliary_data
B
,
T
,
C
=
in_shape
input_grad
=
torch
.
zeros
((
B
,
T
,
C
))
def
process
(
b
):
T
=
ilens
[
b
]
gtn
.
backward
(
losses
[
b
],
False
)
emissions
=
emissions_graphs
[
b
]
grad
=
emissions
.
grad
().
weights_to_numpy
()
input_grad
[
b
][:
T
]
=
torch
.
from_numpy
(
grad
).
view
(
1
,
T
,
C
)
*
scales
[
b
]
gtn
.
parallel_for
(
process
,
range
(
B
))
if
grad_output
.
is_cuda
:
input_grad
=
input_grad
.
cuda
()
input_grad
*=
grad_output
/
B
return
(
input_grad
,
None
,
# targets
None
,
# ilens
None
,
# blank_idx
None
,
# reduction
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/initialization.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python
# Copyright 2019 Kyoto University (Hirofumi Inaguma)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Initialization functions for RNN sequence-to-sequence models."""
import
math
def
lecun_normal_init_parameters
(
module
):
"""Initialize parameters in the LeCun's manner."""
for
p
in
module
.
parameters
():
data
=
p
.
data
if
data
.
dim
()
==
1
:
# bias
data
.
zero_
()
elif
data
.
dim
()
==
2
:
# linear weight
n
=
data
.
size
(
1
)
stdv
=
1.0
/
math
.
sqrt
(
n
)
data
.
normal_
(
0
,
stdv
)
elif
data
.
dim
()
in
(
3
,
4
):
# conv weight
n
=
data
.
size
(
1
)
for
k
in
data
.
size
()[
2
:]:
n
*=
k
stdv
=
1.0
/
math
.
sqrt
(
n
)
data
.
normal_
(
0
,
stdv
)
else
:
raise
NotImplementedError
def
uniform_init_parameters
(
module
):
"""Initialize parameters with an uniform distribution."""
for
p
in
module
.
parameters
():
data
=
p
.
data
if
data
.
dim
()
==
1
:
# bias
data
.
uniform_
(
-
0.1
,
0.1
)
elif
data
.
dim
()
==
2
:
# linear weight
data
.
uniform_
(
-
0.1
,
0.1
)
elif
data
.
dim
()
in
(
3
,
4
):
# conv weight
pass
# use the pytorch default
else
:
raise
NotImplementedError
def
set_forget_bias_to_one
(
bias
):
"""Initialize a bias vector in the forget gate with one."""
n
=
bias
.
size
(
0
)
start
,
end
=
n
//
4
,
n
//
2
bias
.
data
[
start
:
end
].
fill_
(
1.0
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/default.py
0 → 100644
View file @
60a2c57a
"""Default Recurrent Neural Network Languge Model in `lm_train.py`."""
import
logging
from
typing
import
Any
,
List
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
espnet.nets.lm_interface
import
LMInterface
from
espnet.nets.pytorch_backend.e2e_asr
import
to_device
from
espnet.nets.scorer_interface
import
BatchScorerInterface
from
espnet.utils.cli_utils
import
strtobool
class
DefaultRNNLM
(
BatchScorerInterface
,
LMInterface
,
nn
.
Module
):
"""Default RNNLM for `LMInterface` Implementation.
Note:
PyTorch seems to have memory leak when one GPU compute this after data parallel.
If parallel GPUs compute this, it seems to be fine.
See also https://github.com/espnet/espnet/issues/1075
"""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to command line argument parser."""
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
"lstm"
,
nargs
=
"?"
,
choices
=
[
"lstm"
,
"gru"
],
help
=
"Which type of RNN to use"
,
)
parser
.
add_argument
(
"--layer"
,
"-l"
,
type
=
int
,
default
=
2
,
help
=
"Number of hidden layers"
)
parser
.
add_argument
(
"--unit"
,
"-u"
,
type
=
int
,
default
=
650
,
help
=
"Number of hidden units"
)
parser
.
add_argument
(
"--embed-unit"
,
default
=
None
,
type
=
int
,
help
=
"Number of hidden units in embedding layer, "
"if it is not specified, it keeps the same number with hidden units."
,
)
parser
.
add_argument
(
"--dropout-rate"
,
type
=
float
,
default
=
0.5
,
help
=
"dropout probability"
)
parser
.
add_argument
(
"--emb-dropout-rate"
,
type
=
float
,
default
=
0.0
,
help
=
"emb dropout probability"
,
)
parser
.
add_argument
(
"--tie-weights"
,
type
=
strtobool
,
default
=
False
,
help
=
"Tie input and output embeddings"
,
)
return
parser
def
__init__
(
self
,
n_vocab
,
args
):
"""Initialize class.
Args:
n_vocab (int): The size of the vocabulary
args (argparse.Namespace): configurations. see py:method:`add_arguments`
"""
nn
.
Module
.
__init__
(
self
)
# NOTE: for a compatibility with less than 0.5.0 version models
dropout_rate
=
getattr
(
args
,
"dropout_rate"
,
0.0
)
# NOTE: for a compatibility with less than 0.6.1 version models
embed_unit
=
getattr
(
args
,
"embed_unit"
,
None
)
# NOTE: for a compatibility with less than 0.9.7 version models
emb_dropout_rate
=
getattr
(
args
,
"emb_dropout_rate"
,
0.0
)
# NOTE: for a compatibility with less than 0.9.7 version models
tie_weights
=
getattr
(
args
,
"tie_weights"
,
False
)
self
.
model
=
ClassifierWithState
(
RNNLM
(
n_vocab
,
args
.
layer
,
args
.
unit
,
embed_unit
,
args
.
type
,
dropout_rate
,
emb_dropout_rate
,
tie_weights
,
)
)
def
state_dict
(
self
):
"""Dump state dict."""
return
self
.
model
.
state_dict
()
def
load_state_dict
(
self
,
d
):
"""Load state dict."""
self
.
model
.
load_state_dict
(
d
)
def
forward
(
self
,
x
,
t
):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
loss
=
0
logp
=
0
count
=
torch
.
tensor
(
0
).
long
()
state
=
None
batch_size
,
sequence_length
=
x
.
shape
for
i
in
range
(
sequence_length
):
# Compute the loss at this time step and accumulate it
state
,
loss_batch
=
self
.
model
(
state
,
x
[:,
i
],
t
[:,
i
])
non_zeros
=
torch
.
sum
(
x
[:,
i
]
!=
0
,
dtype
=
loss_batch
.
dtype
)
loss
+=
loss_batch
.
mean
()
*
non_zeros
logp
+=
torch
.
sum
(
loss_batch
*
non_zeros
)
count
+=
int
(
non_zeros
)
return
loss
/
batch_size
,
loss
,
count
.
to
(
loss
.
device
)
def
score
(
self
,
y
,
state
,
x
):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and next state for ys
"""
new_state
,
scores
=
self
.
model
.
predict
(
state
,
y
[
-
1
].
unsqueeze
(
0
))
return
scores
.
squeeze
(
0
),
new_state
def
final_score
(
self
,
state
):
"""Score eos.
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return
self
.
model
.
final
(
state
)
# batch beam search API (see BatchScorerInterface)
def
batch_score
(
self
,
ys
:
torch
.
Tensor
,
states
:
List
[
Any
],
xs
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
List
[
Any
]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch
=
len
(
ys
)
n_layers
=
self
.
model
.
predictor
.
n_layers
if
self
.
model
.
predictor
.
typ
==
"lstm"
:
keys
=
(
"c"
,
"h"
)
else
:
keys
=
(
"h"
,)
if
states
[
0
]
is
None
:
states
=
None
else
:
# transpose state of [batch, key, layer] into [key, layer, batch]
states
=
{
k
:
[
torch
.
stack
([
states
[
b
][
k
][
i
]
for
b
in
range
(
n_batch
)])
for
i
in
range
(
n_layers
)
]
for
k
in
keys
}
states
,
logp
=
self
.
model
.
predict
(
states
,
ys
[:,
-
1
])
# transpose state of [key, layer, batch] into [batch, key, layer]
return
(
logp
,
[
{
k
:
[
states
[
k
][
i
][
b
]
for
i
in
range
(
n_layers
)]
for
k
in
keys
}
for
b
in
range
(
n_batch
)
],
)
class
ClassifierWithState
(
nn
.
Module
):
"""A wrapper for pytorch RNNLM."""
def
__init__
(
self
,
predictor
,
lossfun
=
nn
.
CrossEntropyLoss
(
reduction
=
"none"
),
label_key
=-
1
):
"""Initialize class.
:param torch.nn.Module predictor : The RNNLM
:param function lossfun : The loss function to use
:param int/str label_key :
"""
if
not
(
isinstance
(
label_key
,
(
int
,
str
))):
raise
TypeError
(
"label_key must be int or str, but is %s"
%
type
(
label_key
))
super
(
ClassifierWithState
,
self
).
__init__
()
self
.
lossfun
=
lossfun
self
.
y
=
None
self
.
loss
=
None
self
.
label_key
=
label_key
self
.
predictor
=
predictor
def
forward
(
self
,
state
,
*
args
,
**
kwargs
):
"""Compute the loss value for an input and label pair.
Notes:
It also computes accuracy and stores it to the attribute.
When ``label_key`` is ``int``, the corresponding element in ``args``
is treated as ground truth labels. And when it is ``str``, the
element in ``kwargs`` is used.
The all elements of ``args`` and ``kwargs`` except the groundtruth
labels are features.
It feeds features to the predictor and compare the result
with ground truth labels.
:param torch.Tensor state : the LM state
:param list[torch.Tensor] args : Input minibatch
:param dict[torch.Tensor] kwargs : Input minibatch
:return loss value
:rtype torch.Tensor
"""
if
isinstance
(
self
.
label_key
,
int
):
if
not
(
-
len
(
args
)
<=
self
.
label_key
<
len
(
args
)):
msg
=
"Label key %d is out of bounds"
%
self
.
label_key
raise
ValueError
(
msg
)
t
=
args
[
self
.
label_key
]
if
self
.
label_key
==
-
1
:
args
=
args
[:
-
1
]
else
:
args
=
args
[:
self
.
label_key
]
+
args
[
self
.
label_key
+
1
:]
elif
isinstance
(
self
.
label_key
,
str
):
if
self
.
label_key
not
in
kwargs
:
msg
=
'Label key "%s" is not found'
%
self
.
label_key
raise
ValueError
(
msg
)
t
=
kwargs
[
self
.
label_key
]
del
kwargs
[
self
.
label_key
]
self
.
y
=
None
self
.
loss
=
None
state
,
self
.
y
=
self
.
predictor
(
state
,
*
args
,
**
kwargs
)
self
.
loss
=
self
.
lossfun
(
self
.
y
,
t
)
return
state
,
self
.
loss
def
predict
(
self
,
state
,
x
):
"""Predict log probabilities for given state and input x using the predictor.
:param torch.Tensor state : The current state
:param torch.Tensor x : The input
:return a tuple (new state, log prob vector)
:rtype (torch.Tensor, torch.Tensor)
"""
if
hasattr
(
self
.
predictor
,
"normalized"
)
and
self
.
predictor
.
normalized
:
return
self
.
predictor
(
state
,
x
)
else
:
state
,
z
=
self
.
predictor
(
state
,
x
)
return
state
,
F
.
log_softmax
(
z
,
dim
=
1
)
def
buff_predict
(
self
,
state
,
x
,
n
):
"""Predict new tokens from buffered inputs."""
if
self
.
predictor
.
__class__
.
__name__
==
"RNNLM"
:
return
self
.
predict
(
state
,
x
)
new_state
=
[]
new_log_y
=
[]
for
i
in
range
(
n
):
state_i
=
None
if
state
is
None
else
state
[
i
]
state_i
,
log_y
=
self
.
predict
(
state_i
,
x
[
i
].
unsqueeze
(
0
))
new_state
.
append
(
state_i
)
new_log_y
.
append
(
log_y
)
return
new_state
,
torch
.
cat
(
new_log_y
)
def
final
(
self
,
state
,
index
=
None
):
"""Predict final log probabilities for given state using the predictor.
:param state: The state
:return The final log probabilities
:rtype torch.Tensor
"""
if
hasattr
(
self
.
predictor
,
"final"
):
if
index
is
not
None
:
return
self
.
predictor
.
final
(
state
[
index
])
else
:
return
self
.
predictor
.
final
(
state
)
else
:
return
0.0
# Definition of a recurrent net for language modeling
class
RNNLM
(
nn
.
Module
):
"""A pytorch RNNLM."""
def
__init__
(
self
,
n_vocab
,
n_layers
,
n_units
,
n_embed
=
None
,
typ
=
"lstm"
,
dropout_rate
=
0.5
,
emb_dropout_rate
=
0.0
,
tie_weights
=
False
,
):
"""Initialize class.
:param int n_vocab: The size of the vocabulary
:param int n_layers: The number of layers to create
:param int n_units: The number of units per layer
:param str typ: The RNN type
"""
super
(
RNNLM
,
self
).
__init__
()
if
n_embed
is
None
:
n_embed
=
n_units
self
.
embed
=
nn
.
Embedding
(
n_vocab
,
n_embed
)
if
emb_dropout_rate
==
0.0
:
self
.
embed_drop
=
None
else
:
self
.
embed_drop
=
nn
.
Dropout
(
emb_dropout_rate
)
if
typ
==
"lstm"
:
self
.
rnn
=
nn
.
ModuleList
(
[
nn
.
LSTMCell
(
n_embed
,
n_units
)]
+
[
nn
.
LSTMCell
(
n_units
,
n_units
)
for
_
in
range
(
n_layers
-
1
)]
)
else
:
self
.
rnn
=
nn
.
ModuleList
(
[
nn
.
GRUCell
(
n_embed
,
n_units
)]
+
[
nn
.
GRUCell
(
n_units
,
n_units
)
for
_
in
range
(
n_layers
-
1
)]
)
self
.
dropout
=
nn
.
ModuleList
(
[
nn
.
Dropout
(
dropout_rate
)
for
_
in
range
(
n_layers
+
1
)]
)
self
.
lo
=
nn
.
Linear
(
n_units
,
n_vocab
)
self
.
n_layers
=
n_layers
self
.
n_units
=
n_units
self
.
typ
=
typ
logging
.
info
(
"Tie weights set to {}"
.
format
(
tie_weights
))
logging
.
info
(
"Dropout set to {}"
.
format
(
dropout_rate
))
logging
.
info
(
"Emb Dropout set to {}"
.
format
(
emb_dropout_rate
))
if
tie_weights
:
assert
(
n_embed
==
n_units
),
"Tie Weights: True need embedding and final dimensions to match"
self
.
lo
.
weight
=
self
.
embed
.
weight
# initialize parameters from uniform distribution
for
param
in
self
.
parameters
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
def
zero_state
(
self
,
batchsize
):
"""Initialize state."""
p
=
next
(
self
.
parameters
())
return
torch
.
zeros
(
batchsize
,
self
.
n_units
).
to
(
device
=
p
.
device
,
dtype
=
p
.
dtype
)
def
forward
(
self
,
state
,
x
):
"""Forward neural networks."""
if
state
is
None
:
h
=
[
to_device
(
x
,
self
.
zero_state
(
x
.
size
(
0
)))
for
n
in
range
(
self
.
n_layers
)]
state
=
{
"h"
:
h
}
if
self
.
typ
==
"lstm"
:
c
=
[
to_device
(
x
,
self
.
zero_state
(
x
.
size
(
0
)))
for
n
in
range
(
self
.
n_layers
)
]
state
=
{
"c"
:
c
,
"h"
:
h
}
h
=
[
None
]
*
self
.
n_layers
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
x
))
else
:
emb
=
self
.
embed
(
x
)
if
self
.
typ
==
"lstm"
:
c
=
[
None
]
*
self
.
n_layers
h
[
0
],
c
[
0
]
=
self
.
rnn
[
0
](
self
.
dropout
[
0
](
emb
),
(
state
[
"h"
][
0
],
state
[
"c"
][
0
])
)
for
n
in
range
(
1
,
self
.
n_layers
):
h
[
n
],
c
[
n
]
=
self
.
rnn
[
n
](
self
.
dropout
[
n
](
h
[
n
-
1
]),
(
state
[
"h"
][
n
],
state
[
"c"
][
n
])
)
state
=
{
"c"
:
c
,
"h"
:
h
}
else
:
h
[
0
]
=
self
.
rnn
[
0
](
self
.
dropout
[
0
](
emb
),
state
[
"h"
][
0
])
for
n
in
range
(
1
,
self
.
n_layers
):
h
[
n
]
=
self
.
rnn
[
n
](
self
.
dropout
[
n
](
h
[
n
-
1
]),
state
[
"h"
][
n
])
state
=
{
"h"
:
h
}
y
=
self
.
lo
(
self
.
dropout
[
-
1
](
h
[
-
1
]))
return
state
,
y
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/seq_rnn.py
0 → 100644
View file @
60a2c57a
"""Sequential implementation of Recurrent Neural Network Language Model."""
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
espnet.nets.lm_interface
import
LMInterface
class
SequentialRNNLM
(
LMInterface
,
torch
.
nn
.
Module
):
"""Sequential RNNLM.
See also:
https://github.com/pytorch/examples/blob/4581968193699de14b56527296262dd76ab43557/word_language_model/model.py
"""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to command line argument parser."""
parser
.
add_argument
(
"--type"
,
type
=
str
,
default
=
"lstm"
,
nargs
=
"?"
,
choices
=
[
"lstm"
,
"gru"
],
help
=
"Which type of RNN to use"
,
)
parser
.
add_argument
(
"--layer"
,
"-l"
,
type
=
int
,
default
=
2
,
help
=
"Number of hidden layers"
)
parser
.
add_argument
(
"--unit"
,
"-u"
,
type
=
int
,
default
=
650
,
help
=
"Number of hidden units"
)
parser
.
add_argument
(
"--dropout-rate"
,
type
=
float
,
default
=
0.5
,
help
=
"dropout probability"
)
return
parser
def
__init__
(
self
,
n_vocab
,
args
):
"""Initialize class.
Args:
n_vocab (int): The size of the vocabulary
args (argparse.Namespace): configurations. see py:method:`add_arguments`
"""
torch
.
nn
.
Module
.
__init__
(
self
)
self
.
_setup
(
rnn_type
=
args
.
type
.
upper
(),
ntoken
=
n_vocab
,
ninp
=
args
.
unit
,
nhid
=
args
.
unit
,
nlayers
=
args
.
layer
,
dropout
=
args
.
dropout_rate
,
)
def
_setup
(
self
,
rnn_type
,
ntoken
,
ninp
,
nhid
,
nlayers
,
dropout
=
0.5
,
tie_weights
=
False
):
self
.
drop
=
nn
.
Dropout
(
dropout
)
self
.
encoder
=
nn
.
Embedding
(
ntoken
,
ninp
)
if
rnn_type
in
[
"LSTM"
,
"GRU"
]:
self
.
rnn
=
getattr
(
nn
,
rnn_type
)(
ninp
,
nhid
,
nlayers
,
dropout
=
dropout
)
else
:
try
:
nonlinearity
=
{
"RNN_TANH"
:
"tanh"
,
"RNN_RELU"
:
"relu"
}[
rnn_type
]
except
KeyError
:
raise
ValueError
(
"An invalid option for `--model` was supplied, "
"options are ['LSTM', 'GRU', 'RNN_TANH' or 'RNN_RELU']"
)
self
.
rnn
=
nn
.
RNN
(
ninp
,
nhid
,
nlayers
,
nonlinearity
=
nonlinearity
,
dropout
=
dropout
)
self
.
decoder
=
nn
.
Linear
(
nhid
,
ntoken
)
# Optionally tie weights as in:
# "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
# https://arxiv.org/abs/1608.05859
# and
# "Tying Word Vectors and Word Classifiers:
# A Loss Framework for Language Modeling" (Inan et al. 2016)
# https://arxiv.org/abs/1611.01462
if
tie_weights
:
if
nhid
!=
ninp
:
raise
ValueError
(
"When using the tied flag, nhid must be equal to emsize"
)
self
.
decoder
.
weight
=
self
.
encoder
.
weight
self
.
_init_weights
()
self
.
rnn_type
=
rnn_type
self
.
nhid
=
nhid
self
.
nlayers
=
nlayers
def
_init_weights
(
self
):
# NOTE: original init in pytorch/examples
# initrange = 0.1
# self.encoder.weight.data.uniform_(-initrange, initrange)
# self.decoder.bias.data.zero_()
# self.decoder.weight.data.uniform_(-initrange, initrange)
# NOTE: our default.py:RNNLM init
for
param
in
self
.
parameters
():
param
.
data
.
uniform_
(
-
0.1
,
0.1
)
def
forward
(
self
,
x
,
t
):
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
y
=
self
.
_before_loss
(
x
,
None
)[
0
]
mask
=
(
x
!=
0
).
to
(
y
.
dtype
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
logp
=
loss
*
mask
.
view
(
-
1
)
logp
=
logp
.
sum
()
count
=
mask
.
sum
()
return
logp
/
count
,
logp
,
count
def
_before_loss
(
self
,
input
,
hidden
):
emb
=
self
.
drop
(
self
.
encoder
(
input
))
output
,
hidden
=
self
.
rnn
(
emb
,
hidden
)
output
=
self
.
drop
(
output
)
decoded
=
self
.
decoder
(
output
.
view
(
output
.
size
(
0
)
*
output
.
size
(
1
),
output
.
size
(
2
))
)
return
decoded
.
view
(
output
.
size
(
0
),
output
.
size
(
1
),
decoded
.
size
(
1
)),
hidden
def
init_state
(
self
,
x
):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
bsz
=
1
weight
=
next
(
self
.
parameters
())
if
self
.
rnn_type
==
"LSTM"
:
return
(
weight
.
new_zeros
(
self
.
nlayers
,
bsz
,
self
.
nhid
),
weight
.
new_zeros
(
self
.
nlayers
,
bsz
,
self
.
nhid
),
)
else
:
return
weight
.
new_zeros
(
self
.
nlayers
,
bsz
,
self
.
nhid
)
def
score
(
self
,
y
,
state
,
x
):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and next state for ys
"""
y
,
new_state
=
self
.
_before_loss
(
y
[
-
1
].
view
(
1
,
1
),
state
)
logp
=
y
.
log_softmax
(
dim
=-
1
).
view
(
-
1
)
return
logp
,
new_state
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/lm/transformer.py
0 → 100644
View file @
60a2c57a
"""Transformer language model."""
import
logging
from
typing
import
Any
,
List
,
Tuple
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
espnet.nets.lm_interface
import
LMInterface
from
espnet.nets.pytorch_backend.transformer.embedding
import
PositionalEncoding
from
espnet.nets.pytorch_backend.transformer.encoder
import
Encoder
from
espnet.nets.pytorch_backend.transformer.mask
import
subsequent_mask
from
espnet.nets.scorer_interface
import
BatchScorerInterface
from
espnet.utils.cli_utils
import
strtobool
class
TransformerLM
(
nn
.
Module
,
LMInterface
,
BatchScorerInterface
):
"""Transformer language model."""
@
staticmethod
def
add_arguments
(
parser
):
"""Add arguments to command line argument parser."""
parser
.
add_argument
(
"--layer"
,
type
=
int
,
default
=
4
,
help
=
"Number of hidden layers"
)
parser
.
add_argument
(
"--unit"
,
type
=
int
,
default
=
1024
,
help
=
"Number of hidden units in feedforward layer"
,
)
parser
.
add_argument
(
"--att-unit"
,
type
=
int
,
default
=
256
,
help
=
"Number of hidden units in attention layer"
,
)
parser
.
add_argument
(
"--embed-unit"
,
type
=
int
,
default
=
128
,
help
=
"Number of hidden units in embedding layer"
,
)
parser
.
add_argument
(
"--head"
,
type
=
int
,
default
=
2
,
help
=
"Number of multi head attention"
)
parser
.
add_argument
(
"--dropout-rate"
,
type
=
float
,
default
=
0.5
,
help
=
"dropout probability"
)
parser
.
add_argument
(
"--att-dropout-rate"
,
type
=
float
,
default
=
0.0
,
help
=
"att dropout probability"
,
)
parser
.
add_argument
(
"--emb-dropout-rate"
,
type
=
float
,
default
=
0.0
,
help
=
"emb dropout probability"
,
)
parser
.
add_argument
(
"--tie-weights"
,
type
=
strtobool
,
default
=
False
,
help
=
"Tie input and output embeddings"
,
)
parser
.
add_argument
(
"--pos-enc"
,
default
=
"sinusoidal"
,
choices
=
[
"sinusoidal"
,
"none"
],
help
=
"positional encoding"
,
)
return
parser
def
__init__
(
self
,
n_vocab
,
args
):
"""Initialize class.
Args:
n_vocab (int): The size of the vocabulary
args (argparse.Namespace): configurations. see py:method:`add_arguments`
"""
nn
.
Module
.
__init__
(
self
)
# NOTE: for a compatibility with less than 0.9.7 version models
emb_dropout_rate
=
getattr
(
args
,
"emb_dropout_rate"
,
0.0
)
# NOTE: for a compatibility with less than 0.9.7 version models
tie_weights
=
getattr
(
args
,
"tie_weights"
,
False
)
# NOTE: for a compatibility with less than 0.9.7 version models
att_dropout_rate
=
getattr
(
args
,
"att_dropout_rate"
,
0.0
)
if
args
.
pos_enc
==
"sinusoidal"
:
pos_enc_class
=
PositionalEncoding
elif
args
.
pos_enc
==
"none"
:
def
pos_enc_class
(
*
args
,
**
kwargs
):
return
nn
.
Sequential
()
# indentity
else
:
raise
ValueError
(
f
"unknown pos-enc option:
{
args
.
pos_enc
}
"
)
self
.
embed
=
nn
.
Embedding
(
n_vocab
,
args
.
embed_unit
)
if
emb_dropout_rate
==
0.0
:
self
.
embed_drop
=
None
else
:
self
.
embed_drop
=
nn
.
Dropout
(
emb_dropout_rate
)
self
.
encoder
=
Encoder
(
idim
=
args
.
embed_unit
,
attention_dim
=
args
.
att_unit
,
attention_heads
=
args
.
head
,
linear_units
=
args
.
unit
,
num_blocks
=
args
.
layer
,
dropout_rate
=
args
.
dropout_rate
,
attention_dropout_rate
=
att_dropout_rate
,
input_layer
=
"linear"
,
pos_enc_class
=
pos_enc_class
,
)
self
.
decoder
=
nn
.
Linear
(
args
.
att_unit
,
n_vocab
)
logging
.
info
(
"Tie weights set to {}"
.
format
(
tie_weights
))
logging
.
info
(
"Dropout set to {}"
.
format
(
args
.
dropout_rate
))
logging
.
info
(
"Emb Dropout set to {}"
.
format
(
emb_dropout_rate
))
logging
.
info
(
"Att Dropout set to {}"
.
format
(
att_dropout_rate
))
if
tie_weights
:
assert
(
args
.
att_unit
==
args
.
embed_unit
),
"Tie Weights: True need embedding and final dimensions to match"
self
.
decoder
.
weight
=
self
.
embed
.
weight
def
_target_mask
(
self
,
ys_in_pad
):
ys_mask
=
ys_in_pad
!=
0
m
=
subsequent_mask
(
ys_mask
.
size
(
-
1
),
device
=
ys_mask
.
device
).
unsqueeze
(
0
)
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
def
forward
(
self
,
x
:
torch
.
Tensor
,
t
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
"""Compute LM loss value from buffer sequences.
Args:
x (torch.Tensor): Input ids. (batch, len)
t (torch.Tensor): Target ids. (batch, len)
Returns:
tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of
loss to backward (scalar),
negative log-likelihood of t: -log p(t) (scalar) and
the number of elements in x (scalar)
Notes:
The last two return values are used
in perplexity: p(t)^{-n} = exp(-log p(t) / n)
"""
xm
=
x
!=
0
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
x
))
else
:
emb
=
self
.
embed
(
x
)
h
,
_
=
self
.
encoder
(
emb
,
self
.
_target_mask
(
x
))
y
=
self
.
decoder
(
h
)
loss
=
F
.
cross_entropy
(
y
.
view
(
-
1
,
y
.
shape
[
-
1
]),
t
.
view
(
-
1
),
reduction
=
"none"
)
mask
=
xm
.
to
(
dtype
=
loss
.
dtype
)
logp
=
loss
*
mask
.
view
(
-
1
)
logp
=
logp
.
sum
()
count
=
mask
.
sum
()
return
logp
/
count
,
logp
,
count
def
score
(
self
,
y
:
torch
.
Tensor
,
state
:
Any
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Any
]:
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and next state for ys
"""
y
=
y
.
unsqueeze
(
0
)
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
y
))
else
:
emb
=
self
.
embed
(
y
)
h
,
_
,
cache
=
self
.
encoder
.
forward_one_step
(
emb
,
self
.
_target_mask
(
y
),
cache
=
state
)
h
=
self
.
decoder
(
h
[:,
-
1
])
logp
=
h
.
log_softmax
(
dim
=-
1
).
squeeze
(
0
)
return
logp
,
cache
# batch beam search API (see BatchScorerInterface)
def
batch_score
(
self
,
ys
:
torch
.
Tensor
,
states
:
List
[
Any
],
xs
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
List
[
Any
]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
# merge states
n_batch
=
len
(
ys
)
n_layers
=
len
(
self
.
encoder
.
encoders
)
if
states
[
0
]
is
None
:
batch_state
=
None
else
:
# transpose state of [batch, layer] into [layer, batch]
batch_state
=
[
torch
.
stack
([
states
[
b
][
i
]
for
b
in
range
(
n_batch
)])
for
i
in
range
(
n_layers
)
]
if
self
.
embed_drop
is
not
None
:
emb
=
self
.
embed_drop
(
self
.
embed
(
ys
))
else
:
emb
=
self
.
embed
(
ys
)
# batch decoding
h
,
_
,
states
=
self
.
encoder
.
forward_one_step
(
emb
,
self
.
_target_mask
(
ys
),
cache
=
batch_state
)
h
=
self
.
decoder
(
h
[:,
-
1
])
logp
=
h
.
log_softmax
(
dim
=-
1
)
# transpose state of [layer, batch] into [batch, layer]
state_list
=
[[
states
[
i
][
b
]
for
i
in
range
(
n_layers
)]
for
b
in
range
(
n_batch
)]
return
logp
,
state_list
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/maskctc/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/maskctc/add_mask_token.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Token masking module for Masked LM."""
import
numpy
def
mask_uniform
(
ys_pad
,
mask_token
,
eos
,
ignore_id
):
"""Replace random tokens with <mask> label and add <eos> label.
The number of <mask> is chosen from a uniform distribution
between one and the target sequence's length.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int mask_token: index of <mask>
:param int eos: index of <eos>
:param int ignore_id: index of padding
:return: padded tensor (B, Lmax)
:rtype: torch.Tensor
:return: padded tensor (B, Lmax)
:rtype: torch.Tensor
"""
from
espnet.nets.pytorch_backend.nets_utils
import
pad_list
ys
=
[
y
[
y
!=
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
ys_out
=
[
y
.
new
(
y
.
size
()).
fill_
(
ignore_id
)
for
y
in
ys
]
ys_in
=
[
y
.
clone
()
for
y
in
ys
]
for
i
in
range
(
len
(
ys
)):
num_samples
=
numpy
.
random
.
randint
(
1
,
len
(
ys
[
i
])
+
1
)
idx
=
numpy
.
random
.
choice
(
len
(
ys
[
i
]),
num_samples
)
ys_in
[
i
][
idx
]
=
mask_token
ys_out
[
i
][
idx
]
=
ys
[
i
][
idx
]
return
pad_list
(
ys_in
,
eos
),
pad_list
(
ys_out
,
ignore_id
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/maskctc/mask.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2020 Johns Hopkins University (Shinji Watanabe)
# Waseda University (Yosuke Higuchi)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Attention masking module for Masked LM."""
def
square_mask
(
ys_in_pad
,
ignore_id
):
"""Create attention mask to avoid attending on padding tokens.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int ignore_id: index of padding
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor (B, Lmax, Lmax)
"""
ys_mask
=
(
ys_in_pad
!=
ignore_id
).
unsqueeze
(
-
2
)
ymax
=
ys_mask
.
size
(
-
1
)
ys_mask_tmp
=
ys_mask
.
transpose
(
1
,
2
).
repeat
(
1
,
1
,
ymax
)
ys_mask
=
ys_mask
.
repeat
(
1
,
ymax
,
1
)
&
ys_mask_tmp
return
ys_mask
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/nets_utils.py
0 → 100644
View file @
60a2c57a
# -*- coding: utf-8 -*-
"""Network related utility tools."""
import
logging
from
typing
import
Dict
import
numpy
as
np
import
torch
def
to_device
(
m
,
x
):
"""Send tensor into the device of the module.
Args:
m (torch.nn.Module): Torch module.
x (Tensor): Torch tensor.
Returns:
Tensor: Torch tensor located in the same place as torch module.
"""
if
isinstance
(
m
,
torch
.
nn
.
Module
):
device
=
next
(
m
.
parameters
()).
device
elif
isinstance
(
m
,
torch
.
Tensor
):
device
=
m
.
device
else
:
raise
TypeError
(
"Expected torch.nn.Module or torch.tensor, "
f
"bot got:
{
type
(
m
)
}
"
)
return
x
.
to
(
device
)
def
pad_list
(
xs
,
pad_value
):
"""Perform padding for the list of tensors.
Args:
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
pad_value (float): Value for padding.
Returns:
Tensor: Padded tensor (B, Tmax, `*`).
Examples:
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
>>> x
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
>>> pad_list(x, 0)
tensor([[1., 1., 1., 1.],
[1., 1., 0., 0.],
[1., 0., 0., 0.]])
"""
n_batch
=
len
(
xs
)
max_len
=
max
(
x
.
size
(
0
)
for
x
in
xs
)
pad
=
xs
[
0
].
new
(
n_batch
,
max_len
,
*
xs
[
0
].
size
()[
1
:]).
fill_
(
pad_value
)
for
i
in
range
(
n_batch
):
pad
[
i
,
:
xs
[
i
].
size
(
0
)]
=
xs
[
i
]
return
pad
def
make_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
,
maxlen
=
None
):
"""Make mask tensor containing indices of padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
Tensor: Mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_pad_mask(lengths)
masks = [[0, 0, 0, 0 ,0],
[0, 0, 0, 1, 1],
[0, 0, 1, 1, 1]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0],
[0, 0, 0, 0]],
[[0, 0, 0, 1],
[0, 0, 0, 1]],
[[0, 0, 1, 1],
[0, 0, 1, 1]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_pad_mask(lengths, xs)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_pad_mask(lengths, xs, 1)
tensor([[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]],
[[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1]]], dtype=torch.uint8)
>>> make_pad_mask(lengths, xs, 2)
tensor([[[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1],
[0, 0, 0, 0, 0, 1]],
[[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1],
[0, 0, 0, 1, 1, 1]],
[[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1],
[0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)
"""
if
length_dim
==
0
:
raise
ValueError
(
"length_dim cannot be 0: {}"
.
format
(
length_dim
))
if
not
isinstance
(
lengths
,
list
):
lengths
=
lengths
.
long
().
tolist
()
bs
=
int
(
len
(
lengths
))
if
maxlen
is
None
:
if
xs
is
None
:
maxlen
=
int
(
max
(
lengths
))
else
:
maxlen
=
xs
.
size
(
length_dim
)
else
:
assert
xs
is
None
assert
maxlen
>=
int
(
max
(
lengths
))
seq_range
=
torch
.
arange
(
0
,
maxlen
,
dtype
=
torch
.
int64
)
seq_range_expand
=
seq_range
.
unsqueeze
(
0
).
expand
(
bs
,
maxlen
)
seq_length_expand
=
seq_range_expand
.
new
(
lengths
).
unsqueeze
(
-
1
)
mask
=
seq_range_expand
>=
seq_length_expand
if
xs
is
not
None
:
assert
xs
.
size
(
0
)
==
bs
,
(
xs
.
size
(
0
),
bs
)
if
length_dim
<
0
:
length_dim
=
xs
.
dim
()
+
length_dim
# ind = (:, None, ..., None, :, , None, ..., None)
ind
=
tuple
(
slice
(
None
)
if
i
in
(
0
,
length_dim
)
else
None
for
i
in
range
(
xs
.
dim
())
)
mask
=
mask
[
ind
].
expand_as
(
xs
).
to
(
xs
.
device
)
return
mask
def
make_non_pad_mask
(
lengths
,
xs
=
None
,
length_dim
=-
1
):
"""Make mask tensor containing indices of non-padded part.
Args:
lengths (LongTensor or List): Batch of lengths (B,).
xs (Tensor, optional): The reference tensor.
If set, masks will be the same shape as this tensor.
length_dim (int, optional): Dimension indicator of the above tensor.
See the example.
Returns:
ByteTensor: mask tensor containing indices of padded part.
dtype=torch.uint8 in PyTorch 1.2-
dtype=torch.bool in PyTorch 1.2+ (including 1.2)
Examples:
With only lengths.
>>> lengths = [5, 3, 2]
>>> make_non_pad_mask(lengths)
masks = [[1, 1, 1, 1 ,1],
[1, 1, 1, 0, 0],
[1, 1, 0, 0, 0]]
With the reference tensor.
>>> xs = torch.zeros((3, 2, 4))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1],
[1, 1, 1, 1]],
[[1, 1, 1, 0],
[1, 1, 1, 0]],
[[1, 1, 0, 0],
[1, 1, 0, 0]]], dtype=torch.uint8)
>>> xs = torch.zeros((3, 2, 6))
>>> make_non_pad_mask(lengths, xs)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
With the reference tensor and dimension indicator.
>>> xs = torch.zeros((3, 6, 6))
>>> make_non_pad_mask(lengths, xs, 1)
tensor([[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]],
[[1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)
>>> make_non_pad_mask(lengths, xs, 2)
tensor([[[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0],
[1, 1, 1, 1, 1, 0]],
[[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0],
[1, 1, 1, 0, 0, 0]],
[[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0],
[1, 1, 0, 0, 0, 0]]], dtype=torch.uint8)
"""
return
~
make_pad_mask
(
lengths
,
xs
,
length_dim
)
def
mask_by_length
(
xs
,
lengths
,
fill
=
0
):
"""Mask tensor according to length.
Args:
xs (Tensor): Batch of input tensor (B, `*`).
lengths (LongTensor or List): Batch of lengths (B,).
fill (int or float): Value to fill masked part.
Returns:
Tensor: Batch of masked input tensor (B, `*`).
Examples:
>>> x = torch.arange(5).repeat(3, 1) + 1
>>> x
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5],
[1, 2, 3, 4, 5]])
>>> lengths = [5, 3, 2]
>>> mask_by_length(x, lengths)
tensor([[1, 2, 3, 4, 5],
[1, 2, 3, 0, 0],
[1, 2, 0, 0, 0]])
"""
assert
xs
.
size
(
0
)
==
len
(
lengths
)
ret
=
xs
.
data
.
new
(
*
xs
.
size
()).
fill_
(
fill
)
for
i
,
l
in
enumerate
(
lengths
):
ret
[
i
,
:
l
]
=
xs
[
i
,
:
l
]
return
ret
def
th_accuracy
(
pad_outputs
,
pad_targets
,
ignore_label
):
"""Calculate accuracy.
Args:
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
pad_targets (LongTensor): Target label tensors (B, Lmax, D).
ignore_label (int): Ignore label id.
Returns:
float: Accuracy value (0.0 - 1.0).
"""
pad_pred
=
pad_outputs
.
view
(
pad_targets
.
size
(
0
),
pad_targets
.
size
(
1
),
pad_outputs
.
size
(
1
)
).
argmax
(
2
)
mask
=
pad_targets
!=
ignore_label
numerator
=
torch
.
sum
(
pad_pred
.
masked_select
(
mask
)
==
pad_targets
.
masked_select
(
mask
)
)
denominator
=
torch
.
sum
(
mask
)
return
float
(
numerator
)
/
float
(
denominator
)
def
to_torch_tensor
(
x
):
"""Change to torch.Tensor or ComplexTensor from numpy.ndarray.
Args:
x: Inputs. It should be one of numpy.ndarray, Tensor, ComplexTensor, and dict.
Returns:
Tensor or ComplexTensor: Type converted inputs.
Examples:
>>> xs = np.ones(3, dtype=np.float32)
>>> xs = to_torch_tensor(xs)
tensor([1., 1., 1.])
>>> xs = torch.ones(3, 4, 5)
>>> assert to_torch_tensor(xs) is xs
>>> xs = {'real': xs, 'imag': xs}
>>> to_torch_tensor(xs)
ComplexTensor(
Real:
tensor([1., 1., 1.])
Imag;
tensor([1., 1., 1.])
)
"""
# If numpy, change to torch tensor
if
isinstance
(
x
,
np
.
ndarray
):
if
x
.
dtype
.
kind
==
"c"
:
# Dynamically importing because torch_complex requires python3
from
torch_complex.tensor
import
ComplexTensor
return
ComplexTensor
(
x
)
else
:
return
torch
.
from_numpy
(
x
)
# If {'real': ..., 'imag': ...}, convert to ComplexTensor
elif
isinstance
(
x
,
dict
):
# Dynamically importing because torch_complex requires python3
from
torch_complex.tensor
import
ComplexTensor
if
"real"
not
in
x
or
"imag"
not
in
x
:
raise
ValueError
(
"has 'real' and 'imag' keys: {}"
.
format
(
list
(
x
)))
# Relative importing because of using python3 syntax
return
ComplexTensor
(
x
[
"real"
],
x
[
"imag"
])
# If torch.Tensor, as it is
elif
isinstance
(
x
,
torch
.
Tensor
):
return
x
else
:
error
=
(
"x must be numpy.ndarray, torch.Tensor or a dict like "
"{{'real': torch.Tensor, 'imag': torch.Tensor}}, "
"but got {}"
.
format
(
type
(
x
))
)
try
:
from
torch_complex.tensor
import
ComplexTensor
except
Exception
:
# If PY2
raise
ValueError
(
error
)
else
:
# If PY3
if
isinstance
(
x
,
ComplexTensor
):
return
x
else
:
raise
ValueError
(
error
)
def
get_subsample
(
train_args
,
mode
,
arch
):
"""Parse the subsampling factors from the args for the specified `mode` and `arch`.
Args:
train_args: argument Namespace containing options.
mode: one of ('asr', 'mt', 'st')
arch: one of ('rnn', 'rnn-t', 'rnn_mix', 'rnn_mulenc', 'transformer')
Returns:
np.ndarray / List[np.ndarray]: subsampling factors.
"""
if
arch
==
"transformer"
:
return
np
.
array
([
1
])
elif
mode
==
"mt"
and
arch
==
"rnn"
:
# +1 means input (+1) and layers outputs (train_args.elayer)
subsample
=
np
.
ones
(
train_args
.
elayers
+
1
,
dtype
=
np
.
int64
)
logging
.
warning
(
"Subsampling is not performed for machine translation."
)
logging
.
info
(
"subsample: "
+
" "
.
join
([
str
(
x
)
for
x
in
subsample
]))
return
subsample
elif
(
(
mode
==
"asr"
and
arch
in
(
"rnn"
,
"rnn-t"
))
or
(
mode
==
"mt"
and
arch
==
"rnn"
)
or
(
mode
==
"st"
and
arch
==
"rnn"
)
):
subsample
=
np
.
ones
(
train_args
.
elayers
+
1
,
dtype
=
np
.
int64
)
if
train_args
.
etype
.
endswith
(
"p"
)
and
not
train_args
.
etype
.
startswith
(
"vgg"
):
ss
=
train_args
.
subsample
.
split
(
"_"
)
for
j
in
range
(
min
(
train_args
.
elayers
+
1
,
len
(
ss
))):
subsample
[
j
]
=
int
(
ss
[
j
])
else
:
logging
.
warning
(
"Subsampling is not performed for vgg*. "
"It is performed in max pooling layers at CNN."
)
logging
.
info
(
"subsample: "
+
" "
.
join
([
str
(
x
)
for
x
in
subsample
]))
return
subsample
elif
mode
==
"asr"
and
arch
==
"rnn_mix"
:
subsample
=
np
.
ones
(
train_args
.
elayers_sd
+
train_args
.
elayers
+
1
,
dtype
=
np
.
int64
)
if
train_args
.
etype
.
endswith
(
"p"
)
and
not
train_args
.
etype
.
startswith
(
"vgg"
):
ss
=
train_args
.
subsample
.
split
(
"_"
)
for
j
in
range
(
min
(
train_args
.
elayers_sd
+
train_args
.
elayers
+
1
,
len
(
ss
))
):
subsample
[
j
]
=
int
(
ss
[
j
])
else
:
logging
.
warning
(
"Subsampling is not performed for vgg*. "
"It is performed in max pooling layers at CNN."
)
logging
.
info
(
"subsample: "
+
" "
.
join
([
str
(
x
)
for
x
in
subsample
]))
return
subsample
elif
mode
==
"asr"
and
arch
==
"rnn_mulenc"
:
subsample_list
=
[]
for
idx
in
range
(
train_args
.
num_encs
):
subsample
=
np
.
ones
(
train_args
.
elayers
[
idx
]
+
1
,
dtype
=
np
.
int64
)
if
train_args
.
etype
[
idx
].
endswith
(
"p"
)
and
not
train_args
.
etype
[
idx
].
startswith
(
"vgg"
):
ss
=
train_args
.
subsample
[
idx
].
split
(
"_"
)
for
j
in
range
(
min
(
train_args
.
elayers
[
idx
]
+
1
,
len
(
ss
))):
subsample
[
j
]
=
int
(
ss
[
j
])
else
:
logging
.
warning
(
"Encoder %d: Subsampling is not performed for vgg*. "
"It is performed in max pooling layers at CNN."
,
idx
+
1
,
)
logging
.
info
(
"subsample: "
+
" "
.
join
([
str
(
x
)
for
x
in
subsample
]))
subsample_list
.
append
(
subsample
)
return
subsample_list
else
:
raise
ValueError
(
"Invalid options: mode={}, arch={}"
.
format
(
mode
,
arch
))
def
rename_state_dict
(
old_prefix
:
str
,
new_prefix
:
str
,
state_dict
:
Dict
[
str
,
torch
.
Tensor
]
):
"""Replace keys of old prefix with new prefix in state dict."""
# need this list not to break the dict iterator
old_keys
=
[
k
for
k
in
state_dict
if
k
.
startswith
(
old_prefix
)]
if
len
(
old_keys
)
>
0
:
logging
.
warning
(
f
"Rename:
{
old_prefix
}
->
{
new_prefix
}
"
)
for
k
in
old_keys
:
v
=
state_dict
.
pop
(
k
)
new_k
=
k
.
replace
(
old_prefix
,
new_prefix
)
state_dict
[
new_k
]
=
v
def
get_activation
(
act
):
"""Return activation function."""
# Lazy load to avoid unused import
from
espnet.nets.pytorch_backend.conformer.swish
import
Swish
activation_funcs
=
{
"hardtanh"
:
torch
.
nn
.
Hardtanh
,
"tanh"
:
torch
.
nn
.
Tanh
,
"relu"
:
torch
.
nn
.
ReLU
,
"selu"
:
torch
.
nn
.
SELU
,
"swish"
:
Swish
,
}
return
activation_funcs
[
act
]()
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/argument.py
0 → 100644
View file @
60a2c57a
# Copyright 2020 Hirofumi Inaguma
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Conformer common arguments."""
def
add_arguments_rnn_encoder_common
(
group
):
"""Define common arguments for RNN encoder."""
group
.
add_argument
(
"--etype"
,
default
=
"blstmp"
,
type
=
str
,
choices
=
[
"lstm"
,
"blstm"
,
"lstmp"
,
"blstmp"
,
"vgglstmp"
,
"vggblstmp"
,
"vgglstm"
,
"vggblstm"
,
"gru"
,
"bgru"
,
"grup"
,
"bgrup"
,
"vgggrup"
,
"vggbgrup"
,
"vgggru"
,
"vggbgru"
,
],
help
=
"Type of encoder network architecture"
,
)
group
.
add_argument
(
"--elayers"
,
default
=
4
,
type
=
int
,
help
=
"Number of encoder layers"
,
)
group
.
add_argument
(
"--eunits"
,
"-u"
,
default
=
300
,
type
=
int
,
help
=
"Number of encoder hidden units"
,
)
group
.
add_argument
(
"--eprojs"
,
default
=
320
,
type
=
int
,
help
=
"Number of encoder projection units"
)
group
.
add_argument
(
"--subsample"
,
default
=
"1"
,
type
=
str
,
help
=
"Subsample input frames x_y_z means "
"subsample every x frame at 1st layer, "
"every y frame at 2nd layer etc."
,
)
return
group
def
add_arguments_rnn_decoder_common
(
group
):
"""Define common arguments for RNN decoder."""
group
.
add_argument
(
"--dtype"
,
default
=
"lstm"
,
type
=
str
,
choices
=
[
"lstm"
,
"gru"
],
help
=
"Type of decoder network architecture"
,
)
group
.
add_argument
(
"--dlayers"
,
default
=
1
,
type
=
int
,
help
=
"Number of decoder layers"
)
group
.
add_argument
(
"--dunits"
,
default
=
320
,
type
=
int
,
help
=
"Number of decoder hidden units"
)
group
.
add_argument
(
"--dropout-rate-decoder"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate for the decoder"
,
)
group
.
add_argument
(
"--sampling-probability"
,
default
=
0.0
,
type
=
float
,
help
=
"Ratio of predicted labels fed back to decoder"
,
)
group
.
add_argument
(
"--lsm-type"
,
const
=
""
,
default
=
""
,
type
=
str
,
nargs
=
"?"
,
choices
=
[
""
,
"unigram"
],
help
=
"Apply label smoothing with a specified distribution type"
,
)
return
group
def
add_arguments_rnn_attention_common
(
group
):
"""Define common arguments for RNN attention."""
group
.
add_argument
(
"--atype"
,
default
=
"dot"
,
type
=
str
,
choices
=
[
"noatt"
,
"dot"
,
"add"
,
"location"
,
"coverage"
,
"coverage_location"
,
"location2d"
,
"location_recurrent"
,
"multi_head_dot"
,
"multi_head_add"
,
"multi_head_loc"
,
"multi_head_multi_res_loc"
,
],
help
=
"Type of attention architecture"
,
)
group
.
add_argument
(
"--adim"
,
default
=
320
,
type
=
int
,
help
=
"Number of attention transformation dimensions"
,
)
group
.
add_argument
(
"--awin"
,
default
=
5
,
type
=
int
,
help
=
"Window size for location2d attention"
)
group
.
add_argument
(
"--aheads"
,
default
=
4
,
type
=
int
,
help
=
"Number of heads for multi head attention"
,
)
group
.
add_argument
(
"--aconv-chans"
,
default
=-
1
,
type
=
int
,
help
=
"Number of attention convolution channels
\
(negative value indicates no location-aware attention)"
,
)
group
.
add_argument
(
"--aconv-filts"
,
default
=
100
,
type
=
int
,
help
=
"Number of attention convolution filters
\
(negative value indicates no location-aware attention)"
,
)
group
.
add_argument
(
"--dropout-rate"
,
default
=
0.0
,
type
=
float
,
help
=
"Dropout rate for the encoder"
,
)
return
group
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/attentions.py
0 → 100644
View file @
60a2c57a
"""Attention modules for RNN."""
import
math
import
torch
import
torch.nn.functional
as
F
from
espnet.nets.pytorch_backend.nets_utils
import
make_pad_mask
,
to_device
def
_apply_attention_constraint
(
e
,
last_attended_idx
,
backward_window
=
1
,
forward_window
=
3
):
"""Apply monotonic attention constraint.
This function apply the monotonic attention constraint
introduced in `Deep Voice 3: Scaling
Text-to-Speech with Convolutional Sequence Learning`_.
Args:
e (Tensor): Attention energy before applying softmax (1, T).
last_attended_idx (int): The index of the inputs of the last attended [0, T].
backward_window (int, optional): Backward window size in attention constraint.
forward_window (int, optional): Forward window size in attetion constraint.
Returns:
Tensor: Monotonic constrained attention energy (1, T).
.. _`Deep Voice 3: Scaling Text-to-Speech with Convolutional Sequence Learning`:
https://arxiv.org/abs/1710.07654
"""
if
e
.
size
(
0
)
!=
1
:
raise
NotImplementedError
(
"Batch attention constraining is not yet supported."
)
backward_idx
=
last_attended_idx
-
backward_window
forward_idx
=
last_attended_idx
+
forward_window
if
backward_idx
>
0
:
e
[:,
:
backward_idx
]
=
-
float
(
"inf"
)
if
forward_idx
<
e
.
size
(
1
):
e
[:,
forward_idx
:]
=
-
float
(
"inf"
)
return
e
class
NoAtt
(
torch
.
nn
.
Module
):
"""No attention"""
def
__init__
(
self
):
super
(
NoAtt
,
self
).
__init__
()
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
c
=
None
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
c
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
):
"""NoAtt forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, T_max, D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights
:rtype: torch.Tensor
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# initialize attention weight with uniform dist.
if
att_prev
is
None
:
# if no bias, 0 0-pad goes 0
mask
=
1.0
-
make_pad_mask
(
enc_hs_len
).
float
()
att_prev
=
mask
/
mask
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
)
att_prev
=
att_prev
.
to
(
self
.
enc_h
)
self
.
c
=
torch
.
sum
(
self
.
enc_h
*
att_prev
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
self
.
c
,
att_prev
class
AttDot
(
torch
.
nn
.
Module
):
"""Dot product attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
han_mode
=
False
):
super
(
AttDot
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
scaling
=
2.0
):
"""AttDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: dummy (does not use)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weight (B x T_max)
:rtype: torch.Tensor
"""
batch
=
enc_hs_pad
.
size
(
0
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
torch
.
tanh
(
self
.
mlp_enc
(
self
.
enc_h
))
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
e
=
torch
.
sum
(
self
.
pre_compute_enc_h
*
torch
.
tanh
(
self
.
mlp_dec
(
dec_z
)).
view
(
batch
,
1
,
self
.
att_dim
),
dim
=
2
,
)
# utt x frame
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
c
,
w
class
AttAdd
(
torch
.
nn
.
Module
):
"""Additive attention
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
han_mode
=
False
):
super
(
AttAdd
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
scaling
=
2.0
):
"""AttAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
self
.
pre_compute_enc_h
+
dec_z_tiled
)).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
c
,
w
class
AttLoc
(
torch
.
nn
.
Module
):
"""location-aware attention module.
Reference: Attention-Based Models for Speech Recognition
(https://arxiv.org/pdf/1506.07503.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
):
super
(
AttLoc
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
mlp_att
=
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim
,
bias
=
False
)
self
.
loc_conv
=
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
scaling
=
2.0
,
last_attended_idx
=
None
,
backward_window
=
1
,
forward_window
=
3
,
):
"""Calculate AttLoc forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x T_max)
:param float scaling: scaling parameter before applying softmax
:param torch.Tensor forward_window:
forward window size when constraining attention
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
# initialize attention weight with uniform dist.
if
att_prev
is
None
:
# if no bias, 0 0-pad goes 0
att_prev
=
1.0
-
make_pad_mask
(
enc_hs_len
).
to
(
device
=
dec_z
.
device
,
dtype
=
dec_z
.
dtype
)
att_prev
=
att_prev
/
att_prev
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
)
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv
=
self
.
loc_conv
(
att_prev
.
view
(
batch
,
1
,
1
,
self
.
h_length
))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv
=
self
.
mlp_att
(
att_conv
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
att_conv
+
self
.
pre_compute_enc_h
+
dec_z_tiled
)
).
squeeze
(
2
)
# NOTE: consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
# apply monotonic attention constraint (mainly for TTS)
if
last_attended_idx
is
not
None
:
e
=
_apply_attention_constraint
(
e
,
last_attended_idx
,
backward_window
,
forward_window
)
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
c
,
w
class
AttCov
(
torch
.
nn
.
Module
):
"""Coverage mechanism attention
Reference: Get To The Point: Summarization with Pointer-Generator Network
(https://arxiv.org/abs/1704.04368)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
han_mode
=
False
):
super
(
AttCov
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
wvec
=
torch
.
nn
.
Linear
(
1
,
att_dim
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev_list
,
scaling
=
2.0
):
"""AttCov forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
# initialize attention weight with uniform dist.
if
att_prev_list
is
None
:
# if no bias, 0 0-pad goes 0
att_prev_list
=
to_device
(
enc_hs_pad
,
(
1.0
-
make_pad_mask
(
enc_hs_len
).
float
())
)
att_prev_list
=
[
att_prev_list
/
att_prev_list
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
)
]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec
=
sum
(
att_prev_list
)
# cov_vec: B x T => B x T x 1 => B x T x att_dim
cov_vec
=
self
.
wvec
(
cov_vec
.
unsqueeze
(
-
1
))
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
cov_vec
+
self
.
pre_compute_enc_h
+
dec_z_tiled
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
att_prev_list
+=
[
w
]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
c
,
att_prev_list
class
AttLoc2D
(
torch
.
nn
.
Module
):
"""2D location-aware attention
This attention is an extended version of location aware attention.
It take not only one frame before attention weights,
but also earlier frames into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int att_win: attention window size (default=5)
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
att_win
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
):
super
(
AttLoc2D
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
mlp_att
=
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim
,
bias
=
False
)
self
.
loc_conv
=
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
att_win
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
aconv_chans
=
aconv_chans
self
.
att_win
=
att_win
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
scaling
=
2.0
):
"""AttLoc2D forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: previous attention weight (B x att_win x T_max)
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x att_win x T_max)
:rtype: torch.Tensor
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
# initialize attention weight with uniform dist.
if
att_prev
is
None
:
# B * [Li x att_win]
# if no bias, 0 0-pad goes 0
att_prev
=
to_device
(
enc_hs_pad
,
(
1.0
-
make_pad_mask
(
enc_hs_len
).
float
()))
att_prev
=
att_prev
/
att_prev
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
)
att_prev
=
att_prev
.
unsqueeze
(
1
).
expand
(
-
1
,
self
.
att_win
,
-
1
)
# att_prev: B x att_win x Tmax -> B x 1 x att_win x Tmax -> B x C x 1 x Tmax
att_conv
=
self
.
loc_conv
(
att_prev
.
unsqueeze
(
1
))
# att_conv: B x C x 1 x Tmax -> B x Tmax x C
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv
=
self
.
mlp_att
(
att_conv
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
att_conv
+
self
.
pre_compute_enc_h
+
dec_z_tiled
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
# update att_prev: B x att_win x Tmax -> B x att_win+1 x Tmax
# -> B x att_win x Tmax
att_prev
=
torch
.
cat
([
att_prev
,
w
.
unsqueeze
(
1
)],
dim
=
1
)
att_prev
=
att_prev
[:,
1
:]
return
c
,
att_prev
class
AttLocRec
(
torch
.
nn
.
Module
):
"""location-aware recurrent attention
This attention is an extended version of location aware attention.
With the use of RNN,
it take the effect of the history of attention weights into account.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
):
super
(
AttLocRec
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
loc_conv
=
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
self
.
att_lstm
=
torch
.
nn
.
LSTMCell
(
aconv_chans
,
att_dim
,
bias
=
False
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev_states
,
scaling
=
2.0
):
"""AttLocRec forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param tuple att_prev_states: previous attention weight and lstm states
((B, T_max), ((B, att_dim), (B, att_dim)))
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights and lstm states (w, (hx, cx))
((B, T_max), ((B, att_dim), (B, att_dim)))
:rtype: tuple
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
if
att_prev_states
is
None
:
# initialize attention weight with uniform dist.
# if no bias, 0 0-pad goes 0
att_prev
=
to_device
(
enc_hs_pad
,
(
1.0
-
make_pad_mask
(
enc_hs_len
).
float
()))
att_prev
=
att_prev
/
att_prev
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
)
# initialize lstm states
att_h
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
att_dim
)
att_c
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
att_dim
)
att_states
=
(
att_h
,
att_c
)
else
:
att_prev
=
att_prev_states
[
0
]
att_states
=
att_prev_states
[
1
]
# B x 1 x 1 x T -> B x C x 1 x T
att_conv
=
self
.
loc_conv
(
att_prev
.
view
(
batch
,
1
,
1
,
self
.
h_length
))
# apply non-linear
att_conv
=
F
.
relu
(
att_conv
)
# B x C x 1 x T -> B x C x 1 x 1 -> B x C
att_conv
=
F
.
max_pool2d
(
att_conv
,
(
1
,
att_conv
.
size
(
3
))).
view
(
batch
,
-
1
)
att_h
,
att_c
=
self
.
att_lstm
(
att_conv
,
att_states
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
att_h
.
unsqueeze
(
1
)
+
self
.
pre_compute_enc_h
+
dec_z_tiled
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
c
,
(
w
,
(
att_h
,
att_c
))
class
AttCovLoc
(
torch
.
nn
.
Module
):
"""Coverage mechanism location aware attention
This attention is a combination of coverage and location-aware attentions.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode:
flag to swith on mode of hierarchical attention and not store pre_compute_enc_h
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
):
super
(
AttCovLoc
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
mlp_att
=
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim
,
bias
=
False
)
self
.
loc_conv
=
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
aconv_chans
=
aconv_chans
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev_list
,
scaling
=
2.0
):
"""AttCovLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param list att_prev_list: list of previous attention weight
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weights
:rtype: list
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
# initialize attention weight with uniform dist.
if
att_prev_list
is
None
:
# if no bias, 0 0-pad goes 0
mask
=
1.0
-
make_pad_mask
(
enc_hs_len
).
float
()
att_prev_list
=
[
to_device
(
enc_hs_pad
,
mask
/
mask
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
))
]
# att_prev_list: L' * [B x T] => cov_vec B x T
cov_vec
=
sum
(
att_prev_list
)
# cov_vec: B x T -> B x 1 x 1 x T -> B x C x 1 x T
att_conv
=
self
.
loc_conv
(
cov_vec
.
view
(
batch
,
1
,
1
,
self
.
h_length
))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv
=
self
.
mlp_att
(
att_conv
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
att_conv
+
self
.
pre_compute_enc_h
+
dec_z_tiled
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
att_prev_list
+=
[
w
]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
return
c
,
att_prev_list
class
AttMultiHeadDot
(
torch
.
nn
.
Module
):
"""Multi head dot product attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def
__init__
(
self
,
eprojs
,
dunits
,
aheads
,
att_dim_k
,
att_dim_v
,
han_mode
=
False
):
super
(
AttMultiHeadDot
,
self
).
__init__
()
self
.
mlp_q
=
torch
.
nn
.
ModuleList
()
self
.
mlp_k
=
torch
.
nn
.
ModuleList
()
self
.
mlp_v
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
aheads
):
self
.
mlp_q
+=
[
torch
.
nn
.
Linear
(
dunits
,
att_dim_k
)]
self
.
mlp_k
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_k
,
bias
=
False
)]
self
.
mlp_v
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_v
,
bias
=
False
)]
self
.
mlp_o
=
torch
.
nn
.
Linear
(
aheads
*
att_dim_v
,
eprojs
,
bias
=
False
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
aheads
=
aheads
self
.
att_dim_k
=
att_dim_k
self
.
att_dim_v
=
att_dim_v
self
.
scaling
=
1.0
/
math
.
sqrt
(
att_dim_k
)
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
):
"""AttMultiHeadDot forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch
=
enc_hs_pad
.
size
(
0
)
# pre-compute all k and v outside the decoder loop
if
self
.
pre_compute_k
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_k
=
[
torch
.
tanh
(
self
.
mlp_k
[
h
](
self
.
enc_h
))
for
h
in
range
(
self
.
aheads
)
]
if
self
.
pre_compute_v
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_v
=
[
self
.
mlp_v
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
c
=
[]
w
=
[]
for
h
in
range
(
self
.
aheads
):
e
=
torch
.
sum
(
self
.
pre_compute_k
[
h
]
*
torch
.
tanh
(
self
.
mlp_q
[
h
](
dec_z
)).
view
(
batch
,
1
,
self
.
att_dim_k
),
dim
=
2
,
)
# utt x frame
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
+=
[
F
.
softmax
(
self
.
scaling
*
e
,
dim
=
1
)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
+=
[
torch
.
sum
(
self
.
pre_compute_v
[
h
]
*
w
[
h
].
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
]
# concat all of c
c
=
self
.
mlp_o
(
torch
.
cat
(
c
,
dim
=
1
))
return
c
,
w
class
AttMultiHeadAdd
(
torch
.
nn
.
Module
):
"""Multi head additive attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using additive attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def
__init__
(
self
,
eprojs
,
dunits
,
aheads
,
att_dim_k
,
att_dim_v
,
han_mode
=
False
):
super
(
AttMultiHeadAdd
,
self
).
__init__
()
self
.
mlp_q
=
torch
.
nn
.
ModuleList
()
self
.
mlp_k
=
torch
.
nn
.
ModuleList
()
self
.
mlp_v
=
torch
.
nn
.
ModuleList
()
self
.
gvec
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
aheads
):
self
.
mlp_q
+=
[
torch
.
nn
.
Linear
(
dunits
,
att_dim_k
)]
self
.
mlp_k
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_k
,
bias
=
False
)]
self
.
mlp_v
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_v
,
bias
=
False
)]
self
.
gvec
+=
[
torch
.
nn
.
Linear
(
att_dim_k
,
1
)]
self
.
mlp_o
=
torch
.
nn
.
Linear
(
aheads
*
att_dim_v
,
eprojs
,
bias
=
False
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
aheads
=
aheads
self
.
att_dim_k
=
att_dim_k
self
.
att_dim_v
=
att_dim_v
self
.
scaling
=
1.0
/
math
.
sqrt
(
att_dim_k
)
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
):
"""AttMultiHeadAdd forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: dummy (does not use)
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch
=
enc_hs_pad
.
size
(
0
)
# pre-compute all k and v outside the decoder loop
if
self
.
pre_compute_k
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_k
=
[
self
.
mlp_k
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
self
.
pre_compute_v
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_v
=
[
self
.
mlp_v
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
c
=
[]
w
=
[]
for
h
in
range
(
self
.
aheads
):
e
=
self
.
gvec
[
h
](
torch
.
tanh
(
self
.
pre_compute_k
[
h
]
+
self
.
mlp_q
[
h
](
dec_z
).
view
(
batch
,
1
,
self
.
att_dim_k
)
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
+=
[
F
.
softmax
(
self
.
scaling
*
e
,
dim
=
1
)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
+=
[
torch
.
sum
(
self
.
pre_compute_v
[
h
]
*
w
[
h
].
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
]
# concat all of c
c
=
self
.
mlp_o
(
torch
.
cat
(
c
,
dim
=
1
))
return
c
,
w
class
AttMultiHeadLoc
(
torch
.
nn
.
Module
):
"""Multi head location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def
__init__
(
self
,
eprojs
,
dunits
,
aheads
,
att_dim_k
,
att_dim_v
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
,
):
super
(
AttMultiHeadLoc
,
self
).
__init__
()
self
.
mlp_q
=
torch
.
nn
.
ModuleList
()
self
.
mlp_k
=
torch
.
nn
.
ModuleList
()
self
.
mlp_v
=
torch
.
nn
.
ModuleList
()
self
.
gvec
=
torch
.
nn
.
ModuleList
()
self
.
loc_conv
=
torch
.
nn
.
ModuleList
()
self
.
mlp_att
=
torch
.
nn
.
ModuleList
()
for
_
in
range
(
aheads
):
self
.
mlp_q
+=
[
torch
.
nn
.
Linear
(
dunits
,
att_dim_k
)]
self
.
mlp_k
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_k
,
bias
=
False
)]
self
.
mlp_v
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_v
,
bias
=
False
)]
self
.
gvec
+=
[
torch
.
nn
.
Linear
(
att_dim_k
,
1
)]
self
.
loc_conv
+=
[
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
]
self
.
mlp_att
+=
[
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim_k
,
bias
=
False
)]
self
.
mlp_o
=
torch
.
nn
.
Linear
(
aheads
*
att_dim_v
,
eprojs
,
bias
=
False
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
aheads
=
aheads
self
.
att_dim_k
=
att_dim_k
self
.
att_dim_v
=
att_dim_v
self
.
scaling
=
1.0
/
math
.
sqrt
(
att_dim_k
)
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
scaling
=
2.0
):
"""AttMultiHeadLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev:
list of previous attention weight (B x T_max) * aheads
:param float scaling: scaling parameter before applying softmax
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch
=
enc_hs_pad
.
size
(
0
)
# pre-compute all k and v outside the decoder loop
if
self
.
pre_compute_k
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_k
=
[
self
.
mlp_k
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
self
.
pre_compute_v
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_v
=
[
self
.
mlp_v
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
if
att_prev
is
None
:
att_prev
=
[]
for
_
in
range
(
self
.
aheads
):
# if no bias, 0 0-pad goes 0
mask
=
1.0
-
make_pad_mask
(
enc_hs_len
).
float
()
att_prev
+=
[
to_device
(
enc_hs_pad
,
mask
/
mask
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
))
]
c
=
[]
w
=
[]
for
h
in
range
(
self
.
aheads
):
att_conv
=
self
.
loc_conv
[
h
](
att_prev
[
h
].
view
(
batch
,
1
,
1
,
self
.
h_length
))
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
att_conv
=
self
.
mlp_att
[
h
](
att_conv
)
e
=
self
.
gvec
[
h
](
torch
.
tanh
(
self
.
pre_compute_k
[
h
]
+
att_conv
+
self
.
mlp_q
[
h
](
dec_z
).
view
(
batch
,
1
,
self
.
att_dim_k
)
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
+=
[
F
.
softmax
(
scaling
*
e
,
dim
=
1
)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
+=
[
torch
.
sum
(
self
.
pre_compute_v
[
h
]
*
w
[
h
].
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
]
# concat all of c
c
=
self
.
mlp_o
(
torch
.
cat
(
c
,
dim
=
1
))
return
c
,
w
class
AttMultiHeadMultiResLoc
(
torch
.
nn
.
Module
):
"""Multi head multi resolution location based attention
Reference: Attention is all you need
(https://arxiv.org/abs/1706.03762)
This attention is multi head attention using location-aware attention for each head.
Furthermore, it uses different filter size for each head.
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int att_dim_k: dimension k in multi head attention
:param int att_dim_v: dimension v in multi head attention
:param int aconv_chans: maximum # channels of attention convolution
each head use #ch = aconv_chans * (head + 1) / aheads
e.g. aheads=4, aconv_chans=100 => filter size = 25, 50, 75, 100
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
and not store pre_compute_k and pre_compute_v
"""
def
__init__
(
self
,
eprojs
,
dunits
,
aheads
,
att_dim_k
,
att_dim_v
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
,
):
super
(
AttMultiHeadMultiResLoc
,
self
).
__init__
()
self
.
mlp_q
=
torch
.
nn
.
ModuleList
()
self
.
mlp_k
=
torch
.
nn
.
ModuleList
()
self
.
mlp_v
=
torch
.
nn
.
ModuleList
()
self
.
gvec
=
torch
.
nn
.
ModuleList
()
self
.
loc_conv
=
torch
.
nn
.
ModuleList
()
self
.
mlp_att
=
torch
.
nn
.
ModuleList
()
for
h
in
range
(
aheads
):
self
.
mlp_q
+=
[
torch
.
nn
.
Linear
(
dunits
,
att_dim_k
)]
self
.
mlp_k
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_k
,
bias
=
False
)]
self
.
mlp_v
+=
[
torch
.
nn
.
Linear
(
eprojs
,
att_dim_v
,
bias
=
False
)]
self
.
gvec
+=
[
torch
.
nn
.
Linear
(
att_dim_k
,
1
)]
afilts
=
aconv_filts
*
(
h
+
1
)
//
aheads
self
.
loc_conv
+=
[
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
afilts
+
1
),
padding
=
(
0
,
afilts
),
bias
=
False
)
]
self
.
mlp_att
+=
[
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim_k
,
bias
=
False
)]
self
.
mlp_o
=
torch
.
nn
.
Linear
(
aheads
*
att_dim_v
,
eprojs
,
bias
=
False
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
aheads
=
aheads
self
.
att_dim_k
=
att_dim_k
self
.
att_dim_v
=
att_dim_v
self
.
scaling
=
1.0
/
math
.
sqrt
(
att_dim_k
)
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
self
.
han_mode
=
han_mode
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_k
=
None
self
.
pre_compute_v
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
):
"""AttMultiHeadMultiResLoc forward
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: list of previous attention weight
(B x T_max) * aheads
:return: attention weighted encoder state (B x D_enc)
:rtype: torch.Tensor
:return: list of previous attention weight (B x T_max) * aheads
:rtype: list
"""
batch
=
enc_hs_pad
.
size
(
0
)
# pre-compute all k and v outside the decoder loop
if
self
.
pre_compute_k
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_k
=
[
self
.
mlp_k
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
self
.
pre_compute_v
is
None
or
self
.
han_mode
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_v
=
[
self
.
mlp_v
[
h
](
self
.
enc_h
)
for
h
in
range
(
self
.
aheads
)]
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
if
att_prev
is
None
:
att_prev
=
[]
for
_
in
range
(
self
.
aheads
):
# if no bias, 0 0-pad goes 0
mask
=
1.0
-
make_pad_mask
(
enc_hs_len
).
float
()
att_prev
+=
[
to_device
(
enc_hs_pad
,
mask
/
mask
.
new
(
enc_hs_len
).
unsqueeze
(
-
1
))
]
c
=
[]
w
=
[]
for
h
in
range
(
self
.
aheads
):
att_conv
=
self
.
loc_conv
[
h
](
att_prev
[
h
].
view
(
batch
,
1
,
1
,
self
.
h_length
))
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
att_conv
=
self
.
mlp_att
[
h
](
att_conv
)
e
=
self
.
gvec
[
h
](
torch
.
tanh
(
self
.
pre_compute_k
[
h
]
+
att_conv
+
self
.
mlp_q
[
h
](
dec_z
).
view
(
batch
,
1
,
self
.
att_dim_k
)
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
w
+=
[
F
.
softmax
(
self
.
scaling
*
e
,
dim
=
1
)]
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
+=
[
torch
.
sum
(
self
.
pre_compute_v
[
h
]
*
w
[
h
].
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
]
# concat all of c
c
=
self
.
mlp_o
(
torch
.
cat
(
c
,
dim
=
1
))
return
c
,
w
class
AttForward
(
torch
.
nn
.
Module
):
"""Forward attention module.
Reference:
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
"""
def
__init__
(
self
,
eprojs
,
dunits
,
att_dim
,
aconv_chans
,
aconv_filts
):
super
(
AttForward
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eprojs
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
mlp_att
=
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim
,
bias
=
False
)
self
.
loc_conv
=
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eprojs
=
eprojs
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
reset
(
self
):
"""reset states"""
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
scaling
=
1.0
,
last_attended_idx
=
None
,
backward_window
=
1
,
forward_window
=
3
,
):
"""Calculate AttForward forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B x T_max x D_enc)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B x D_dec)
:param torch.Tensor att_prev: attention weights of previous step
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, D_enc)
:rtype: torch.Tensor
:return: previous attention weights (B x T_max)
:rtype: torch.Tensor
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
if
att_prev
is
None
:
# initial attention will be [1, 0, 0, ...]
att_prev
=
enc_hs_pad
.
new_zeros
(
*
enc_hs_pad
.
size
()[:
2
])
att_prev
[:,
0
]
=
1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv
=
self
.
loc_conv
(
att_prev
.
view
(
batch
,
1
,
1
,
self
.
h_length
))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv
=
self
.
mlp_att
(
att_conv
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
unsqueeze
(
1
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
self
.
pre_compute_enc_h
+
dec_z_tiled
+
att_conv
)
).
squeeze
(
2
)
# NOTE: consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
# apply monotonic attention constraint (mainly for TTS)
if
last_attended_idx
is
not
None
:
e
=
_apply_attention_constraint
(
e
,
last_attended_idx
,
backward_window
,
forward_window
)
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# forward attention
att_prev_shift
=
F
.
pad
(
att_prev
,
(
1
,
0
))[:,
:
-
1
]
w
=
(
att_prev
+
att_prev_shift
)
*
w
# NOTE: clamp is needed to avoid nan gradient
w
=
F
.
normalize
(
torch
.
clamp
(
w
,
1e-6
),
p
=
1
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
unsqueeze
(
-
1
),
dim
=
1
)
return
c
,
w
class
AttForwardTA
(
torch
.
nn
.
Module
):
"""Forward attention with transition agent module.
Reference:
Forward attention in sequence-to-sequence acoustic modeling for speech synthesis
(https://arxiv.org/pdf/1807.06736.pdf)
:param int eunits: # units of encoder
:param int dunits: # units of decoder
:param int att_dim: attention dimension
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param int odim: output dimension
"""
def
__init__
(
self
,
eunits
,
dunits
,
att_dim
,
aconv_chans
,
aconv_filts
,
odim
):
super
(
AttForwardTA
,
self
).
__init__
()
self
.
mlp_enc
=
torch
.
nn
.
Linear
(
eunits
,
att_dim
)
self
.
mlp_dec
=
torch
.
nn
.
Linear
(
dunits
,
att_dim
,
bias
=
False
)
self
.
mlp_ta
=
torch
.
nn
.
Linear
(
eunits
+
dunits
+
odim
,
1
)
self
.
mlp_att
=
torch
.
nn
.
Linear
(
aconv_chans
,
att_dim
,
bias
=
False
)
self
.
loc_conv
=
torch
.
nn
.
Conv2d
(
1
,
aconv_chans
,
(
1
,
2
*
aconv_filts
+
1
),
padding
=
(
0
,
aconv_filts
),
bias
=
False
,
)
self
.
gvec
=
torch
.
nn
.
Linear
(
att_dim
,
1
)
self
.
dunits
=
dunits
self
.
eunits
=
eunits
self
.
att_dim
=
att_dim
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
trans_agent_prob
=
0.5
def
reset
(
self
):
self
.
h_length
=
None
self
.
enc_h
=
None
self
.
pre_compute_enc_h
=
None
self
.
mask
=
None
self
.
trans_agent_prob
=
0.5
def
forward
(
self
,
enc_hs_pad
,
enc_hs_len
,
dec_z
,
att_prev
,
out_prev
,
scaling
=
1.0
,
last_attended_idx
=
None
,
backward_window
=
1
,
forward_window
=
3
,
):
"""Calculate AttForwardTA forward propagation.
:param torch.Tensor enc_hs_pad: padded encoder hidden state (B, Tmax, eunits)
:param list enc_hs_len: padded encoder hidden state length (B)
:param torch.Tensor dec_z: decoder hidden state (B, dunits)
:param torch.Tensor att_prev: attention weights of previous step
:param torch.Tensor out_prev: decoder outputs of previous step (B, odim)
:param float scaling: scaling parameter before applying softmax
:param int last_attended_idx: index of the inputs of the last attended
:param int backward_window: backward window size in attention constraint
:param int forward_window: forward window size in attetion constraint
:return: attention weighted encoder state (B, dunits)
:rtype: torch.Tensor
:return: previous attention weights (B, Tmax)
:rtype: torch.Tensor
"""
batch
=
len
(
enc_hs_pad
)
# pre-compute all h outside the decoder loop
if
self
.
pre_compute_enc_h
is
None
:
self
.
enc_h
=
enc_hs_pad
# utt x frame x hdim
self
.
h_length
=
self
.
enc_h
.
size
(
1
)
# utt x frame x att_dim
self
.
pre_compute_enc_h
=
self
.
mlp_enc
(
self
.
enc_h
)
if
dec_z
is
None
:
dec_z
=
enc_hs_pad
.
new_zeros
(
batch
,
self
.
dunits
)
else
:
dec_z
=
dec_z
.
view
(
batch
,
self
.
dunits
)
if
att_prev
is
None
:
# initial attention will be [1, 0, 0, ...]
att_prev
=
enc_hs_pad
.
new_zeros
(
*
enc_hs_pad
.
size
()[:
2
])
att_prev
[:,
0
]
=
1.0
# att_prev: utt x frame -> utt x 1 x 1 x frame
# -> utt x att_conv_chans x 1 x frame
att_conv
=
self
.
loc_conv
(
att_prev
.
view
(
batch
,
1
,
1
,
self
.
h_length
))
# att_conv: utt x att_conv_chans x 1 x frame -> utt x frame x att_conv_chans
att_conv
=
att_conv
.
squeeze
(
2
).
transpose
(
1
,
2
)
# att_conv: utt x frame x att_conv_chans -> utt x frame x att_dim
att_conv
=
self
.
mlp_att
(
att_conv
)
# dec_z_tiled: utt x frame x att_dim
dec_z_tiled
=
self
.
mlp_dec
(
dec_z
).
view
(
batch
,
1
,
self
.
att_dim
)
# dot with gvec
# utt x frame x att_dim -> utt x frame
e
=
self
.
gvec
(
torch
.
tanh
(
att_conv
+
self
.
pre_compute_enc_h
+
dec_z_tiled
)
).
squeeze
(
2
)
# NOTE consider zero padding when compute w.
if
self
.
mask
is
None
:
self
.
mask
=
to_device
(
enc_hs_pad
,
make_pad_mask
(
enc_hs_len
))
e
.
masked_fill_
(
self
.
mask
,
-
float
(
"inf"
))
# apply monotonic attention constraint (mainly for TTS)
if
last_attended_idx
is
not
None
:
e
=
_apply_attention_constraint
(
e
,
last_attended_idx
,
backward_window
,
forward_window
)
w
=
F
.
softmax
(
scaling
*
e
,
dim
=
1
)
# forward attention
att_prev_shift
=
F
.
pad
(
att_prev
,
(
1
,
0
))[:,
:
-
1
]
w
=
(
self
.
trans_agent_prob
*
att_prev
+
(
1
-
self
.
trans_agent_prob
)
*
att_prev_shift
)
*
w
# NOTE: clamp is needed to avoid nan gradient
w
=
F
.
normalize
(
torch
.
clamp
(
w
,
1e-6
),
p
=
1
,
dim
=
1
)
# weighted sum over flames
# utt x hdim
# NOTE use bmm instead of sum(*)
c
=
torch
.
sum
(
self
.
enc_h
*
w
.
view
(
batch
,
self
.
h_length
,
1
),
dim
=
1
)
# update transition agent prob
self
.
trans_agent_prob
=
torch
.
sigmoid
(
self
.
mlp_ta
(
torch
.
cat
([
c
,
out_prev
,
dec_z
],
dim
=
1
))
)
return
c
,
w
def
att_for
(
args
,
num_att
=
1
,
han_mode
=
False
):
"""Instantiates an attention module given the program arguments
:param Namespace args: The arguments
:param int num_att: number of attention modules
(in multi-speaker case, it can be 2 or more)
:param bool han_mode: switch on/off mode of hierarchical attention network (HAN)
:rtype torch.nn.Module
:return: The attention module
"""
att_list
=
torch
.
nn
.
ModuleList
()
num_encs
=
getattr
(
args
,
"num_encs"
,
1
)
# use getattr to keep compatibility
aheads
=
getattr
(
args
,
"aheads"
,
None
)
awin
=
getattr
(
args
,
"awin"
,
None
)
aconv_chans
=
getattr
(
args
,
"aconv_chans"
,
None
)
aconv_filts
=
getattr
(
args
,
"aconv_filts"
,
None
)
if
num_encs
==
1
:
for
i
in
range
(
num_att
):
att
=
initial_att
(
args
.
atype
,
args
.
eprojs
,
args
.
dunits
,
aheads
,
args
.
adim
,
awin
,
aconv_chans
,
aconv_filts
,
)
att_list
.
append
(
att
)
elif
num_encs
>
1
:
# no multi-speaker mode
if
han_mode
:
att
=
initial_att
(
args
.
han_type
,
args
.
eprojs
,
args
.
dunits
,
args
.
han_heads
,
args
.
han_dim
,
args
.
han_win
,
args
.
han_conv_chans
,
args
.
han_conv_filts
,
han_mode
=
True
,
)
return
att
else
:
att_list
=
torch
.
nn
.
ModuleList
()
for
idx
in
range
(
num_encs
):
att
=
initial_att
(
args
.
atype
[
idx
],
args
.
eprojs
,
args
.
dunits
,
aheads
[
idx
],
args
.
adim
[
idx
],
awin
[
idx
],
aconv_chans
[
idx
],
aconv_filts
[
idx
],
)
att_list
.
append
(
att
)
else
:
raise
ValueError
(
"Number of encoders needs to be more than one. {}"
.
format
(
num_encs
)
)
return
att_list
def
initial_att
(
atype
,
eprojs
,
dunits
,
aheads
,
adim
,
awin
,
aconv_chans
,
aconv_filts
,
han_mode
=
False
):
"""Instantiates a single attention module
:param str atype: attention type
:param int eprojs: # projection-units of encoder
:param int dunits: # units of decoder
:param int aheads: # heads of multi head attention
:param int adim: attention dimension
:param int awin: attention window size
:param int aconv_chans: # channels of attention convolution
:param int aconv_filts: filter size of attention convolution
:param bool han_mode: flag to swith on mode of hierarchical attention
:return: The attention module
"""
if
atype
==
"noatt"
:
att
=
NoAtt
()
elif
atype
==
"dot"
:
att
=
AttDot
(
eprojs
,
dunits
,
adim
,
han_mode
)
elif
atype
==
"add"
:
att
=
AttAdd
(
eprojs
,
dunits
,
adim
,
han_mode
)
elif
atype
==
"location"
:
att
=
AttLoc
(
eprojs
,
dunits
,
adim
,
aconv_chans
,
aconv_filts
,
han_mode
)
elif
atype
==
"location2d"
:
att
=
AttLoc2D
(
eprojs
,
dunits
,
adim
,
awin
,
aconv_chans
,
aconv_filts
,
han_mode
)
elif
atype
==
"location_recurrent"
:
att
=
AttLocRec
(
eprojs
,
dunits
,
adim
,
aconv_chans
,
aconv_filts
,
han_mode
)
elif
atype
==
"coverage"
:
att
=
AttCov
(
eprojs
,
dunits
,
adim
,
han_mode
)
elif
atype
==
"coverage_location"
:
att
=
AttCovLoc
(
eprojs
,
dunits
,
adim
,
aconv_chans
,
aconv_filts
,
han_mode
)
elif
atype
==
"multi_head_dot"
:
att
=
AttMultiHeadDot
(
eprojs
,
dunits
,
aheads
,
adim
,
adim
,
han_mode
)
elif
atype
==
"multi_head_add"
:
att
=
AttMultiHeadAdd
(
eprojs
,
dunits
,
aheads
,
adim
,
adim
,
han_mode
)
elif
atype
==
"multi_head_loc"
:
att
=
AttMultiHeadLoc
(
eprojs
,
dunits
,
aheads
,
adim
,
adim
,
aconv_chans
,
aconv_filts
,
han_mode
)
elif
atype
==
"multi_head_multi_res_loc"
:
att
=
AttMultiHeadMultiResLoc
(
eprojs
,
dunits
,
aheads
,
adim
,
adim
,
aconv_chans
,
aconv_filts
,
han_mode
)
return
att
def
att_to_numpy
(
att_ws
,
att
):
"""Converts attention weights to a numpy array given the attention
:param list att_ws: The attention weights
:param torch.nn.Module att: The attention
:rtype: np.ndarray
:return: The numpy array of the attention weights
"""
# convert to numpy array with the shape (B, Lmax, Tmax)
if
isinstance
(
att
,
AttLoc2D
):
# att_ws => list of previous concate attentions
att_ws
=
torch
.
stack
([
aw
[:,
-
1
]
for
aw
in
att_ws
],
dim
=
1
).
cpu
().
numpy
()
elif
isinstance
(
att
,
(
AttCov
,
AttCovLoc
)):
# att_ws => list of list of previous attentions
att_ws
=
(
torch
.
stack
([
aw
[
idx
]
for
idx
,
aw
in
enumerate
(
att_ws
)],
dim
=
1
).
cpu
().
numpy
()
)
elif
isinstance
(
att
,
AttLocRec
):
# att_ws => list of tuple of attention and hidden states
att_ws
=
torch
.
stack
([
aw
[
0
]
for
aw
in
att_ws
],
dim
=
1
).
cpu
().
numpy
()
elif
isinstance
(
att
,
(
AttMultiHeadDot
,
AttMultiHeadAdd
,
AttMultiHeadLoc
,
AttMultiHeadMultiResLoc
),
):
# att_ws => list of list of each head attention
n_heads
=
len
(
att_ws
[
0
])
att_ws_sorted_by_head
=
[]
for
h
in
range
(
n_heads
):
att_ws_head
=
torch
.
stack
([
aw
[
h
]
for
aw
in
att_ws
],
dim
=
1
)
att_ws_sorted_by_head
+=
[
att_ws_head
]
att_ws
=
torch
.
stack
(
att_ws_sorted_by_head
,
dim
=
1
).
cpu
().
numpy
()
else
:
# att_ws => list of attentions
att_ws
=
torch
.
stack
(
att_ws
,
dim
=
1
).
cpu
().
numpy
()
return
att_ws
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/decoders.py
0 → 100644
View file @
60a2c57a
"""RNN decoder module."""
import
logging
import
math
import
random
from
argparse
import
Namespace
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
espnet.nets.ctc_prefix_score
import
CTCPrefixScore
,
CTCPrefixScoreTH
from
espnet.nets.e2e_asr_common
import
end_detect
from
espnet.nets.pytorch_backend.nets_utils
import
(
mask_by_length
,
pad_list
,
th_accuracy
,
to_device
,
)
from
espnet.nets.pytorch_backend.rnn.attentions
import
att_to_numpy
from
espnet.nets.scorer_interface
import
ScorerInterface
MAX_DECODER_OUTPUT
=
5
CTC_SCORING_RATIO
=
1.5
class
Decoder
(
torch
.
nn
.
Module
,
ScorerInterface
):
"""Decoder module
:param int eprojs: encoder projection units
:param int odim: dimension of outputs
:param str dtype: gru or lstm
:param int dlayers: decoder layers
:param int dunits: decoder units
:param int sos: start of sequence symbol id
:param int eos: end of sequence symbol id
:param torch.nn.Module att: attention module
:param int verbose: verbose level
:param list char_list: list of character strings
:param ndarray labeldist: distribution of label smoothing
:param float lsm_weight: label smoothing weight
:param float sampling_probability: scheduled sampling probability
:param float dropout: dropout rate
:param float context_residual: if True, use context vector for token generation
:param float replace_sos: use for multilingual (speech/text) translation
"""
def
__init__
(
self
,
eprojs
,
odim
,
dtype
,
dlayers
,
dunits
,
sos
,
eos
,
att
,
verbose
=
0
,
char_list
=
None
,
labeldist
=
None
,
lsm_weight
=
0.0
,
sampling_probability
=
0.0
,
dropout
=
0.0
,
context_residual
=
False
,
replace_sos
=
False
,
num_encs
=
1
,
):
torch
.
nn
.
Module
.
__init__
(
self
)
self
.
dtype
=
dtype
self
.
dunits
=
dunits
self
.
dlayers
=
dlayers
self
.
context_residual
=
context_residual
self
.
embed
=
torch
.
nn
.
Embedding
(
odim
,
dunits
)
self
.
dropout_emb
=
torch
.
nn
.
Dropout
(
p
=
dropout
)
self
.
decoder
=
torch
.
nn
.
ModuleList
()
self
.
dropout_dec
=
torch
.
nn
.
ModuleList
()
self
.
decoder
+=
[
torch
.
nn
.
LSTMCell
(
dunits
+
eprojs
,
dunits
)
if
self
.
dtype
==
"lstm"
else
torch
.
nn
.
GRUCell
(
dunits
+
eprojs
,
dunits
)
]
self
.
dropout_dec
+=
[
torch
.
nn
.
Dropout
(
p
=
dropout
)]
for
_
in
range
(
1
,
self
.
dlayers
):
self
.
decoder
+=
[
torch
.
nn
.
LSTMCell
(
dunits
,
dunits
)
if
self
.
dtype
==
"lstm"
else
torch
.
nn
.
GRUCell
(
dunits
,
dunits
)
]
self
.
dropout_dec
+=
[
torch
.
nn
.
Dropout
(
p
=
dropout
)]
# NOTE: dropout is applied only for the vertical connections
# see https://arxiv.org/pdf/1409.2329.pdf
self
.
ignore_id
=
-
1
if
context_residual
:
self
.
output
=
torch
.
nn
.
Linear
(
dunits
+
eprojs
,
odim
)
else
:
self
.
output
=
torch
.
nn
.
Linear
(
dunits
,
odim
)
self
.
loss
=
None
self
.
att
=
att
self
.
dunits
=
dunits
self
.
sos
=
sos
self
.
eos
=
eos
self
.
odim
=
odim
self
.
verbose
=
verbose
self
.
char_list
=
char_list
# for label smoothing
self
.
labeldist
=
labeldist
self
.
vlabeldist
=
None
self
.
lsm_weight
=
lsm_weight
self
.
sampling_probability
=
sampling_probability
self
.
dropout
=
dropout
self
.
num_encs
=
num_encs
# for multilingual E2E-ST
self
.
replace_sos
=
replace_sos
self
.
logzero
=
-
10000000000.0
def
zero_state
(
self
,
hs_pad
):
return
hs_pad
.
new_zeros
(
hs_pad
.
size
(
0
),
self
.
dunits
)
def
rnn_forward
(
self
,
ey
,
z_list
,
c_list
,
z_prev
,
c_prev
):
if
self
.
dtype
==
"lstm"
:
z_list
[
0
],
c_list
[
0
]
=
self
.
decoder
[
0
](
ey
,
(
z_prev
[
0
],
c_prev
[
0
]))
for
i
in
range
(
1
,
self
.
dlayers
):
z_list
[
i
],
c_list
[
i
]
=
self
.
decoder
[
i
](
self
.
dropout_dec
[
i
-
1
](
z_list
[
i
-
1
]),
(
z_prev
[
i
],
c_prev
[
i
])
)
else
:
z_list
[
0
]
=
self
.
decoder
[
0
](
ey
,
z_prev
[
0
])
for
i
in
range
(
1
,
self
.
dlayers
):
z_list
[
i
]
=
self
.
decoder
[
i
](
self
.
dropout_dec
[
i
-
1
](
z_list
[
i
-
1
]),
z_prev
[
i
]
)
return
z_list
,
c_list
def
forward
(
self
,
hs_pad
,
hlens
,
ys_pad
,
strm_idx
=
0
,
lang_ids
=
None
):
"""Decoder forward
:param torch.Tensor hs_pad: batch of padded hidden state sequences (B, Tmax, D)
[in multi-encoder case,
list of torch.Tensor,
[(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlens: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor,
[(B), (B), ..., ]
:param torch.Tensor ys_pad: batch of padded character id sequence tensor
(B, Lmax)
:param int strm_idx: stream index indicates the index of decoding stream.
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention loss value
:rtype: torch.Tensor
:return: accuracy
:rtype: float
"""
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if
self
.
num_encs
==
1
:
hs_pad
=
[
hs_pad
]
hlens
=
[
hlens
]
# TODO(kan-bayashi): need to make more smart way
ys
=
[
y
[
y
!=
self
.
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
# attention index for the attention module
# in SPA (speaker parallel attention),
# att_idx is used to select attention module. In other cases, it is 0.
att_idx
=
min
(
strm_idx
,
len
(
self
.
att
)
-
1
)
# hlens should be list of list of integer
hlens
=
[
list
(
map
(
int
,
hlens
[
idx
]))
for
idx
in
range
(
self
.
num_encs
)]
self
.
loss
=
None
# prepare input and output word sequences with sos/eos IDs
eos
=
ys
[
0
].
new
([
self
.
eos
])
sos
=
ys
[
0
].
new
([
self
.
sos
])
if
self
.
replace_sos
:
ys_in
=
[
torch
.
cat
([
idx
,
y
],
dim
=
0
)
for
idx
,
y
in
zip
(
lang_ids
,
ys
)]
else
:
ys_in
=
[
torch
.
cat
([
sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
torch
.
cat
([
y
,
eos
],
dim
=
0
)
for
y
in
ys
]
# padding for ys with -1
# pys: utt x olen
ys_in_pad
=
pad_list
(
ys_in
,
self
.
eos
)
ys_out_pad
=
pad_list
(
ys_out
,
self
.
ignore_id
)
# get dim, length info
batch
=
ys_out_pad
.
size
(
0
)
olength
=
ys_out_pad
.
size
(
1
)
for
idx
in
range
(
self
.
num_encs
):
logging
.
info
(
self
.
__class__
.
__name__
+
"Number of Encoder:{}; enc{}: input lengths: {}."
.
format
(
self
.
num_encs
,
idx
+
1
,
hlens
[
idx
]
)
)
logging
.
info
(
self
.
__class__
.
__name__
+
" output lengths: "
+
str
([
y
.
size
(
0
)
for
y
in
ys_out
])
)
# initialization
c_list
=
[
self
.
zero_state
(
hs_pad
[
0
])]
z_list
=
[
self
.
zero_state
(
hs_pad
[
0
])]
for
_
in
range
(
1
,
self
.
dlayers
):
c_list
.
append
(
self
.
zero_state
(
hs_pad
[
0
]))
z_list
.
append
(
self
.
zero_state
(
hs_pad
[
0
]))
z_all
=
[]
if
self
.
num_encs
==
1
:
att_w
=
None
self
.
att
[
att_idx
].
reset
()
# reset pre-computation of h
else
:
att_w_list
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_c_list
=
[
None
]
*
(
self
.
num_encs
)
# atts
for
idx
in
range
(
self
.
num_encs
+
1
):
self
.
att
[
idx
].
reset
()
# reset pre-computation of h in atts and han
# pre-computation of embedding
eys
=
self
.
dropout_emb
(
self
.
embed
(
ys_in_pad
))
# utt x olen x zdim
# loop for an output sequence
for
i
in
range
(
olength
):
if
self
.
num_encs
==
1
:
att_c
,
att_w
=
self
.
att
[
att_idx
](
hs_pad
[
0
],
hlens
[
0
],
self
.
dropout_dec
[
0
](
z_list
[
0
]),
att_w
)
else
:
for
idx
in
range
(
self
.
num_encs
):
att_c_list
[
idx
],
att_w_list
[
idx
]
=
self
.
att
[
idx
](
hs_pad
[
idx
],
hlens
[
idx
],
self
.
dropout_dec
[
0
](
z_list
[
0
]),
att_w_list
[
idx
],
)
hs_pad_han
=
torch
.
stack
(
att_c_list
,
dim
=
1
)
hlens_han
=
[
self
.
num_encs
]
*
len
(
ys_in
)
att_c
,
att_w_list
[
self
.
num_encs
]
=
self
.
att
[
self
.
num_encs
](
hs_pad_han
,
hlens_han
,
self
.
dropout_dec
[
0
](
z_list
[
0
]),
att_w_list
[
self
.
num_encs
],
)
if
i
>
0
and
random
.
random
()
<
self
.
sampling_probability
:
logging
.
info
(
" scheduled sampling "
)
z_out
=
self
.
output
(
z_all
[
-
1
])
z_out
=
np
.
argmax
(
z_out
.
detach
().
cpu
(),
axis
=
1
)
z_out
=
self
.
dropout_emb
(
self
.
embed
(
to_device
(
hs_pad
[
0
],
z_out
)))
ey
=
torch
.
cat
((
z_out
,
att_c
),
dim
=
1
)
# utt x (zdim + hdim)
else
:
ey
=
torch
.
cat
((
eys
[:,
i
,
:],
att_c
),
dim
=
1
)
# utt x (zdim + hdim)
z_list
,
c_list
=
self
.
rnn_forward
(
ey
,
z_list
,
c_list
,
z_list
,
c_list
)
if
self
.
context_residual
:
z_all
.
append
(
torch
.
cat
((
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]),
att_c
),
dim
=-
1
)
)
# utt x (zdim + hdim)
else
:
z_all
.
append
(
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]))
# utt x (zdim)
z_all
=
torch
.
stack
(
z_all
,
dim
=
1
).
view
(
batch
*
olength
,
-
1
)
# compute loss
y_all
=
self
.
output
(
z_all
)
self
.
loss
=
F
.
cross_entropy
(
y_all
,
ys_out_pad
.
view
(
-
1
),
ignore_index
=
self
.
ignore_id
,
reduction
=
"mean"
,
)
# compute perplexity
ppl
=
math
.
exp
(
self
.
loss
.
item
())
# -1: eos, which is removed in the loss computation
self
.
loss
*=
np
.
mean
([
len
(
x
)
for
x
in
ys_in
])
-
1
acc
=
th_accuracy
(
y_all
,
ys_out_pad
,
ignore_label
=
self
.
ignore_id
)
logging
.
info
(
"att loss:"
+
""
.
join
(
str
(
self
.
loss
.
item
()).
split
(
"
\n
"
)))
# show predicted character sequence for debug
if
self
.
verbose
>
0
and
self
.
char_list
is
not
None
:
ys_hat
=
y_all
.
view
(
batch
,
olength
,
-
1
)
ys_true
=
ys_out_pad
for
(
i
,
y_hat
),
y_true
in
zip
(
enumerate
(
ys_hat
.
detach
().
cpu
().
numpy
()),
ys_true
.
detach
().
cpu
().
numpy
()
):
if
i
==
MAX_DECODER_OUTPUT
:
break
idx_hat
=
np
.
argmax
(
y_hat
[
y_true
!=
self
.
ignore_id
],
axis
=
1
)
idx_true
=
y_true
[
y_true
!=
self
.
ignore_id
]
seq_hat
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
idx_hat
]
seq_true
=
[
self
.
char_list
[
int
(
idx
)]
for
idx
in
idx_true
]
seq_hat
=
""
.
join
(
seq_hat
)
seq_true
=
""
.
join
(
seq_true
)
logging
.
info
(
"groundtruth[%d]: "
%
i
+
seq_true
)
logging
.
info
(
"prediction [%d]: "
%
i
+
seq_hat
)
if
self
.
labeldist
is
not
None
:
if
self
.
vlabeldist
is
None
:
self
.
vlabeldist
=
to_device
(
hs_pad
[
0
],
torch
.
from_numpy
(
self
.
labeldist
))
loss_reg
=
-
torch
.
sum
(
(
F
.
log_softmax
(
y_all
,
dim
=
1
)
*
self
.
vlabeldist
).
view
(
-
1
),
dim
=
0
)
/
len
(
ys_in
)
self
.
loss
=
(
1.0
-
self
.
lsm_weight
)
*
self
.
loss
+
self
.
lsm_weight
*
loss_reg
return
self
.
loss
,
acc
,
ppl
def
recognize_beam
(
self
,
h
,
lpz
,
recog_args
,
char_list
,
rnnlm
=
None
,
strm_idx
=
0
):
"""beam search implementation
:param torch.Tensor h: encoder hidden state (T, eprojs)
[in multi-encoder case, list of torch.Tensor,
[(T1, eprojs), (T2, eprojs), ...] ]
:param torch.Tensor lpz: ctc log softmax output (T, odim)
[in multi-encoder case, list of torch.Tensor,
[(T1, odim), (T2, odim), ...] ]
:param Namespace recog_args: argument Namespace containing options
:param char_list: list of character strings
:param torch.nn.Module rnnlm: language module
:param int strm_idx:
stream index for speaker parallel attention in multi-speaker case
:return: N-best decoding results
:rtype: list of dicts
"""
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if
self
.
num_encs
==
1
:
h
=
[
h
]
lpz
=
[
lpz
]
if
self
.
num_encs
>
1
and
lpz
is
None
:
lpz
=
[
lpz
]
*
self
.
num_encs
for
idx
in
range
(
self
.
num_encs
):
logging
.
info
(
"Number of Encoder:{}; enc{}: input lengths: {}."
.
format
(
self
.
num_encs
,
idx
+
1
,
h
[
0
].
size
(
0
)
)
)
att_idx
=
min
(
strm_idx
,
len
(
self
.
att
)
-
1
)
# initialization
c_list
=
[
self
.
zero_state
(
h
[
0
].
unsqueeze
(
0
))]
z_list
=
[
self
.
zero_state
(
h
[
0
].
unsqueeze
(
0
))]
for
_
in
range
(
1
,
self
.
dlayers
):
c_list
.
append
(
self
.
zero_state
(
h
[
0
].
unsqueeze
(
0
)))
z_list
.
append
(
self
.
zero_state
(
h
[
0
].
unsqueeze
(
0
)))
if
self
.
num_encs
==
1
:
a
=
None
self
.
att
[
att_idx
].
reset
()
# reset pre-computation of h
else
:
a
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_w_list
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_c_list
=
[
None
]
*
(
self
.
num_encs
)
# atts
for
idx
in
range
(
self
.
num_encs
+
1
):
self
.
att
[
idx
].
reset
()
# reset pre-computation of h in atts and han
# search parms
beam
=
recog_args
.
beam_size
penalty
=
recog_args
.
penalty
ctc_weight
=
getattr
(
recog_args
,
"ctc_weight"
,
False
)
# for NMT
if
lpz
[
0
]
is
not
None
and
self
.
num_encs
>
1
:
# weights-ctc,
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
weights_ctc_dec
=
recog_args
.
weights_ctc_dec
/
np
.
sum
(
recog_args
.
weights_ctc_dec
)
# normalize
logging
.
info
(
"ctc weights (decoding): "
+
" "
.
join
([
str
(
x
)
for
x
in
weights_ctc_dec
])
)
else
:
weights_ctc_dec
=
[
1.0
]
# preprate sos
if
self
.
replace_sos
and
recog_args
.
tgt_lang
:
y
=
char_list
.
index
(
recog_args
.
tgt_lang
)
else
:
y
=
self
.
sos
logging
.
info
(
"<sos> index: "
+
str
(
y
))
logging
.
info
(
"<sos> mark: "
+
char_list
[
y
])
vy
=
h
[
0
].
new_zeros
(
1
).
long
()
maxlen
=
np
.
amin
([
h
[
idx
].
size
(
0
)
for
idx
in
range
(
self
.
num_encs
)])
if
recog_args
.
maxlenratio
!=
0
:
# maxlen >= 1
maxlen
=
max
(
1
,
int
(
recog_args
.
maxlenratio
*
maxlen
))
minlen
=
int
(
recog_args
.
minlenratio
*
maxlen
)
logging
.
info
(
"max output length: "
+
str
(
maxlen
))
logging
.
info
(
"min output length: "
+
str
(
minlen
))
# initialize hypothesis
if
rnnlm
:
hyp
=
{
"score"
:
0.0
,
"yseq"
:
[
y
],
"c_prev"
:
c_list
,
"z_prev"
:
z_list
,
"a_prev"
:
a
,
"rnnlm_prev"
:
None
,
}
else
:
hyp
=
{
"score"
:
0.0
,
"yseq"
:
[
y
],
"c_prev"
:
c_list
,
"z_prev"
:
z_list
,
"a_prev"
:
a
,
}
if
lpz
[
0
]
is
not
None
:
ctc_prefix_score
=
[
CTCPrefixScore
(
lpz
[
idx
].
detach
().
numpy
(),
0
,
self
.
eos
,
np
)
for
idx
in
range
(
self
.
num_encs
)
]
hyp
[
"ctc_state_prev"
]
=
[
ctc_prefix_score
[
idx
].
initial_state
()
for
idx
in
range
(
self
.
num_encs
)
]
hyp
[
"ctc_score_prev"
]
=
[
0.0
]
*
self
.
num_encs
if
ctc_weight
!=
1.0
:
# pre-pruning based on attention scores
ctc_beam
=
min
(
lpz
[
0
].
shape
[
-
1
],
int
(
beam
*
CTC_SCORING_RATIO
))
else
:
ctc_beam
=
lpz
[
0
].
shape
[
-
1
]
hyps
=
[
hyp
]
ended_hyps
=
[]
for
i
in
range
(
maxlen
):
logging
.
debug
(
"position "
+
str
(
i
))
hyps_best_kept
=
[]
for
hyp
in
hyps
:
vy
[
0
]
=
hyp
[
"yseq"
][
i
]
ey
=
self
.
dropout_emb
(
self
.
embed
(
vy
))
# utt list (1) x zdim
if
self
.
num_encs
==
1
:
att_c
,
att_w
=
self
.
att
[
att_idx
](
h
[
0
].
unsqueeze
(
0
),
[
h
[
0
].
size
(
0
)],
self
.
dropout_dec
[
0
](
hyp
[
"z_prev"
][
0
]),
hyp
[
"a_prev"
],
)
else
:
for
idx
in
range
(
self
.
num_encs
):
att_c_list
[
idx
],
att_w_list
[
idx
]
=
self
.
att
[
idx
](
h
[
idx
].
unsqueeze
(
0
),
[
h
[
idx
].
size
(
0
)],
self
.
dropout_dec
[
0
](
hyp
[
"z_prev"
][
0
]),
hyp
[
"a_prev"
][
idx
],
)
h_han
=
torch
.
stack
(
att_c_list
,
dim
=
1
)
att_c
,
att_w_list
[
self
.
num_encs
]
=
self
.
att
[
self
.
num_encs
](
h_han
,
[
self
.
num_encs
],
self
.
dropout_dec
[
0
](
hyp
[
"z_prev"
][
0
]),
hyp
[
"a_prev"
][
self
.
num_encs
],
)
ey
=
torch
.
cat
((
ey
,
att_c
),
dim
=
1
)
# utt(1) x (zdim + hdim)
z_list
,
c_list
=
self
.
rnn_forward
(
ey
,
z_list
,
c_list
,
hyp
[
"z_prev"
],
hyp
[
"c_prev"
]
)
# get nbest local scores and their ids
if
self
.
context_residual
:
logits
=
self
.
output
(
torch
.
cat
((
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]),
att_c
),
dim
=-
1
)
)
else
:
logits
=
self
.
output
(
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]))
local_att_scores
=
F
.
log_softmax
(
logits
,
dim
=
1
)
if
rnnlm
:
rnnlm_state
,
local_lm_scores
=
rnnlm
.
predict
(
hyp
[
"rnnlm_prev"
],
vy
)
local_scores
=
(
local_att_scores
+
recog_args
.
lm_weight
*
local_lm_scores
)
else
:
local_scores
=
local_att_scores
if
lpz
[
0
]
is
not
None
:
local_best_scores
,
local_best_ids
=
torch
.
topk
(
local_att_scores
,
ctc_beam
,
dim
=
1
)
ctc_scores
,
ctc_states
=
(
[
None
]
*
self
.
num_encs
,
[
None
]
*
self
.
num_encs
,
)
for
idx
in
range
(
self
.
num_encs
):
ctc_scores
[
idx
],
ctc_states
[
idx
]
=
ctc_prefix_score
[
idx
](
hyp
[
"yseq"
],
local_best_ids
[
0
],
hyp
[
"ctc_state_prev"
][
idx
]
)
local_scores
=
(
1.0
-
ctc_weight
)
*
local_att_scores
[
:,
local_best_ids
[
0
]
]
if
self
.
num_encs
==
1
:
local_scores
+=
ctc_weight
*
torch
.
from_numpy
(
ctc_scores
[
0
]
-
hyp
[
"ctc_score_prev"
][
0
]
)
else
:
for
idx
in
range
(
self
.
num_encs
):
local_scores
+=
(
ctc_weight
*
weights_ctc_dec
[
idx
]
*
torch
.
from_numpy
(
ctc_scores
[
idx
]
-
hyp
[
"ctc_score_prev"
][
idx
]
)
)
if
rnnlm
:
local_scores
+=
(
recog_args
.
lm_weight
*
local_lm_scores
[:,
local_best_ids
[
0
]]
)
local_best_scores
,
joint_best_ids
=
torch
.
topk
(
local_scores
,
beam
,
dim
=
1
)
local_best_ids
=
local_best_ids
[:,
joint_best_ids
[
0
]]
else
:
local_best_scores
,
local_best_ids
=
torch
.
topk
(
local_scores
,
beam
,
dim
=
1
)
for
j
in
range
(
beam
):
new_hyp
=
{}
# [:] is needed!
new_hyp
[
"z_prev"
]
=
z_list
[:]
new_hyp
[
"c_prev"
]
=
c_list
[:]
if
self
.
num_encs
==
1
:
new_hyp
[
"a_prev"
]
=
att_w
[:]
else
:
new_hyp
[
"a_prev"
]
=
[
att_w_list
[
idx
][:]
for
idx
in
range
(
self
.
num_encs
+
1
)
]
new_hyp
[
"score"
]
=
hyp
[
"score"
]
+
local_best_scores
[
0
,
j
]
new_hyp
[
"yseq"
]
=
[
0
]
*
(
1
+
len
(
hyp
[
"yseq"
]))
new_hyp
[
"yseq"
][:
len
(
hyp
[
"yseq"
])]
=
hyp
[
"yseq"
]
new_hyp
[
"yseq"
][
len
(
hyp
[
"yseq"
])]
=
int
(
local_best_ids
[
0
,
j
])
if
rnnlm
:
new_hyp
[
"rnnlm_prev"
]
=
rnnlm_state
if
lpz
[
0
]
is
not
None
:
new_hyp
[
"ctc_state_prev"
]
=
[
ctc_states
[
idx
][
joint_best_ids
[
0
,
j
]]
for
idx
in
range
(
self
.
num_encs
)
]
new_hyp
[
"ctc_score_prev"
]
=
[
ctc_scores
[
idx
][
joint_best_ids
[
0
,
j
]]
for
idx
in
range
(
self
.
num_encs
)
]
# will be (2 x beam) hyps at most
hyps_best_kept
.
append
(
new_hyp
)
hyps_best_kept
=
sorted
(
hyps_best_kept
,
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[:
beam
]
# sort and get nbest
hyps
=
hyps_best_kept
logging
.
debug
(
"number of pruned hypotheses: "
+
str
(
len
(
hyps
)))
logging
.
debug
(
"best hypo: "
+
""
.
join
([
char_list
[
int
(
x
)]
for
x
in
hyps
[
0
][
"yseq"
][
1
:]])
)
# add eos in the final loop to avoid that there are no ended hyps
if
i
==
maxlen
-
1
:
logging
.
info
(
"adding <eos> in the last position in the loop"
)
for
hyp
in
hyps
:
hyp
[
"yseq"
].
append
(
self
.
eos
)
# add ended hypotheses to a final list,
# and removed them from current hypotheses
# (this will be a problem, number of hyps < beam)
remained_hyps
=
[]
for
hyp
in
hyps
:
if
hyp
[
"yseq"
][
-
1
]
==
self
.
eos
:
# only store the sequence that has more than minlen outputs
# also add penalty
if
len
(
hyp
[
"yseq"
])
>
minlen
:
hyp
[
"score"
]
+=
(
i
+
1
)
*
penalty
if
rnnlm
:
# Word LM needs to add final <eos> score
hyp
[
"score"
]
+=
recog_args
.
lm_weight
*
rnnlm
.
final
(
hyp
[
"rnnlm_prev"
]
)
ended_hyps
.
append
(
hyp
)
else
:
remained_hyps
.
append
(
hyp
)
# end detection
if
end_detect
(
ended_hyps
,
i
)
and
recog_args
.
maxlenratio
==
0.0
:
logging
.
info
(
"end detected at %d"
,
i
)
break
hyps
=
remained_hyps
if
len
(
hyps
)
>
0
:
logging
.
debug
(
"remaining hypotheses: "
+
str
(
len
(
hyps
)))
else
:
logging
.
info
(
"no hypothesis. Finish decoding."
)
break
for
hyp
in
hyps
:
logging
.
debug
(
"hypo: "
+
""
.
join
([
char_list
[
int
(
x
)]
for
x
in
hyp
[
"yseq"
][
1
:]])
)
logging
.
debug
(
"number of ended hypotheses: "
+
str
(
len
(
ended_hyps
)))
nbest_hyps
=
sorted
(
ended_hyps
,
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[
:
min
(
len
(
ended_hyps
),
recog_args
.
nbest
)
]
# check number of hypotheses
if
len
(
nbest_hyps
)
==
0
:
logging
.
warning
(
"there is no N-best results, "
"perform recognition again with smaller minlenratio."
)
# should copy because Namespace will be overwritten globally
recog_args
=
Namespace
(
**
vars
(
recog_args
))
recog_args
.
minlenratio
=
max
(
0.0
,
recog_args
.
minlenratio
-
0.1
)
if
self
.
num_encs
==
1
:
return
self
.
recognize_beam
(
h
[
0
],
lpz
[
0
],
recog_args
,
char_list
,
rnnlm
)
else
:
return
self
.
recognize_beam
(
h
,
lpz
,
recog_args
,
char_list
,
rnnlm
)
logging
.
info
(
"total log probability: "
+
str
(
nbest_hyps
[
0
][
"score"
]))
logging
.
info
(
"normalized log probability: "
+
str
(
nbest_hyps
[
0
][
"score"
]
/
len
(
nbest_hyps
[
0
][
"yseq"
]))
)
# remove sos
return
nbest_hyps
def
recognize_beam_batch
(
self
,
h
,
hlens
,
lpz
,
recog_args
,
char_list
,
rnnlm
=
None
,
normalize_score
=
True
,
strm_idx
=
0
,
lang_ids
=
None
,
):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if
self
.
num_encs
==
1
:
h
=
[
h
]
hlens
=
[
hlens
]
lpz
=
[
lpz
]
if
self
.
num_encs
>
1
and
lpz
is
None
:
lpz
=
[
lpz
]
*
self
.
num_encs
att_idx
=
min
(
strm_idx
,
len
(
self
.
att
)
-
1
)
for
idx
in
range
(
self
.
num_encs
):
logging
.
info
(
"Number of Encoder:{}; enc{}: input lengths: {}."
.
format
(
self
.
num_encs
,
idx
+
1
,
h
[
idx
].
size
(
1
)
)
)
h
[
idx
]
=
mask_by_length
(
h
[
idx
],
hlens
[
idx
],
0.0
)
# search params
batch
=
len
(
hlens
[
0
])
beam
=
recog_args
.
beam_size
penalty
=
recog_args
.
penalty
ctc_weight
=
getattr
(
recog_args
,
"ctc_weight"
,
0
)
# for NMT
att_weight
=
1.0
-
ctc_weight
ctc_margin
=
getattr
(
recog_args
,
"ctc_window_margin"
,
0
)
# use getattr to keep compatibility
# weights-ctc,
# e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss
if
lpz
[
0
]
is
not
None
and
self
.
num_encs
>
1
:
weights_ctc_dec
=
recog_args
.
weights_ctc_dec
/
np
.
sum
(
recog_args
.
weights_ctc_dec
)
# normalize
logging
.
info
(
"ctc weights (decoding): "
+
" "
.
join
([
str
(
x
)
for
x
in
weights_ctc_dec
])
)
else
:
weights_ctc_dec
=
[
1.0
]
n_bb
=
batch
*
beam
pad_b
=
to_device
(
h
[
0
],
torch
.
arange
(
batch
)
*
beam
).
view
(
-
1
,
1
)
max_hlen
=
np
.
amin
([
max
(
hlens
[
idx
])
for
idx
in
range
(
self
.
num_encs
)])
if
recog_args
.
maxlenratio
==
0
:
maxlen
=
max_hlen
else
:
maxlen
=
max
(
1
,
int
(
recog_args
.
maxlenratio
*
max_hlen
))
minlen
=
int
(
recog_args
.
minlenratio
*
max_hlen
)
logging
.
info
(
"max output length: "
+
str
(
maxlen
))
logging
.
info
(
"min output length: "
+
str
(
minlen
))
# initialization
c_prev
=
[
to_device
(
h
[
0
],
torch
.
zeros
(
n_bb
,
self
.
dunits
))
for
_
in
range
(
self
.
dlayers
)
]
z_prev
=
[
to_device
(
h
[
0
],
torch
.
zeros
(
n_bb
,
self
.
dunits
))
for
_
in
range
(
self
.
dlayers
)
]
c_list
=
[
to_device
(
h
[
0
],
torch
.
zeros
(
n_bb
,
self
.
dunits
))
for
_
in
range
(
self
.
dlayers
)
]
z_list
=
[
to_device
(
h
[
0
],
torch
.
zeros
(
n_bb
,
self
.
dunits
))
for
_
in
range
(
self
.
dlayers
)
]
vscores
=
to_device
(
h
[
0
],
torch
.
zeros
(
batch
,
beam
))
rnnlm_state
=
None
if
self
.
num_encs
==
1
:
a_prev
=
[
None
]
att_w_list
,
ctc_scorer
,
ctc_state
=
[
None
],
[
None
],
[
None
]
self
.
att
[
att_idx
].
reset
()
# reset pre-computation of h
else
:
a_prev
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_w_list
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_c_list
=
[
None
]
*
(
self
.
num_encs
)
# atts
ctc_scorer
,
ctc_state
=
[
None
]
*
(
self
.
num_encs
),
[
None
]
*
(
self
.
num_encs
)
for
idx
in
range
(
self
.
num_encs
+
1
):
self
.
att
[
idx
].
reset
()
# reset pre-computation of h in atts and han
if
self
.
replace_sos
and
recog_args
.
tgt_lang
:
logging
.
info
(
"<sos> index: "
+
str
(
char_list
.
index
(
recog_args
.
tgt_lang
)))
logging
.
info
(
"<sos> mark: "
+
recog_args
.
tgt_lang
)
yseq
=
[[
char_list
.
index
(
recog_args
.
tgt_lang
)]
for
_
in
range
(
n_bb
)]
elif
lang_ids
is
not
None
:
# NOTE: used for evaluation during training
yseq
=
[[
lang_ids
[
b
//
recog_args
.
beam_size
]]
for
b
in
range
(
n_bb
)]
else
:
logging
.
info
(
"<sos> index: "
+
str
(
self
.
sos
))
logging
.
info
(
"<sos> mark: "
+
char_list
[
self
.
sos
])
yseq
=
[[
self
.
sos
]
for
_
in
range
(
n_bb
)]
accum_odim_ids
=
[
self
.
sos
for
_
in
range
(
n_bb
)]
stop_search
=
[
False
for
_
in
range
(
batch
)]
nbest_hyps
=
[[]
for
_
in
range
(
batch
)]
ended_hyps
=
[[]
for
_
in
range
(
batch
)]
exp_hlens
=
[
hlens
[
idx
].
repeat
(
beam
).
view
(
beam
,
batch
).
transpose
(
0
,
1
).
contiguous
()
for
idx
in
range
(
self
.
num_encs
)
]
exp_hlens
=
[
exp_hlens
[
idx
].
view
(
-
1
).
tolist
()
for
idx
in
range
(
self
.
num_encs
)]
exp_h
=
[
h
[
idx
].
unsqueeze
(
1
).
repeat
(
1
,
beam
,
1
,
1
).
contiguous
()
for
idx
in
range
(
self
.
num_encs
)
]
exp_h
=
[
exp_h
[
idx
].
view
(
n_bb
,
h
[
idx
].
size
()[
1
],
h
[
idx
].
size
()[
2
])
for
idx
in
range
(
self
.
num_encs
)
]
if
lpz
[
0
]
is
not
None
:
scoring_num
=
min
(
int
(
beam
*
CTC_SCORING_RATIO
)
if
att_weight
>
0.0
and
not
lpz
[
0
].
is_cuda
else
0
,
lpz
[
0
].
size
(
-
1
),
)
ctc_scorer
=
[
CTCPrefixScoreTH
(
lpz
[
idx
],
hlens
[
idx
],
0
,
self
.
eos
,
margin
=
ctc_margin
,
)
for
idx
in
range
(
self
.
num_encs
)
]
for
i
in
range
(
maxlen
):
logging
.
debug
(
"position "
+
str
(
i
))
vy
=
to_device
(
h
[
0
],
torch
.
LongTensor
(
self
.
_get_last_yseq
(
yseq
)))
ey
=
self
.
dropout_emb
(
self
.
embed
(
vy
))
if
self
.
num_encs
==
1
:
att_c
,
att_w
=
self
.
att
[
att_idx
](
exp_h
[
0
],
exp_hlens
[
0
],
self
.
dropout_dec
[
0
](
z_prev
[
0
]),
a_prev
[
0
]
)
att_w_list
=
[
att_w
]
else
:
for
idx
in
range
(
self
.
num_encs
):
att_c_list
[
idx
],
att_w_list
[
idx
]
=
self
.
att
[
idx
](
exp_h
[
idx
],
exp_hlens
[
idx
],
self
.
dropout_dec
[
0
](
z_prev
[
0
]),
a_prev
[
idx
],
)
exp_h_han
=
torch
.
stack
(
att_c_list
,
dim
=
1
)
att_c
,
att_w_list
[
self
.
num_encs
]
=
self
.
att
[
self
.
num_encs
](
exp_h_han
,
[
self
.
num_encs
]
*
n_bb
,
self
.
dropout_dec
[
0
](
z_prev
[
0
]),
a_prev
[
self
.
num_encs
],
)
ey
=
torch
.
cat
((
ey
,
att_c
),
dim
=
1
)
# attention decoder
z_list
,
c_list
=
self
.
rnn_forward
(
ey
,
z_list
,
c_list
,
z_prev
,
c_prev
)
if
self
.
context_residual
:
logits
=
self
.
output
(
torch
.
cat
((
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]),
att_c
),
dim
=-
1
)
)
else
:
logits
=
self
.
output
(
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]))
local_scores
=
att_weight
*
F
.
log_softmax
(
logits
,
dim
=
1
)
# rnnlm
if
rnnlm
:
rnnlm_state
,
local_lm_scores
=
rnnlm
.
buff_predict
(
rnnlm_state
,
vy
,
n_bb
)
local_scores
=
local_scores
+
recog_args
.
lm_weight
*
local_lm_scores
# ctc
if
ctc_scorer
[
0
]:
local_scores
[:,
0
]
=
self
.
logzero
# avoid choosing blank
part_ids
=
(
torch
.
topk
(
local_scores
,
scoring_num
,
dim
=-
1
)[
1
]
if
scoring_num
>
0
else
None
)
for
idx
in
range
(
self
.
num_encs
):
att_w
=
att_w_list
[
idx
]
att_w_
=
att_w
if
isinstance
(
att_w
,
torch
.
Tensor
)
else
att_w
[
0
]
local_ctc_scores
,
ctc_state
[
idx
]
=
ctc_scorer
[
idx
](
yseq
,
ctc_state
[
idx
],
part_ids
,
att_w_
)
local_scores
=
(
local_scores
+
ctc_weight
*
weights_ctc_dec
[
idx
]
*
local_ctc_scores
)
local_scores
=
local_scores
.
view
(
batch
,
beam
,
self
.
odim
)
if
i
==
0
:
local_scores
[:,
1
:,
:]
=
self
.
logzero
# accumulate scores
eos_vscores
=
local_scores
[:,
:,
self
.
eos
]
+
vscores
vscores
=
vscores
.
view
(
batch
,
beam
,
1
).
repeat
(
1
,
1
,
self
.
odim
)
vscores
[:,
:,
self
.
eos
]
=
self
.
logzero
vscores
=
(
vscores
+
local_scores
).
view
(
batch
,
-
1
)
# global pruning
accum_best_scores
,
accum_best_ids
=
torch
.
topk
(
vscores
,
beam
,
1
)
accum_odim_ids
=
(
torch
.
fmod
(
accum_best_ids
,
self
.
odim
).
view
(
-
1
).
data
.
cpu
().
tolist
()
)
accum_padded_beam_ids
=
(
(
accum_best_ids
//
self
.
odim
+
pad_b
).
view
(
-
1
).
data
.
cpu
().
tolist
()
)
y_prev
=
yseq
[:][:]
yseq
=
self
.
_index_select_list
(
yseq
,
accum_padded_beam_ids
)
yseq
=
self
.
_append_ids
(
yseq
,
accum_odim_ids
)
vscores
=
accum_best_scores
vidx
=
to_device
(
h
[
0
],
torch
.
LongTensor
(
accum_padded_beam_ids
))
a_prev
=
[]
num_atts
=
self
.
num_encs
if
self
.
num_encs
==
1
else
self
.
num_encs
+
1
for
idx
in
range
(
num_atts
):
if
isinstance
(
att_w_list
[
idx
],
torch
.
Tensor
):
_a_prev
=
torch
.
index_select
(
att_w_list
[
idx
].
view
(
n_bb
,
*
att_w_list
[
idx
].
shape
[
1
:]),
0
,
vidx
)
elif
isinstance
(
att_w_list
[
idx
],
list
):
# handle the case of multi-head attention
_a_prev
=
[
torch
.
index_select
(
att_w_one
.
view
(
n_bb
,
-
1
),
0
,
vidx
)
for
att_w_one
in
att_w_list
[
idx
]
]
else
:
# handle the case of location_recurrent when return is a tuple
_a_prev_
=
torch
.
index_select
(
att_w_list
[
idx
][
0
].
view
(
n_bb
,
-
1
),
0
,
vidx
)
_h_prev_
=
torch
.
index_select
(
att_w_list
[
idx
][
1
][
0
].
view
(
n_bb
,
-
1
),
0
,
vidx
)
_c_prev_
=
torch
.
index_select
(
att_w_list
[
idx
][
1
][
1
].
view
(
n_bb
,
-
1
),
0
,
vidx
)
_a_prev
=
(
_a_prev_
,
(
_h_prev_
,
_c_prev_
))
a_prev
.
append
(
_a_prev
)
z_prev
=
[
torch
.
index_select
(
z_list
[
li
].
view
(
n_bb
,
-
1
),
0
,
vidx
)
for
li
in
range
(
self
.
dlayers
)
]
c_prev
=
[
torch
.
index_select
(
c_list
[
li
].
view
(
n_bb
,
-
1
),
0
,
vidx
)
for
li
in
range
(
self
.
dlayers
)
]
# pick ended hyps
if
i
>=
minlen
:
k
=
0
penalty_i
=
(
i
+
1
)
*
penalty
thr
=
accum_best_scores
[:,
-
1
]
for
samp_i
in
range
(
batch
):
if
stop_search
[
samp_i
]:
k
=
k
+
beam
continue
for
beam_j
in
range
(
beam
):
_vscore
=
None
if
eos_vscores
[
samp_i
,
beam_j
]
>
thr
[
samp_i
]:
yk
=
y_prev
[
k
][:]
if
len
(
yk
)
<=
min
(
hlens
[
idx
][
samp_i
]
for
idx
in
range
(
self
.
num_encs
)
):
_vscore
=
eos_vscores
[
samp_i
][
beam_j
]
+
penalty_i
elif
i
==
maxlen
-
1
:
yk
=
yseq
[
k
][:]
_vscore
=
vscores
[
samp_i
][
beam_j
]
+
penalty_i
if
_vscore
:
yk
.
append
(
self
.
eos
)
if
rnnlm
:
_vscore
+=
recog_args
.
lm_weight
*
rnnlm
.
final
(
rnnlm_state
,
index
=
k
)
_score
=
_vscore
.
data
.
cpu
().
numpy
()
ended_hyps
[
samp_i
].
append
(
{
"yseq"
:
yk
,
"vscore"
:
_vscore
,
"score"
:
_score
}
)
k
=
k
+
1
# end detection
stop_search
=
[
stop_search
[
samp_i
]
or
end_detect
(
ended_hyps
[
samp_i
],
i
)
for
samp_i
in
range
(
batch
)
]
stop_search_summary
=
list
(
set
(
stop_search
))
if
len
(
stop_search_summary
)
==
1
and
stop_search_summary
[
0
]:
break
if
rnnlm
:
rnnlm_state
=
self
.
_index_select_lm_state
(
rnnlm_state
,
0
,
vidx
)
if
ctc_scorer
[
0
]:
for
idx
in
range
(
self
.
num_encs
):
ctc_state
[
idx
]
=
ctc_scorer
[
idx
].
index_select_state
(
ctc_state
[
idx
],
accum_best_ids
)
torch
.
cuda
.
empty_cache
()
dummy_hyps
=
[
{
"yseq"
:
[
self
.
sos
,
self
.
eos
],
"score"
:
np
.
array
([
-
float
(
"inf"
)])}
]
ended_hyps
=
[
ended_hyps
[
samp_i
]
if
len
(
ended_hyps
[
samp_i
])
!=
0
else
dummy_hyps
for
samp_i
in
range
(
batch
)
]
if
normalize_score
:
for
samp_i
in
range
(
batch
):
for
x
in
ended_hyps
[
samp_i
]:
x
[
"score"
]
/=
len
(
x
[
"yseq"
])
nbest_hyps
=
[
sorted
(
ended_hyps
[
samp_i
],
key
=
lambda
x
:
x
[
"score"
],
reverse
=
True
)[
:
min
(
len
(
ended_hyps
[
samp_i
]),
recog_args
.
nbest
)
]
for
samp_i
in
range
(
batch
)
]
return
nbest_hyps
def
calculate_all_attentions
(
self
,
hs_pad
,
hlen
,
ys_pad
,
strm_idx
=
0
,
lang_ids
=
None
):
"""Calculate all of attentions
:param torch.Tensor hs_pad: batch of padded hidden state sequences
(B, Tmax, D)
in multi-encoder case, list of torch.Tensor,
[(B, Tmax_1, D), (B, Tmax_2, D), ..., ] ]
:param torch.Tensor hlen: batch of lengths of hidden state sequences (B)
[in multi-encoder case, list of torch.Tensor,
[(B), (B), ..., ]
:param torch.Tensor ys_pad:
batch of padded character id sequence tensor (B, Lmax)
:param int strm_idx:
stream index for parallel speaker attention in multi-speaker case
:param torch.Tensor lang_ids: batch of target language id tensor (B, 1)
:return: attention weights with the following shape,
1) multi-head case => attention weights (B, H, Lmax, Tmax),
2) multi-encoder case =>
[(B, Lmax, Tmax1), (B, Lmax, Tmax2), ..., (B, Lmax, NumEncs)]
3) other case => attention weights (B, Lmax, Tmax).
:rtype: float ndarray
"""
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if
self
.
num_encs
==
1
:
hs_pad
=
[
hs_pad
]
hlen
=
[
hlen
]
# TODO(kan-bayashi): need to make more smart way
ys
=
[
y
[
y
!=
self
.
ignore_id
]
for
y
in
ys_pad
]
# parse padded ys
att_idx
=
min
(
strm_idx
,
len
(
self
.
att
)
-
1
)
# hlen should be list of list of integer
hlen
=
[
list
(
map
(
int
,
hlen
[
idx
]))
for
idx
in
range
(
self
.
num_encs
)]
self
.
loss
=
None
# prepare input and output word sequences with sos/eos IDs
eos
=
ys
[
0
].
new
([
self
.
eos
])
sos
=
ys
[
0
].
new
([
self
.
sos
])
if
self
.
replace_sos
:
ys_in
=
[
torch
.
cat
([
idx
,
y
],
dim
=
0
)
for
idx
,
y
in
zip
(
lang_ids
,
ys
)]
else
:
ys_in
=
[
torch
.
cat
([
sos
,
y
],
dim
=
0
)
for
y
in
ys
]
ys_out
=
[
torch
.
cat
([
y
,
eos
],
dim
=
0
)
for
y
in
ys
]
# padding for ys with -1
# pys: utt x olen
ys_in_pad
=
pad_list
(
ys_in
,
self
.
eos
)
ys_out_pad
=
pad_list
(
ys_out
,
self
.
ignore_id
)
# get length info
olength
=
ys_out_pad
.
size
(
1
)
# initialization
c_list
=
[
self
.
zero_state
(
hs_pad
[
0
])]
z_list
=
[
self
.
zero_state
(
hs_pad
[
0
])]
for
_
in
range
(
1
,
self
.
dlayers
):
c_list
.
append
(
self
.
zero_state
(
hs_pad
[
0
]))
z_list
.
append
(
self
.
zero_state
(
hs_pad
[
0
]))
att_ws
=
[]
if
self
.
num_encs
==
1
:
att_w
=
None
self
.
att
[
att_idx
].
reset
()
# reset pre-computation of h
else
:
att_w_list
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_c_list
=
[
None
]
*
(
self
.
num_encs
)
# atts
for
idx
in
range
(
self
.
num_encs
+
1
):
self
.
att
[
idx
].
reset
()
# reset pre-computation of h in atts and han
# pre-computation of embedding
eys
=
self
.
dropout_emb
(
self
.
embed
(
ys_in_pad
))
# utt x olen x zdim
# loop for an output sequence
for
i
in
range
(
olength
):
if
self
.
num_encs
==
1
:
att_c
,
att_w
=
self
.
att
[
att_idx
](
hs_pad
[
0
],
hlen
[
0
],
self
.
dropout_dec
[
0
](
z_list
[
0
]),
att_w
)
att_ws
.
append
(
att_w
)
else
:
for
idx
in
range
(
self
.
num_encs
):
att_c_list
[
idx
],
att_w_list
[
idx
]
=
self
.
att
[
idx
](
hs_pad
[
idx
],
hlen
[
idx
],
self
.
dropout_dec
[
0
](
z_list
[
0
]),
att_w_list
[
idx
],
)
hs_pad_han
=
torch
.
stack
(
att_c_list
,
dim
=
1
)
hlen_han
=
[
self
.
num_encs
]
*
len
(
ys_in
)
att_c
,
att_w_list
[
self
.
num_encs
]
=
self
.
att
[
self
.
num_encs
](
hs_pad_han
,
hlen_han
,
self
.
dropout_dec
[
0
](
z_list
[
0
]),
att_w_list
[
self
.
num_encs
],
)
att_ws
.
append
(
att_w_list
.
copy
())
ey
=
torch
.
cat
((
eys
[:,
i
,
:],
att_c
),
dim
=
1
)
# utt x (zdim + hdim)
z_list
,
c_list
=
self
.
rnn_forward
(
ey
,
z_list
,
c_list
,
z_list
,
c_list
)
if
self
.
num_encs
==
1
:
# convert to numpy array with the shape (B, Lmax, Tmax)
att_ws
=
att_to_numpy
(
att_ws
,
self
.
att
[
att_idx
])
else
:
_att_ws
=
[]
for
idx
,
ws
in
enumerate
(
zip
(
*
att_ws
)):
ws
=
att_to_numpy
(
ws
,
self
.
att
[
idx
])
_att_ws
.
append
(
ws
)
att_ws
=
_att_ws
return
att_ws
@
staticmethod
def
_get_last_yseq
(
exp_yseq
):
last
=
[]
for
y_seq
in
exp_yseq
:
last
.
append
(
y_seq
[
-
1
])
return
last
@
staticmethod
def
_append_ids
(
yseq
,
ids
):
if
isinstance
(
ids
,
list
):
for
i
,
j
in
enumerate
(
ids
):
yseq
[
i
].
append
(
j
)
else
:
for
i
in
range
(
len
(
yseq
)):
yseq
[
i
].
append
(
ids
)
return
yseq
@
staticmethod
def
_index_select_list
(
yseq
,
lst
):
new_yseq
=
[]
for
i
in
lst
:
new_yseq
.
append
(
yseq
[
i
][:])
return
new_yseq
@
staticmethod
def
_index_select_lm_state
(
rnnlm_state
,
dim
,
vidx
):
if
isinstance
(
rnnlm_state
,
dict
):
new_state
=
{}
for
k
,
v
in
rnnlm_state
.
items
():
new_state
[
k
]
=
[
torch
.
index_select
(
vi
,
dim
,
vidx
)
for
vi
in
v
]
elif
isinstance
(
rnnlm_state
,
list
):
new_state
=
[]
for
i
in
vidx
:
new_state
.
append
(
rnnlm_state
[
int
(
i
)][:])
return
new_state
# scorer interface methods
def
init_state
(
self
,
x
):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if
self
.
num_encs
==
1
:
x
=
[
x
]
c_list
=
[
self
.
zero_state
(
x
[
0
].
unsqueeze
(
0
))]
z_list
=
[
self
.
zero_state
(
x
[
0
].
unsqueeze
(
0
))]
for
_
in
range
(
1
,
self
.
dlayers
):
c_list
.
append
(
self
.
zero_state
(
x
[
0
].
unsqueeze
(
0
)))
z_list
.
append
(
self
.
zero_state
(
x
[
0
].
unsqueeze
(
0
)))
# TODO(karita): support strm_index for `asr_mix`
strm_index
=
0
att_idx
=
min
(
strm_index
,
len
(
self
.
att
)
-
1
)
if
self
.
num_encs
==
1
:
a
=
None
self
.
att
[
att_idx
].
reset
()
# reset pre-computation of h
else
:
a
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
for
idx
in
range
(
self
.
num_encs
+
1
):
self
.
att
[
idx
].
reset
()
# reset pre-computation of h in atts and han
return
dict
(
c_prev
=
c_list
[:],
z_prev
=
z_list
[:],
a_prev
=
a
,
workspace
=
(
att_idx
,
z_list
,
c_list
),
)
def
score
(
self
,
yseq
,
state
,
x
):
# to support mutiple encoder asr mode, in single encoder mode,
# convert torch.Tensor to List of torch.Tensor
if
self
.
num_encs
==
1
:
x
=
[
x
]
att_idx
,
z_list
,
c_list
=
state
[
"workspace"
]
vy
=
yseq
[
-
1
].
unsqueeze
(
0
)
ey
=
self
.
dropout_emb
(
self
.
embed
(
vy
))
# utt list (1) x zdim
if
self
.
num_encs
==
1
:
att_c
,
att_w
=
self
.
att
[
att_idx
](
x
[
0
].
unsqueeze
(
0
),
[
x
[
0
].
size
(
0
)],
self
.
dropout_dec
[
0
](
state
[
"z_prev"
][
0
]),
state
[
"a_prev"
],
)
else
:
att_w
=
[
None
]
*
(
self
.
num_encs
+
1
)
# atts + han
att_c_list
=
[
None
]
*
(
self
.
num_encs
)
# atts
for
idx
in
range
(
self
.
num_encs
):
att_c_list
[
idx
],
att_w
[
idx
]
=
self
.
att
[
idx
](
x
[
idx
].
unsqueeze
(
0
),
[
x
[
idx
].
size
(
0
)],
self
.
dropout_dec
[
0
](
state
[
"z_prev"
][
0
]),
state
[
"a_prev"
][
idx
],
)
h_han
=
torch
.
stack
(
att_c_list
,
dim
=
1
)
att_c
,
att_w
[
self
.
num_encs
]
=
self
.
att
[
self
.
num_encs
](
h_han
,
[
self
.
num_encs
],
self
.
dropout_dec
[
0
](
state
[
"z_prev"
][
0
]),
state
[
"a_prev"
][
self
.
num_encs
],
)
ey
=
torch
.
cat
((
ey
,
att_c
),
dim
=
1
)
# utt(1) x (zdim + hdim)
z_list
,
c_list
=
self
.
rnn_forward
(
ey
,
z_list
,
c_list
,
state
[
"z_prev"
],
state
[
"c_prev"
]
)
if
self
.
context_residual
:
logits
=
self
.
output
(
torch
.
cat
((
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]),
att_c
),
dim
=-
1
)
)
else
:
logits
=
self
.
output
(
self
.
dropout_dec
[
-
1
](
z_list
[
-
1
]))
logp
=
F
.
log_softmax
(
logits
,
dim
=
1
).
squeeze
(
0
)
return
(
logp
,
dict
(
c_prev
=
c_list
[:],
z_prev
=
z_list
[:],
a_prev
=
att_w
,
workspace
=
(
att_idx
,
z_list
,
c_list
),
),
)
def
decoder_for
(
args
,
odim
,
sos
,
eos
,
att
,
labeldist
):
return
Decoder
(
args
.
eprojs
,
odim
,
args
.
dtype
,
args
.
dlayers
,
args
.
dunits
,
sos
,
eos
,
att
,
args
.
verbose
,
args
.
char_list
,
labeldist
,
args
.
lsm_weight
,
args
.
sampling_probability
,
args
.
dropout_rate_decoder
,
getattr
(
args
,
"context_residual"
,
False
),
# use getattr to keep compatibility
getattr
(
args
,
"replace_sos"
,
False
),
# use getattr to keep compatibility
getattr
(
args
,
"num_encs"
,
1
),
)
# use getattr to keep compatibility
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/rnn/encoders.py
0 → 100644
View file @
60a2c57a
import
logging
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch.nn.utils.rnn
import
pack_padded_sequence
,
pad_packed_sequence
from
espnet.nets.e2e_asr_common
import
get_vgg2l_odim
from
espnet.nets.pytorch_backend.nets_utils
import
make_pad_mask
,
to_device
class
RNNP
(
torch
.
nn
.
Module
):
"""RNN with projection layer module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of projection units
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def
__init__
(
self
,
idim
,
elayers
,
cdim
,
hdim
,
subsample
,
dropout
,
typ
=
"blstm"
):
super
(
RNNP
,
self
).
__init__
()
bidir
=
typ
[
0
]
==
"b"
for
i
in
range
(
elayers
):
if
i
==
0
:
inputdim
=
idim
else
:
inputdim
=
hdim
RNN
=
torch
.
nn
.
LSTM
if
"lstm"
in
typ
else
torch
.
nn
.
GRU
rnn
=
RNN
(
inputdim
,
cdim
,
num_layers
=
1
,
bidirectional
=
bidir
,
batch_first
=
True
)
setattr
(
self
,
"%s%d"
%
(
"birnn"
if
bidir
else
"rnn"
,
i
),
rnn
)
# bottleneck layer to merge
if
bidir
:
setattr
(
self
,
"bt%d"
%
i
,
torch
.
nn
.
Linear
(
2
*
cdim
,
hdim
))
else
:
setattr
(
self
,
"bt%d"
%
i
,
torch
.
nn
.
Linear
(
cdim
,
hdim
))
self
.
elayers
=
elayers
self
.
cdim
=
cdim
self
.
subsample
=
subsample
self
.
typ
=
typ
self
.
bidir
=
bidir
self
.
dropout
=
dropout
def
forward
(
self
,
xs_pad
,
ilens
,
prev_state
=
None
):
"""RNNP forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, idim)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, hdim)
:rtype: torch.Tensor
"""
logging
.
debug
(
self
.
__class__
.
__name__
+
" input lengths: "
+
str
(
ilens
))
elayer_states
=
[]
for
layer
in
range
(
self
.
elayers
):
if
not
isinstance
(
ilens
,
torch
.
Tensor
):
ilens
=
torch
.
tensor
(
ilens
)
xs_pack
=
pack_padded_sequence
(
xs_pad
,
ilens
.
cpu
(),
batch_first
=
True
)
rnn
=
getattr
(
self
,
(
"birnn"
if
self
.
bidir
else
"rnn"
)
+
str
(
layer
))
if
self
.
training
:
rnn
.
flatten_parameters
()
if
prev_state
is
not
None
and
rnn
.
bidirectional
:
prev_state
=
reset_backward_rnn_state
(
prev_state
)
ys
,
states
=
rnn
(
xs_pack
,
hx
=
None
if
prev_state
is
None
else
prev_state
[
layer
]
)
elayer_states
.
append
(
states
)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad
,
ilens
=
pad_packed_sequence
(
ys
,
batch_first
=
True
)
sub
=
self
.
subsample
[
layer
+
1
]
if
sub
>
1
:
ys_pad
=
ys_pad
[:,
::
sub
]
ilens
=
torch
.
tensor
([
int
(
i
+
1
)
//
sub
for
i
in
ilens
])
# (sum _utt frame_utt) x dim
projection_layer
=
getattr
(
self
,
"bt%d"
%
layer
)
projected
=
projection_layer
(
ys_pad
.
contiguous
().
view
(
-
1
,
ys_pad
.
size
(
2
)))
xs_pad
=
projected
.
view
(
ys_pad
.
size
(
0
),
ys_pad
.
size
(
1
),
-
1
)
if
layer
<
self
.
elayers
-
1
:
xs_pad
=
torch
.
tanh
(
F
.
dropout
(
xs_pad
,
p
=
self
.
dropout
))
return
xs_pad
,
ilens
,
elayer_states
# x: utt list of frame x dim
class
RNN
(
torch
.
nn
.
Module
):
"""RNN module
:param int idim: dimension of inputs
:param int elayers: number of encoder layers
:param int cdim: number of rnn units (resulted in cdim * 2 if bidirectional)
:param int hdim: number of final projection units
:param float dropout: dropout rate
:param str typ: The RNN type
"""
def
__init__
(
self
,
idim
,
elayers
,
cdim
,
hdim
,
dropout
,
typ
=
"blstm"
):
super
(
RNN
,
self
).
__init__
()
bidir
=
typ
[
0
]
==
"b"
self
.
nbrnn
=
(
torch
.
nn
.
LSTM
(
idim
,
cdim
,
elayers
,
batch_first
=
True
,
dropout
=
dropout
,
bidirectional
=
bidir
,
)
if
"lstm"
in
typ
else
torch
.
nn
.
GRU
(
idim
,
cdim
,
elayers
,
batch_first
=
True
,
dropout
=
dropout
,
bidirectional
=
bidir
,
)
)
if
bidir
:
self
.
l_last
=
torch
.
nn
.
Linear
(
cdim
*
2
,
hdim
)
else
:
self
.
l_last
=
torch
.
nn
.
Linear
(
cdim
,
hdim
)
self
.
typ
=
typ
def
forward
(
self
,
xs_pad
,
ilens
,
prev_state
=
None
):
"""RNN forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous RNN states
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
logging
.
debug
(
self
.
__class__
.
__name__
+
" input lengths: "
+
str
(
ilens
))
if
not
isinstance
(
ilens
,
torch
.
Tensor
):
ilens
=
torch
.
tensor
(
ilens
)
xs_pack
=
pack_padded_sequence
(
xs_pad
,
ilens
.
cpu
(),
batch_first
=
True
)
if
self
.
training
:
self
.
nbrnn
.
flatten_parameters
()
if
prev_state
is
not
None
and
self
.
nbrnn
.
bidirectional
:
# We assume that when previous state is passed,
# it means that we're streaming the input
# and therefore cannot propagate backward BRNN state
# (otherwise it goes in the wrong direction)
prev_state
=
reset_backward_rnn_state
(
prev_state
)
ys
,
states
=
self
.
nbrnn
(
xs_pack
,
hx
=
prev_state
)
# ys: utt list of frame x cdim x 2 (2: means bidirectional)
ys_pad
,
ilens
=
pad_packed_sequence
(
ys
,
batch_first
=
True
)
# (sum _utt frame_utt) x dim
projected
=
torch
.
tanh
(
self
.
l_last
(
ys_pad
.
contiguous
().
view
(
-
1
,
ys_pad
.
size
(
2
)))
)
xs_pad
=
projected
.
view
(
ys_pad
.
size
(
0
),
ys_pad
.
size
(
1
),
-
1
)
return
xs_pad
,
ilens
,
states
# x: utt list of frame x dim
def
reset_backward_rnn_state
(
states
):
"""Sets backward BRNN states to zeroes
Useful in processing of sliding windows over the inputs
"""
if
isinstance
(
states
,
(
list
,
tuple
)):
for
state
in
states
:
state
[
1
::
2
]
=
0.0
else
:
states
[
1
::
2
]
=
0.0
return
states
class
VGG2L
(
torch
.
nn
.
Module
):
"""VGG-like module
:param int in_channel: number of input channels
"""
def
__init__
(
self
,
in_channel
=
1
):
super
(
VGG2L
,
self
).
__init__
()
# CNN layer (VGG motivated)
self
.
conv1_1
=
torch
.
nn
.
Conv2d
(
in_channel
,
64
,
3
,
stride
=
1
,
padding
=
1
)
self
.
conv1_2
=
torch
.
nn
.
Conv2d
(
64
,
64
,
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2_1
=
torch
.
nn
.
Conv2d
(
64
,
128
,
3
,
stride
=
1
,
padding
=
1
)
self
.
conv2_2
=
torch
.
nn
.
Conv2d
(
128
,
128
,
3
,
stride
=
1
,
padding
=
1
)
self
.
in_channel
=
in_channel
def
forward
(
self
,
xs_pad
,
ilens
,
**
kwargs
):
"""VGG2L forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:return: batch of padded hidden state sequences (B, Tmax // 4, 128 * D // 4)
:rtype: torch.Tensor
"""
logging
.
debug
(
self
.
__class__
.
__name__
+
" input lengths: "
+
str
(
ilens
))
# x: utt x frame x dim
# xs_pad = F.pad_sequence(xs_pad)
# x: utt x 1 (input channel num) x frame x dim
xs_pad
=
xs_pad
.
view
(
xs_pad
.
size
(
0
),
xs_pad
.
size
(
1
),
self
.
in_channel
,
xs_pad
.
size
(
2
)
//
self
.
in_channel
,
).
transpose
(
1
,
2
)
# NOTE: max_pool1d ?
xs_pad
=
F
.
relu
(
self
.
conv1_1
(
xs_pad
))
xs_pad
=
F
.
relu
(
self
.
conv1_2
(
xs_pad
))
xs_pad
=
F
.
max_pool2d
(
xs_pad
,
2
,
stride
=
2
,
ceil_mode
=
True
)
xs_pad
=
F
.
relu
(
self
.
conv2_1
(
xs_pad
))
xs_pad
=
F
.
relu
(
self
.
conv2_2
(
xs_pad
))
xs_pad
=
F
.
max_pool2d
(
xs_pad
,
2
,
stride
=
2
,
ceil_mode
=
True
)
if
torch
.
is_tensor
(
ilens
):
ilens
=
ilens
.
cpu
().
numpy
()
else
:
ilens
=
np
.
array
(
ilens
,
dtype
=
np
.
float32
)
ilens
=
np
.
array
(
np
.
ceil
(
ilens
/
2
),
dtype
=
np
.
int64
)
ilens
=
np
.
array
(
np
.
ceil
(
np
.
array
(
ilens
,
dtype
=
np
.
float32
)
/
2
),
dtype
=
np
.
int64
).
tolist
()
# x: utt_list of frame (remove zeropaded frames) x (input channel num x dim)
xs_pad
=
xs_pad
.
transpose
(
1
,
2
)
xs_pad
=
xs_pad
.
contiguous
().
view
(
xs_pad
.
size
(
0
),
xs_pad
.
size
(
1
),
xs_pad
.
size
(
2
)
*
xs_pad
.
size
(
3
)
)
return
xs_pad
,
ilens
,
None
# no state in this layer
class
Encoder
(
torch
.
nn
.
Module
):
"""Encoder module
:param str etype: type of encoder network
:param int idim: number of dimensions of encoder network
:param int elayers: number of layers of encoder network
:param int eunits: number of lstm units of encoder network
:param int eprojs: number of projection units of encoder network
:param np.ndarray subsample: list of subsampling numbers
:param float dropout: dropout rate
:param int in_channel: number of input channels
"""
def
__init__
(
self
,
etype
,
idim
,
elayers
,
eunits
,
eprojs
,
subsample
,
dropout
,
in_channel
=
1
):
super
(
Encoder
,
self
).
__init__
()
typ
=
etype
.
lstrip
(
"vgg"
).
rstrip
(
"p"
)
if
typ
not
in
[
"lstm"
,
"gru"
,
"blstm"
,
"bgru"
]:
logging
.
error
(
"Error: need to specify an appropriate encoder architecture"
)
if
etype
.
startswith
(
"vgg"
):
if
etype
[
-
1
]
==
"p"
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
VGG2L
(
in_channel
),
RNNP
(
get_vgg2l_odim
(
idim
,
in_channel
=
in_channel
),
elayers
,
eunits
,
eprojs
,
subsample
,
dropout
,
typ
=
typ
,
),
]
)
logging
.
info
(
"Use CNN-VGG + "
+
typ
.
upper
()
+
"P for encoder"
)
else
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
VGG2L
(
in_channel
),
RNN
(
get_vgg2l_odim
(
idim
,
in_channel
=
in_channel
),
elayers
,
eunits
,
eprojs
,
dropout
,
typ
=
typ
,
),
]
)
logging
.
info
(
"Use CNN-VGG + "
+
typ
.
upper
()
+
" for encoder"
)
self
.
conv_subsampling_factor
=
4
else
:
if
etype
[
-
1
]
==
"p"
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
RNNP
(
idim
,
elayers
,
eunits
,
eprojs
,
subsample
,
dropout
,
typ
=
typ
)]
)
logging
.
info
(
typ
.
upper
()
+
" with every-layer projection for encoder"
)
else
:
self
.
enc
=
torch
.
nn
.
ModuleList
(
[
RNN
(
idim
,
elayers
,
eunits
,
eprojs
,
dropout
,
typ
=
typ
)]
)
logging
.
info
(
typ
.
upper
()
+
" without projection for encoder"
)
self
.
conv_subsampling_factor
=
1
def
forward
(
self
,
xs_pad
,
ilens
,
prev_states
=
None
):
"""Encoder forward
:param torch.Tensor xs_pad: batch of padded input sequences (B, Tmax, D)
:param torch.Tensor ilens: batch of lengths of input sequences (B)
:param torch.Tensor prev_state: batch of previous encoder hidden states (?, ...)
:return: batch of hidden state sequences (B, Tmax, eprojs)
:rtype: torch.Tensor
"""
if
prev_states
is
None
:
prev_states
=
[
None
]
*
len
(
self
.
enc
)
assert
len
(
prev_states
)
==
len
(
self
.
enc
)
current_states
=
[]
for
module
,
prev_state
in
zip
(
self
.
enc
,
prev_states
):
xs_pad
,
ilens
,
states
=
module
(
xs_pad
,
ilens
,
prev_state
=
prev_state
)
current_states
.
append
(
states
)
# make mask to remove bias value in padded part
mask
=
to_device
(
xs_pad
,
make_pad_mask
(
ilens
).
unsqueeze
(
-
1
))
return
xs_pad
.
masked_fill
(
mask
,
0.0
),
ilens
,
current_states
def
encoder_for
(
args
,
idim
,
subsample
):
"""Instantiates an encoder module given the program arguments
:param Namespace args: The arguments
:param int or List of integer idim: dimension of input, e.g. 83, or
List of dimensions of inputs, e.g. [83,83]
:param List or List of List subsample: subsample factors, e.g. [1,2,2,1,1], or
List of subsample factors of each encoder.
e.g. [[1,2,2,1,1], [1,2,2,1,1]]
:rtype torch.nn.Module
:return: The encoder module
"""
num_encs
=
getattr
(
args
,
"num_encs"
,
1
)
# use getattr to keep compatibility
if
num_encs
==
1
:
# compatible with single encoder asr mode
return
Encoder
(
args
.
etype
,
idim
,
args
.
elayers
,
args
.
eunits
,
args
.
eprojs
,
subsample
,
args
.
dropout_rate
,
)
elif
num_encs
>=
1
:
enc_list
=
torch
.
nn
.
ModuleList
()
for
idx
in
range
(
num_encs
):
enc
=
Encoder
(
args
.
etype
[
idx
],
idim
[
idx
],
args
.
elayers
[
idx
],
args
.
eunits
[
idx
],
args
.
eprojs
,
subsample
[
idx
],
args
.
dropout_rate
[
idx
],
)
enc_list
.
append
(
enc
)
return
enc_list
else
:
raise
ValueError
(
"Number of encoders needs to be more than one. {}"
.
format
(
num_encs
)
)
Prev
1
…
4
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment