Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
sunzhq2
yidong-infer
Commits
60a2c57a
Commit
60a2c57a
authored
Jan 27, 2026
by
sunzhq2
Committed by
xuxo
Jan 27, 2026
Browse files
update conformer
parent
4a699441
Changes
216
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1886 additions
and
0 deletions
+1886
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/longformer_attention.py
.../nets/pytorch_backend/transformer/longformer_attention.py
+55
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/mask.py
...build/lib/espnet/nets/pytorch_backend/transformer/mask.py
+35
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/multi_layer_conv.py
...pnet/nets/pytorch_backend/transformer/multi_layer_conv.py
+105
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/optimizer.py
.../lib/espnet/nets/pytorch_backend/transformer/optimizer.py
+75
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/plot.py
...build/lib/espnet/nets/pytorch_backend/transformer/plot.py
+158
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py
.../pytorch_backend/transformer/positionwise_feed_forward.py
+32
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/repeat.py
...ild/lib/espnet/nets/pytorch_backend/transformer/repeat.py
+46
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/subsampling.py
...ib/espnet/nets/pytorch_backend/transformer/subsampling.py
+318
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py
...pytorch_backend/transformer/subsampling_without_posenc.py
+62
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/wavenet.py
...20240621/build/lib/espnet/nets/pytorch_backend/wavenet.py
+447
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorer_interface.py
...202304_20240621/build/lib/espnet/nets/scorer_interface.py
+186
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/__init__.py
...202304_20240621/build/lib/espnet/nets/scorers/__init__.py
+1
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/ctc.py
...et-v.202304_20240621/build/lib/espnet/nets/scorers/ctc.py
+157
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/length_bonus.py
...04_20240621/build/lib/espnet/nets/scorers/length_bonus.py
+59
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/ngram.py
...-v.202304_20240621/build/lib/espnet/nets/scorers/ngram.py
+101
-0
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/uasr.py
...t-v.202304_20240621/build/lib/espnet/nets/scorers/uasr.py
+49
-0
No files found.
Too many changes to show.
To preserve performance only
216 of 216+
files are displayed.
Plain diff
Email patch
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/longformer_attention.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2022 Roshan Sharma (Carnegie Mellon University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Longformer based Local Attention Definition."""
from
longformer.longformer
import
LongformerConfig
,
LongformerSelfAttention
from
torch
import
nn
class
LongformerAttention
(
nn
.
Module
):
"""Longformer based Local Attention Definition."""
def
__init__
(
self
,
config
:
LongformerConfig
,
layer_id
:
int
):
"""Compute Longformer based Self-Attention.
Args:
config : Longformer attention configuration
layer_id: Integer representing the layer index
"""
super
().
__init__
()
self
.
attention_window
=
config
.
attention_window
[
layer_id
]
self
.
attention_layer
=
LongformerSelfAttention
(
config
,
layer_id
=
layer_id
)
self
.
attention
=
None
def
forward
(
self
,
query
,
key
,
value
,
mask
):
"""Compute Longformer Self-Attention with masking.
Expects `len(hidden_states)` to be multiple of `attention_window`.
Padding to `attention_window` happens in :meth:`encoder.forward`
to avoid redoing the padding on each layer.
Args:
query (torch.Tensor): Query tensor (#batch, time1, size).
key (torch.Tensor): Key tensor (#batch, time2, size).
value (torch.Tensor): Value tensor (#batch, time2, size).
pos_emb (torch.Tensor): Positional embedding tensor
(#batch, 2*time1-1, size).
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
(#batch, time1, time2).
Returns:
torch.Tensor: Output tensor (#batch, time1, d_model).
"""
attention_mask
=
mask
.
int
()
attention_mask
[
mask
==
0
]
=
-
1
attention_mask
[
mask
==
1
]
=
0
output
,
self
.
attention
=
self
.
attention_layer
(
hidden_states
=
query
,
attention_mask
=
attention_mask
.
unsqueeze
(
1
),
head_mask
=
None
,
output_attentions
=
True
,
)
return
output
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/mask.py
0 → 100644
View file @
60a2c57a
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Mask module."""
import
torch
def
subsequent_mask
(
size
,
device
=
"cpu"
,
dtype
=
torch
.
bool
):
"""Create mask for subsequent steps (size, size).
:param int size: size of mask
:param str device: "cpu" or "cuda" or torch.Tensor.device
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor
>>> subsequent_mask(3)
[[1, 0, 0],
[1, 1, 0],
[1, 1, 1]]
"""
ret
=
torch
.
ones
(
size
,
size
,
device
=
device
,
dtype
=
dtype
)
return
torch
.
tril
(
ret
,
out
=
ret
)
def
target_mask
(
ys_in_pad
,
ignore_id
):
"""Create mask for decoder self-attention.
:param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax)
:param int ignore_id: index of padding
:param torch.dtype dtype: result dtype
:rtype: torch.Tensor (B, Lmax, Lmax)
"""
ys_mask
=
ys_in_pad
!=
ignore_id
m
=
subsequent_mask
(
ys_mask
.
size
(
-
1
),
device
=
ys_mask
.
device
).
unsqueeze
(
0
)
return
ys_mask
.
unsqueeze
(
-
2
)
&
m
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/multi_layer_conv.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Layer modules for FFT block in FastSpeech (Feed-forward Transformer)."""
import
torch
class
MultiLayeredConv1d
(
torch
.
nn
.
Module
):
"""Multi-layered conv1d for Transformer block.
This is a module of multi-leyered conv1d designed
to replace positionwise feed-forward network
in Transforner block, which is introduced in
`FastSpeech: Fast, Robust and Controllable Text to Speech`_.
.. _`FastSpeech: Fast, Robust and Controllable Text to Speech`:
https://arxiv.org/pdf/1905.09263.pdf
"""
def
__init__
(
self
,
in_chans
,
hidden_chans
,
kernel_size
,
dropout_rate
):
"""Initialize MultiLayeredConv1d module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super
(
MultiLayeredConv1d
,
self
).
__init__
()
self
.
w_1
=
torch
.
nn
.
Conv1d
(
in_chans
,
hidden_chans
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
self
.
w_2
=
torch
.
nn
.
Conv1d
(
hidden_chans
,
in_chans
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
Returns:
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
"""
x
=
torch
.
relu
(
self
.
w_1
(
x
.
transpose
(
-
1
,
1
))).
transpose
(
-
1
,
1
)
return
self
.
w_2
(
self
.
dropout
(
x
).
transpose
(
-
1
,
1
)).
transpose
(
-
1
,
1
)
class
Conv1dLinear
(
torch
.
nn
.
Module
):
"""Conv1D + Linear for Transformer block.
A variant of MultiLayeredConv1d, which replaces second conv-layer to linear.
"""
def
__init__
(
self
,
in_chans
,
hidden_chans
,
kernel_size
,
dropout_rate
):
"""Initialize Conv1dLinear module.
Args:
in_chans (int): Number of input channels.
hidden_chans (int): Number of hidden channels.
kernel_size (int): Kernel size of conv1d.
dropout_rate (float): Dropout rate.
"""
super
(
Conv1dLinear
,
self
).
__init__
()
self
.
w_1
=
torch
.
nn
.
Conv1d
(
in_chans
,
hidden_chans
,
kernel_size
,
stride
=
1
,
padding
=
(
kernel_size
-
1
)
//
2
,
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_chans
,
in_chans
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (torch.Tensor): Batch of input tensors (B, T, in_chans).
Returns:
torch.Tensor: Batch of output tensors (B, T, hidden_chans).
"""
x
=
torch
.
relu
(
self
.
w_1
(
x
.
transpose
(
-
1
,
1
))).
transpose
(
-
1
,
1
)
return
self
.
w_2
(
self
.
dropout
(
x
))
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/optimizer.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Optimizer module."""
import
torch
class
NoamOpt
(
object
):
"""Optim wrapper that implements rate."""
def
__init__
(
self
,
model_size
,
factor
,
warmup
,
optimizer
):
"""Construct an NoamOpt object."""
self
.
optimizer
=
optimizer
self
.
_step
=
0
self
.
warmup
=
warmup
self
.
factor
=
factor
self
.
model_size
=
model_size
self
.
_rate
=
0
@
property
def
param_groups
(
self
):
"""Return param_groups."""
return
self
.
optimizer
.
param_groups
def
step
(
self
):
"""Update parameters and rate."""
self
.
_step
+=
1
rate
=
self
.
rate
()
for
p
in
self
.
optimizer
.
param_groups
:
p
[
"lr"
]
=
rate
self
.
_rate
=
rate
self
.
optimizer
.
step
()
def
rate
(
self
,
step
=
None
):
"""Implement `lrate` above."""
if
step
is
None
:
step
=
self
.
_step
return
(
self
.
factor
*
self
.
model_size
**
(
-
0.5
)
*
min
(
step
**
(
-
0.5
),
step
*
self
.
warmup
**
(
-
1.5
))
)
def
zero_grad
(
self
):
"""Reset gradient."""
self
.
optimizer
.
zero_grad
()
def
state_dict
(
self
):
"""Return state_dict."""
return
{
"_step"
:
self
.
_step
,
"warmup"
:
self
.
warmup
,
"factor"
:
self
.
factor
,
"model_size"
:
self
.
model_size
,
"_rate"
:
self
.
_rate
,
"optimizer"
:
self
.
optimizer
.
state_dict
(),
}
def
load_state_dict
(
self
,
state_dict
):
"""Load state_dict."""
for
key
,
value
in
state_dict
.
items
():
if
key
==
"optimizer"
:
self
.
optimizer
.
load_state_dict
(
state_dict
[
"optimizer"
])
else
:
setattr
(
self
,
key
,
value
)
def
get_std_opt
(
model_params
,
d_model
,
warmup
,
factor
):
"""Get standard NoamOpt."""
base
=
torch
.
optim
.
Adam
(
model_params
,
lr
=
0
,
betas
=
(
0.9
,
0.98
),
eps
=
1e-9
)
return
NoamOpt
(
d_model
,
factor
,
warmup
,
base
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/plot.py
0 → 100644
View file @
60a2c57a
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import
logging
import
os
import
numpy
from
espnet.asr
import
asr_utils
def
_plot_and_save_attention
(
att_w
,
filename
,
xtokens
=
None
,
ytokens
=
None
):
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
from
matplotlib.ticker
import
MaxNLocator
d
=
os
.
path
.
dirname
(
filename
)
if
not
os
.
path
.
exists
(
d
):
os
.
makedirs
(
d
)
w
,
h
=
plt
.
figaspect
(
1.0
/
len
(
att_w
))
fig
=
plt
.
Figure
(
figsize
=
(
w
*
2
,
h
*
2
))
axes
=
fig
.
subplots
(
1
,
len
(
att_w
))
if
len
(
att_w
)
==
1
:
axes
=
[
axes
]
for
ax
,
aw
in
zip
(
axes
,
att_w
):
# plt.subplot(1, len(att_w), h)
ax
.
imshow
(
aw
.
astype
(
numpy
.
float32
),
aspect
=
"auto"
)
ax
.
set_xlabel
(
"Input"
)
ax
.
set_ylabel
(
"Output"
)
ax
.
xaxis
.
set_major_locator
(
MaxNLocator
(
integer
=
True
))
ax
.
yaxis
.
set_major_locator
(
MaxNLocator
(
integer
=
True
))
# Labels for major ticks
if
xtokens
is
not
None
:
ax
.
set_xticks
(
numpy
.
linspace
(
0
,
len
(
xtokens
),
len
(
xtokens
)
+
1
))
ax
.
set_xticks
(
numpy
.
linspace
(
0
,
len
(
xtokens
),
1
),
minor
=
True
)
ax
.
set_xticklabels
(
xtokens
+
[
""
],
rotation
=
40
)
if
ytokens
is
not
None
:
ax
.
set_yticks
(
numpy
.
linspace
(
0
,
len
(
ytokens
),
len
(
ytokens
)
+
1
))
ax
.
set_yticks
(
numpy
.
linspace
(
0
,
len
(
ytokens
),
1
),
minor
=
True
)
ax
.
set_yticklabels
(
ytokens
+
[
""
])
fig
.
tight_layout
()
return
fig
def
savefig
(
plot
,
filename
):
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
plot
.
savefig
(
filename
)
plt
.
clf
()
def
plot_multi_head_attention
(
data
,
uttid_list
,
attn_dict
,
outdir
,
suffix
=
"png"
,
savefn
=
savefig
,
ikey
=
"input"
,
iaxis
=
0
,
okey
=
"output"
,
oaxis
=
0
,
subsampling_factor
=
4
,
):
"""Plot multi head attentions.
:param dict data: utts info from json file
:param List uttid_list: utterance IDs
:param dict[str, torch.Tensor] attn_dict: multi head attention dict.
values should be torch.Tensor (head, input_length, output_length)
:param str outdir: dir to save fig
:param str suffix: filename suffix including image type (e.g., png)
:param savefn: function to save
:param str ikey: key to access input
:param int iaxis: dimension to access input
:param str okey: key to access output
:param int oaxis: dimension to access output
:param subsampling_factor: subsampling factor in encoder
"""
for
name
,
att_ws
in
attn_dict
.
items
():
for
idx
,
att_w
in
enumerate
(
att_ws
):
data_i
=
data
[
uttid_list
[
idx
]]
filename
=
"%s/%s.%s.%s"
%
(
outdir
,
uttid_list
[
idx
],
name
,
suffix
)
dec_len
=
int
(
data_i
[
okey
][
oaxis
][
"shape"
][
0
])
+
1
# +1 for <eos>
enc_len
=
int
(
data_i
[
ikey
][
iaxis
][
"shape"
][
0
])
is_mt
=
"token"
in
data_i
[
ikey
][
iaxis
].
keys
()
# for ASR/ST
if
not
is_mt
:
enc_len
//=
subsampling_factor
xtokens
,
ytokens
=
None
,
None
if
"encoder"
in
name
:
att_w
=
att_w
[:,
:
enc_len
,
:
enc_len
]
# for MT
if
is_mt
:
xtokens
=
data_i
[
ikey
][
iaxis
][
"token"
].
split
()
ytokens
=
xtokens
[:]
elif
"decoder"
in
name
:
if
"self"
in
name
:
# self-attention
att_w
=
att_w
[:,
:
dec_len
,
:
dec_len
]
if
"token"
in
data_i
[
okey
][
oaxis
].
keys
():
ytokens
=
data_i
[
okey
][
oaxis
][
"token"
].
split
()
+
[
"<eos>"
]
xtokens
=
[
"<sos>"
]
+
data_i
[
okey
][
oaxis
][
"token"
].
split
()
else
:
# cross-attention
att_w
=
att_w
[:,
:
dec_len
,
:
enc_len
]
if
"token"
in
data_i
[
okey
][
oaxis
].
keys
():
ytokens
=
data_i
[
okey
][
oaxis
][
"token"
].
split
()
+
[
"<eos>"
]
# for MT
if
is_mt
:
xtokens
=
data_i
[
ikey
][
iaxis
][
"token"
].
split
()
else
:
logging
.
warning
(
"unknown name for shaping attention"
)
fig
=
_plot_and_save_attention
(
att_w
,
filename
,
xtokens
,
ytokens
)
savefn
(
fig
,
filename
)
class
PlotAttentionReport
(
asr_utils
.
PlotAttentionReport
):
def
plotfn
(
self
,
*
args
,
**
kwargs
):
kwargs
[
"ikey"
]
=
self
.
ikey
kwargs
[
"iaxis"
]
=
self
.
iaxis
kwargs
[
"okey"
]
=
self
.
okey
kwargs
[
"oaxis"
]
=
self
.
oaxis
kwargs
[
"subsampling_factor"
]
=
self
.
factor
plot_multi_head_attention
(
*
args
,
**
kwargs
)
def
__call__
(
self
,
trainer
):
attn_dict
,
uttid_list
=
self
.
get_attention_weights
()
suffix
=
"ep.{.updater.epoch}.png"
.
format
(
trainer
)
self
.
plotfn
(
self
.
data_dict
,
uttid_list
,
attn_dict
,
self
.
outdir
,
suffix
,
savefig
)
def
get_attention_weights
(
self
):
return_batch
,
uttid_list
=
self
.
transform
(
self
.
data
,
return_uttid
=
True
)
batch
=
self
.
converter
([
return_batch
],
self
.
device
)
if
isinstance
(
batch
,
tuple
):
att_ws
=
self
.
att_vis_fn
(
*
batch
)
elif
isinstance
(
batch
,
dict
):
att_ws
=
self
.
att_vis_fn
(
**
batch
)
return
att_ws
,
uttid_list
def
log_attentions
(
self
,
logger
,
step
):
def
log_fig
(
plot
,
filename
):
import
matplotlib
matplotlib
.
use
(
"Agg"
)
import
matplotlib.pyplot
as
plt
logger
.
add_figure
(
os
.
path
.
basename
(
filename
),
plot
,
step
)
plt
.
clf
()
attn_dict
,
uttid_list
=
self
.
get_attention_weights
()
self
.
plotfn
(
self
.
data_dict
,
uttid_list
,
attn_dict
,
self
.
outdir
,
""
,
log_fig
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/positionwise_feed_forward.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Positionwise feed forward layer definition."""
import
torch
class
PositionwiseFeedForward
(
torch
.
nn
.
Module
):
"""Positionwise feed forward layer.
Args:
idim (int): Input dimenstion.
hidden_units (int): The number of hidden units.
dropout_rate (float): Dropout rate.
"""
def
__init__
(
self
,
idim
,
hidden_units
,
dropout_rate
,
activation
=
torch
.
nn
.
ReLU
()):
"""Construct an PositionwiseFeedForward object."""
super
(
PositionwiseFeedForward
,
self
).
__init__
()
self
.
w_1
=
torch
.
nn
.
Linear
(
idim
,
hidden_units
)
self
.
w_2
=
torch
.
nn
.
Linear
(
hidden_units
,
idim
)
self
.
dropout
=
torch
.
nn
.
Dropout
(
dropout_rate
)
self
.
activation
=
activation
def
forward
(
self
,
x
):
"""Forward function."""
return
self
.
w_2
(
self
.
dropout
(
self
.
activation
(
self
.
w_1
(
x
))))
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/repeat.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Repeat the same layer definition."""
import
torch
class
MultiSequential
(
torch
.
nn
.
Sequential
):
"""Multi-input multi-output torch.nn.Sequential."""
def
__init__
(
self
,
*
args
,
layer_drop_rate
=
0.0
):
"""Initialize MultiSequential with layer_drop.
Args:
layer_drop_rate (float): Probability of dropping out each fn (layer).
"""
super
(
MultiSequential
,
self
).
__init__
(
*
args
)
self
.
layer_drop_rate
=
layer_drop_rate
def
forward
(
self
,
*
args
):
"""Repeat."""
_probs
=
torch
.
empty
(
len
(
self
)).
uniform_
()
for
idx
,
m
in
enumerate
(
self
):
if
not
self
.
training
or
(
_probs
[
idx
]
>=
self
.
layer_drop_rate
):
args
=
m
(
*
args
)
return
args
def
repeat
(
N
,
fn
,
layer_drop_rate
=
0.0
):
"""Repeat module N times.
Args:
N (int): Number of repeat time.
fn (Callable): Function to generate module.
layer_drop_rate (float): Probability of dropping out each fn (layer).
Returns:
MultiSequential: Repeated model instance.
"""
return
MultiSequential
(
*
[
fn
(
n
)
for
n
in
range
(
N
)],
layer_drop_rate
=
layer_drop_rate
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/subsampling.py
0 → 100644
View file @
60a2c57a
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2019 Shigeki Karita
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import
torch
from
espnet.nets.pytorch_backend.transformer.embedding
import
PositionalEncoding
class
TooShortUttError
(
Exception
):
"""Raised when the utt is too short for subsampling.
Args:
message (str): Message for error catch
actual_size (int): the short size that cannot pass the subsampling
limit (int): the limit size for subsampling
"""
def
__init__
(
self
,
message
,
actual_size
,
limit
):
"""Construct a TooShortUttError for error handler."""
super
().
__init__
(
message
)
self
.
actual_size
=
actual_size
self
.
limit
=
limit
def
check_short_utt
(
ins
,
size
):
"""Check if the utterance is too short for subsampling."""
if
isinstance
(
ins
,
Conv2dSubsampling1
)
and
size
<
5
:
return
True
,
5
if
isinstance
(
ins
,
Conv2dSubsampling2
)
and
size
<
7
:
return
True
,
7
if
isinstance
(
ins
,
Conv2dSubsampling
)
and
size
<
7
:
return
True
,
7
if
isinstance
(
ins
,
Conv2dSubsampling6
)
and
size
<
11
:
return
True
,
11
if
isinstance
(
ins
,
Conv2dSubsampling8
)
and
size
<
15
:
return
True
,
15
return
False
,
-
1
class
Conv2dSubsampling
(
torch
.
nn
.
Module
):
"""Convolutional 2D subsampling (to 1/4 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def
__init__
(
self
,
idim
,
odim
,
dropout_rate
,
pos_enc
=
None
):
"""Construct an Conv2dSubsampling object."""
super
(
Conv2dSubsampling
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
1
)
//
2
),
odim
),
pos_enc
if
pos_enc
is
not
None
else
PositionalEncoding
(
odim
,
dropout_rate
),
)
def
forward
(
self
,
x
,
x_mask
):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
if
x_mask
is
None
:
return
x
,
None
return
x
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
def
__getitem__
(
self
,
key
):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if
key
!=
-
1
:
raise
NotImplementedError
(
"Support only `-1` (for `reset_parameters`)."
)
return
self
.
out
[
key
]
class
Conv2dSubsampling1
(
torch
.
nn
.
Module
):
"""Similar to Conv2dSubsampling module, but without any subsampling performed.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def
__init__
(
self
,
idim
,
odim
,
dropout_rate
,
pos_enc
=
None
):
"""Construct an Conv2dSubsampling1 object."""
super
(
Conv2dSubsampling1
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
1
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
1
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
(
idim
-
4
),
odim
),
pos_enc
if
pos_enc
is
not
None
else
PositionalEncoding
(
odim
,
dropout_rate
),
)
def
forward
(
self
,
x
,
x_mask
):
"""Pass x through 2 Conv2d layers without subsampling.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim).
where time' = time - 4.
torch.Tensor: Subsampled mask (#batch, 1, time').
where time' = time - 4.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
if
x_mask
is
None
:
return
x
,
None
return
x
,
x_mask
[:,
:,
:
-
4
]
def
__getitem__
(
self
,
key
):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if
key
!=
-
1
:
raise
NotImplementedError
(
"Support only `-1` (for `reset_parameters`)."
)
return
self
.
out
[
key
]
class
Conv2dSubsampling2
(
torch
.
nn
.
Module
):
"""Convolutional 2D subsampling (to 1/2 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def
__init__
(
self
,
idim
,
odim
,
dropout_rate
,
pos_enc
=
None
):
"""Construct an Conv2dSubsampling2 object."""
super
(
Conv2dSubsampling2
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
1
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)),
odim
),
pos_enc
if
pos_enc
is
not
None
else
PositionalEncoding
(
odim
,
dropout_rate
),
)
def
forward
(
self
,
x
,
x_mask
):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 2.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 2.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
if
x_mask
is
None
:
return
x
,
None
return
x
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
1
]
def
__getitem__
(
self
,
key
):
"""Get item.
When reset_parameters() is called, if use_scaled_pos_enc is used,
return the positioning encoding.
"""
if
key
!=
-
1
:
raise
NotImplementedError
(
"Support only `-1` (for `reset_parameters`)."
)
return
self
.
out
[
key
]
class
Conv2dSubsampling6
(
torch
.
nn
.
Module
):
"""Convolutional 2D subsampling (to 1/6 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def
__init__
(
self
,
idim
,
odim
,
dropout_rate
,
pos_enc
=
None
):
"""Construct an Conv2dSubsampling6 object."""
super
(
Conv2dSubsampling6
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
5
,
3
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
(((
idim
-
1
)
//
2
-
2
)
//
3
),
odim
),
pos_enc
if
pos_enc
is
not
None
else
PositionalEncoding
(
odim
,
dropout_rate
),
)
def
forward
(
self
,
x
,
x_mask
):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 6.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 6.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
if
x_mask
is
None
:
return
x
,
None
return
x
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
4
:
3
]
class
Conv2dSubsampling8
(
torch
.
nn
.
Module
):
"""Convolutional 2D subsampling (to 1/8 length).
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
pos_enc (torch.nn.Module): Custom position encoding layer.
"""
def
__init__
(
self
,
idim
,
odim
,
dropout_rate
,
pos_enc
=
None
):
"""Construct an Conv2dSubsampling8 object."""
super
(
Conv2dSubsampling8
,
self
).
__init__
()
self
.
conv
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Conv2d
(
1
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
torch
.
nn
.
Conv2d
(
odim
,
odim
,
3
,
2
),
torch
.
nn
.
ReLU
(),
)
self
.
out
=
torch
.
nn
.
Sequential
(
torch
.
nn
.
Linear
(
odim
*
((((
idim
-
1
)
//
2
-
1
)
//
2
-
1
)
//
2
),
odim
),
pos_enc
if
pos_enc
is
not
None
else
PositionalEncoding
(
odim
,
dropout_rate
),
)
def
forward
(
self
,
x
,
x_mask
):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 8.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 8.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
if
x_mask
is
None
:
return
x
,
None
return
x
,
x_mask
[:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
][:,
:,
:
-
2
:
2
]
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/transformer/subsampling_without_posenc.py
0 → 100644
View file @
60a2c57a
# Copyright 2020 Emiru Tsunoo
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""Subsampling layer definition."""
import
math
import
torch
class
Conv2dSubsamplingWOPosEnc
(
torch
.
nn
.
Module
):
"""Convolutional 2D subsampling.
Args:
idim (int): Input dimension.
odim (int): Output dimension.
dropout_rate (float): Dropout rate.
kernels (list): kernel sizes
strides (list): stride sizes
"""
def
__init__
(
self
,
idim
,
odim
,
dropout_rate
,
kernels
,
strides
):
"""Construct an Conv2dSubsamplingWOPosEnc object."""
assert
len
(
kernels
)
==
len
(
strides
)
super
().
__init__
()
conv
=
[]
olen
=
idim
for
i
,
(
k
,
s
)
in
enumerate
(
zip
(
kernels
,
strides
)):
conv
+=
[
torch
.
nn
.
Conv2d
(
1
if
i
==
0
else
odim
,
odim
,
k
,
s
),
torch
.
nn
.
ReLU
(),
]
olen
=
math
.
floor
((
olen
-
k
)
/
s
+
1
)
self
.
conv
=
torch
.
nn
.
Sequential
(
*
conv
)
self
.
out
=
torch
.
nn
.
Linear
(
odim
*
olen
,
odim
)
self
.
strides
=
strides
self
.
kernels
=
kernels
def
forward
(
self
,
x
,
x_mask
):
"""Subsample x.
Args:
x (torch.Tensor): Input tensor (#batch, time, idim).
x_mask (torch.Tensor): Input mask (#batch, 1, time).
Returns:
torch.Tensor: Subsampled tensor (#batch, time', odim),
where time' = time // 4.
torch.Tensor: Subsampled mask (#batch, 1, time'),
where time' = time // 4.
"""
x
=
x
.
unsqueeze
(
1
)
# (b, c, t, f)
x
=
self
.
conv
(
x
)
b
,
c
,
t
,
f
=
x
.
size
()
x
=
self
.
out
(
x
.
transpose
(
1
,
2
).
contiguous
().
view
(
b
,
t
,
c
*
f
))
if
x_mask
is
None
:
return
x
,
None
for
k
,
s
in
zip
(
self
.
kernels
,
self
.
strides
):
x_mask
=
x_mask
[:,
:,
:
-
k
+
1
:
s
]
return
x
,
x_mask
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/pytorch_backend/wavenet.py
0 → 100644
View file @
60a2c57a
# -*- coding: utf-8 -*-
# Copyright 2019 Tomoki Hayashi (Nagoya University)
# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""This code is based on https://github.com/kan-bayashi/PytorchWaveNetVocoder."""
import
logging
import
sys
import
time
import
numpy
as
np
import
torch
import
torch.nn.functional
as
F
from
torch
import
nn
def
encode_mu_law
(
x
,
mu
=
256
):
"""Perform mu-law encoding.
Args:
x (ndarray): Audio signal with the range from -1 to 1.
mu (int): Quantized level.
Returns:
ndarray: Quantized audio signal with the range from 0 to mu - 1.
"""
mu
=
mu
-
1
fx
=
np
.
sign
(
x
)
*
np
.
log
(
1
+
mu
*
np
.
abs
(
x
))
/
np
.
log
(
1
+
mu
)
return
np
.
floor
((
fx
+
1
)
/
2
*
mu
+
0.5
).
astype
(
np
.
int64
)
def
decode_mu_law
(
y
,
mu
=
256
):
"""Perform mu-law decoding.
Args:
x (ndarray): Quantized audio signal with the range from 0 to mu - 1.
mu (int): Quantized level.
Returns:
ndarray: Audio signal with the range from -1 to 1.
"""
mu
=
mu
-
1
fx
=
(
y
-
0.5
)
/
mu
*
2
-
1
x
=
np
.
sign
(
fx
)
/
mu
*
((
1
+
mu
)
**
np
.
abs
(
fx
)
-
1
)
return
x
def
initialize
(
m
):
"""Initilize conv layers with xavier.
Args:
m (torch.nn.Module): Torch module.
"""
if
isinstance
(
m
,
nn
.
Conv1d
):
nn
.
init
.
xavier_uniform_
(
m
.
weight
)
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
if
isinstance
(
m
,
nn
.
ConvTranspose2d
):
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
nn
.
init
.
constant_
(
m
.
bias
,
0.0
)
class
OneHot
(
nn
.
Module
):
"""Convert to one-hot vector.
Args:
depth (int): Dimension of one-hot vector.
"""
def
__init__
(
self
,
depth
):
super
(
OneHot
,
self
).
__init__
()
self
.
depth
=
depth
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (LongTensor): long tensor variable with the shape (B, T)
Returns:
Tensor: float tensor variable with the shape (B, depth, T)
"""
x
=
x
%
self
.
depth
x
=
torch
.
unsqueeze
(
x
,
2
)
x_onehot
=
x
.
new_zeros
(
x
.
size
(
0
),
x
.
size
(
1
),
self
.
depth
).
float
()
return
x_onehot
.
scatter_
(
2
,
x
,
1
)
class
CausalConv1d
(
nn
.
Module
):
"""1D dilated causal convolution."""
def
__init__
(
self
,
in_channels
,
out_channels
,
kernel_size
,
dilation
=
1
,
bias
=
True
):
super
(
CausalConv1d
,
self
).
__init__
()
self
.
in_channels
=
in_channels
self
.
out_channels
=
out_channels
self
.
kernel_size
=
kernel_size
self
.
dilation
=
dilation
self
.
padding
=
padding
=
(
kernel_size
-
1
)
*
dilation
self
.
conv
=
nn
.
Conv1d
(
in_channels
,
out_channels
,
kernel_size
,
padding
=
padding
,
dilation
=
dilation
,
bias
=
bias
,
)
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, in_channels, T).
Returns:
Tensor: Tensor with the shape (B, out_channels, T)
"""
x
=
self
.
conv
(
x
)
if
self
.
padding
!=
0
:
x
=
x
[:,
:,
:
-
self
.
padding
]
return
x
class
UpSampling
(
nn
.
Module
):
"""Upsampling layer with deconvolution.
Args:
upsampling_factor (int): Upsampling factor.
"""
def
__init__
(
self
,
upsampling_factor
,
bias
=
True
):
super
(
UpSampling
,
self
).
__init__
()
self
.
upsampling_factor
=
upsampling_factor
self
.
bias
=
bias
self
.
conv
=
nn
.
ConvTranspose2d
(
1
,
1
,
kernel_size
=
(
1
,
self
.
upsampling_factor
),
stride
=
(
1
,
self
.
upsampling_factor
),
bias
=
self
.
bias
,
)
def
forward
(
self
,
x
):
"""Calculate forward propagation.
Args:
x (Tensor): Input tensor with the shape (B, C, T)
Returns:
Tensor: Tensor with the shape (B, C, T') where T' = T * upsampling_factor.
"""
x
=
x
.
unsqueeze
(
1
)
# B x 1 x C x T
x
=
self
.
conv
(
x
)
# B x 1 x C x T'
return
x
.
squeeze
(
1
)
class
WaveNet
(
nn
.
Module
):
"""Conditional wavenet.
Args:
n_quantize (int): Number of quantization.
n_aux (int): Number of aux feature dimension.
n_resch (int): Number of filter channels for residual block.
n_skipch (int): Number of filter channels for skip connection.
dilation_depth (int): Number of dilation depth
(e.g. if set 10, max dilation = 2^(10-1)).
dilation_repeat (int): Number of dilation repeat.
kernel_size (int): Filter size of dilated causal convolution.
upsampling_factor (int): Upsampling factor.
"""
def
__init__
(
self
,
n_quantize
=
256
,
n_aux
=
28
,
n_resch
=
512
,
n_skipch
=
256
,
dilation_depth
=
10
,
dilation_repeat
=
3
,
kernel_size
=
2
,
upsampling_factor
=
0
,
):
super
(
WaveNet
,
self
).
__init__
()
self
.
n_aux
=
n_aux
self
.
n_quantize
=
n_quantize
self
.
n_resch
=
n_resch
self
.
n_skipch
=
n_skipch
self
.
kernel_size
=
kernel_size
self
.
dilation_depth
=
dilation_depth
self
.
dilation_repeat
=
dilation_repeat
self
.
upsampling_factor
=
upsampling_factor
self
.
dilations
=
[
2
**
i
for
i
in
range
(
self
.
dilation_depth
)
]
*
self
.
dilation_repeat
self
.
receptive_field
=
(
self
.
kernel_size
-
1
)
*
sum
(
self
.
dilations
)
+
1
# for preprocessing
self
.
onehot
=
OneHot
(
self
.
n_quantize
)
self
.
causal
=
CausalConv1d
(
self
.
n_quantize
,
self
.
n_resch
,
self
.
kernel_size
)
if
self
.
upsampling_factor
>
0
:
self
.
upsampling
=
UpSampling
(
self
.
upsampling_factor
)
# for residual blocks
self
.
dil_sigmoid
=
nn
.
ModuleList
()
self
.
dil_tanh
=
nn
.
ModuleList
()
self
.
aux_1x1_sigmoid
=
nn
.
ModuleList
()
self
.
aux_1x1_tanh
=
nn
.
ModuleList
()
self
.
skip_1x1
=
nn
.
ModuleList
()
self
.
res_1x1
=
nn
.
ModuleList
()
for
d
in
self
.
dilations
:
self
.
dil_sigmoid
+=
[
CausalConv1d
(
self
.
n_resch
,
self
.
n_resch
,
self
.
kernel_size
,
d
)
]
self
.
dil_tanh
+=
[
CausalConv1d
(
self
.
n_resch
,
self
.
n_resch
,
self
.
kernel_size
,
d
)
]
self
.
aux_1x1_sigmoid
+=
[
nn
.
Conv1d
(
self
.
n_aux
,
self
.
n_resch
,
1
)]
self
.
aux_1x1_tanh
+=
[
nn
.
Conv1d
(
self
.
n_aux
,
self
.
n_resch
,
1
)]
self
.
skip_1x1
+=
[
nn
.
Conv1d
(
self
.
n_resch
,
self
.
n_skipch
,
1
)]
self
.
res_1x1
+=
[
nn
.
Conv1d
(
self
.
n_resch
,
self
.
n_resch
,
1
)]
# for postprocessing
self
.
conv_post_1
=
nn
.
Conv1d
(
self
.
n_skipch
,
self
.
n_skipch
,
1
)
self
.
conv_post_2
=
nn
.
Conv1d
(
self
.
n_skipch
,
self
.
n_quantize
,
1
)
def
forward
(
self
,
x
,
h
):
"""Calculate forward propagation.
Args:
x (LongTensor): Quantized input waveform tensor with the shape (B, T).
h (Tensor): Auxiliary feature tensor with the shape (B, n_aux, T).
Returns:
Tensor: Logits with the shape (B, T, n_quantize).
"""
# preprocess
output
=
self
.
_preprocess
(
x
)
if
self
.
upsampling_factor
>
0
:
h
=
self
.
upsampling
(
h
)
# residual block
skip_connections
=
[]
for
i
in
range
(
len
(
self
.
dilations
)):
output
,
skip
=
self
.
_residual_forward
(
output
,
h
,
self
.
dil_sigmoid
[
i
],
self
.
dil_tanh
[
i
],
self
.
aux_1x1_sigmoid
[
i
],
self
.
aux_1x1_tanh
[
i
],
self
.
skip_1x1
[
i
],
self
.
res_1x1
[
i
],
)
skip_connections
.
append
(
skip
)
# skip-connection part
output
=
sum
(
skip_connections
)
output
=
self
.
_postprocess
(
output
)
return
output
def
generate
(
self
,
x
,
h
,
n_samples
,
interval
=
None
,
mode
=
"sampling"
):
"""Generate a waveform with fast genration algorithm.
This generation based on `Fast WaveNet Generation Algorithm`_.
Args:
x (LongTensor): Initial waveform tensor with the shape (T,).
h (Tensor): Auxiliary feature tensor with the shape (n_samples + T, n_aux).
n_samples (int): Number of samples to be generated.
interval (int, optional): Log interval.
mode (str, optional): "sampling" or "argmax".
Return:
ndarray: Generated quantized waveform (n_samples).
.. _`Fast WaveNet Generation Algorithm`: https://arxiv.org/abs/1611.09482
"""
# reshape inputs
assert
len
(
x
.
shape
)
==
1
assert
len
(
h
.
shape
)
==
2
and
h
.
shape
[
1
]
==
self
.
n_aux
x
=
x
.
unsqueeze
(
0
)
h
=
h
.
transpose
(
0
,
1
).
unsqueeze
(
0
)
# perform upsampling
if
self
.
upsampling_factor
>
0
:
h
=
self
.
upsampling
(
h
)
# padding for shortage
if
n_samples
>
h
.
shape
[
2
]:
h
=
F
.
pad
(
h
,
(
0
,
n_samples
-
h
.
shape
[
2
]),
"replicate"
)
# padding if the length less than
n_pad
=
self
.
receptive_field
-
x
.
size
(
1
)
if
n_pad
>
0
:
x
=
F
.
pad
(
x
,
(
n_pad
,
0
),
"constant"
,
self
.
n_quantize
//
2
)
h
=
F
.
pad
(
h
,
(
n_pad
,
0
),
"replicate"
)
# prepare buffer
output
=
self
.
_preprocess
(
x
)
h_
=
h
[:,
:,
:
x
.
size
(
1
)]
output_buffer
=
[]
buffer_size
=
[]
for
i
,
d
in
enumerate
(
self
.
dilations
):
output
,
_
=
self
.
_residual_forward
(
output
,
h_
,
self
.
dil_sigmoid
[
i
],
self
.
dil_tanh
[
i
],
self
.
aux_1x1_sigmoid
[
i
],
self
.
aux_1x1_tanh
[
i
],
self
.
skip_1x1
[
i
],
self
.
res_1x1
[
i
],
)
if
d
==
2
**
(
self
.
dilation_depth
-
1
):
buffer_size
.
append
(
self
.
kernel_size
-
1
)
else
:
buffer_size
.
append
(
d
*
2
*
(
self
.
kernel_size
-
1
))
output_buffer
.
append
(
output
[:,
:,
-
buffer_size
[
i
]
-
1
:
-
1
])
# generate
samples
=
x
[
0
]
start_time
=
time
.
time
()
for
i
in
range
(
n_samples
):
output
=
samples
[
-
self
.
kernel_size
*
2
+
1
:].
unsqueeze
(
0
)
output
=
self
.
_preprocess
(
output
)
h_
=
h
[:,
:,
samples
.
size
(
0
)
-
1
].
contiguous
().
view
(
1
,
self
.
n_aux
,
1
)
output_buffer_next
=
[]
skip_connections
=
[]
for
j
,
d
in
enumerate
(
self
.
dilations
):
output
,
skip
=
self
.
_generate_residual_forward
(
output
,
h_
,
self
.
dil_sigmoid
[
j
],
self
.
dil_tanh
[
j
],
self
.
aux_1x1_sigmoid
[
j
],
self
.
aux_1x1_tanh
[
j
],
self
.
skip_1x1
[
j
],
self
.
res_1x1
[
j
],
)
output
=
torch
.
cat
([
output_buffer
[
j
],
output
],
dim
=
2
)
output_buffer_next
.
append
(
output
[:,
:,
-
buffer_size
[
j
]
:])
skip_connections
.
append
(
skip
)
# update buffer
output_buffer
=
output_buffer_next
# get predicted sample
output
=
sum
(
skip_connections
)
output
=
self
.
_postprocess
(
output
)[
0
]
if
mode
==
"sampling"
:
posterior
=
F
.
softmax
(
output
[
-
1
],
dim
=
0
)
dist
=
torch
.
distributions
.
Categorical
(
posterior
)
sample
=
dist
.
sample
().
unsqueeze
(
0
)
elif
mode
==
"argmax"
:
sample
=
output
.
argmax
(
-
1
)
else
:
logging
.
error
(
"mode should be sampling or argmax"
)
sys
.
exit
(
1
)
samples
=
torch
.
cat
([
samples
,
sample
],
dim
=
0
)
# show progress
if
interval
is
not
None
and
(
i
+
1
)
%
interval
==
0
:
elapsed_time_per_sample
=
(
time
.
time
()
-
start_time
)
/
interval
logging
.
info
(
"%d/%d estimated time = %.3f sec (%.3f sec / sample)"
%
(
i
+
1
,
n_samples
,
(
n_samples
-
i
-
1
)
*
elapsed_time_per_sample
,
elapsed_time_per_sample
,
)
)
start_time
=
time
.
time
()
return
samples
[
-
n_samples
:].
cpu
().
numpy
()
def
_preprocess
(
self
,
x
):
x
=
self
.
onehot
(
x
).
transpose
(
1
,
2
)
output
=
self
.
causal
(
x
)
return
output
def
_postprocess
(
self
,
x
):
output
=
F
.
relu
(
x
)
output
=
self
.
conv_post_1
(
output
)
output
=
F
.
relu
(
output
)
# B x C x T
output
=
self
.
conv_post_2
(
output
).
transpose
(
1
,
2
)
# B x T x C
return
output
def
_residual_forward
(
self
,
x
,
h
,
dil_sigmoid
,
dil_tanh
,
aux_1x1_sigmoid
,
aux_1x1_tanh
,
skip_1x1
,
res_1x1
,
):
output_sigmoid
=
dil_sigmoid
(
x
)
output_tanh
=
dil_tanh
(
x
)
aux_output_sigmoid
=
aux_1x1_sigmoid
(
h
)
aux_output_tanh
=
aux_1x1_tanh
(
h
)
output
=
torch
.
sigmoid
(
output_sigmoid
+
aux_output_sigmoid
)
*
torch
.
tanh
(
output_tanh
+
aux_output_tanh
)
skip
=
skip_1x1
(
output
)
output
=
res_1x1
(
output
)
output
=
output
+
x
return
output
,
skip
def
_generate_residual_forward
(
self
,
x
,
h
,
dil_sigmoid
,
dil_tanh
,
aux_1x1_sigmoid
,
aux_1x1_tanh
,
skip_1x1
,
res_1x1
,
):
output_sigmoid
=
dil_sigmoid
(
x
)[:,
:,
-
1
:]
output_tanh
=
dil_tanh
(
x
)[:,
:,
-
1
:]
aux_output_sigmoid
=
aux_1x1_sigmoid
(
h
)
aux_output_tanh
=
aux_1x1_tanh
(
h
)
output
=
torch
.
sigmoid
(
output_sigmoid
+
aux_output_sigmoid
)
*
torch
.
tanh
(
output_tanh
+
aux_output_tanh
)
skip
=
skip_1x1
(
output
)
output
=
res_1x1
(
output
)
output
=
output
+
x
[:,
:,
-
1
:]
# B x C x 1
return
output
,
skip
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorer_interface.py
0 → 100644
View file @
60a2c57a
"""Scorer interface module."""
import
warnings
from
typing
import
Any
,
List
,
Tuple
import
torch
class
ScorerInterface
:
"""Scorer interface for beam search.
The scorer performs scoring of the all tokens in vocabulary.
Examples:
* Search heuristics
* :class:`espnet.nets.scorers.length_bonus.LengthBonus`
* Decoder networks of the sequence-to-sequence models
* :class:`espnet.nets.pytorch_backend.nets.transformer.decoder.Decoder`
* :class:`espnet.nets.pytorch_backend.nets.rnn.decoders.Decoder`
* Neural language models
* :class:`espnet.nets.pytorch_backend.lm.transformer.TransformerLM`
* :class:`espnet.nets.pytorch_backend.lm.default.DefaultRNNLM`
* :class:`espnet.nets.pytorch_backend.lm.seq_rnn.SequentialRNNLM`
"""
def
init_state
(
self
,
x
:
torch
.
Tensor
)
->
Any
:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return
None
def
select_state
(
self
,
state
:
Any
,
i
:
int
,
new_id
:
int
=
None
)
->
Any
:
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label index to select a state if necessary
Returns:
state: pruned state
"""
return
None
if
state
is
None
else
state
[
i
]
def
score
(
self
,
y
:
torch
.
Tensor
,
state
:
Any
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Any
]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
scores for next token that has a shape of `(n_vocab)`
and next state for ys
"""
raise
NotImplementedError
def
final_score
(
self
,
state
:
Any
)
->
float
:
"""Score eos (optional).
Args:
state: Scorer state for prefix tokens
Returns:
float: final score
"""
return
0.0
class
BatchScorerInterface
(
ScorerInterface
):
"""Batch scorer interface."""
def
batch_init_state
(
self
,
x
:
torch
.
Tensor
)
->
Any
:
"""Get an initial state for decoding (optional).
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
return
self
.
init_state
(
x
)
def
batch_score
(
self
,
ys
:
torch
.
Tensor
,
states
:
List
[
Any
],
xs
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
List
[
Any
]]:
"""Score new token batch (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
warnings
.
warn
(
"{} batch score is implemented through for loop not parallelized"
.
format
(
self
.
__class__
.
__name__
)
)
scores
=
list
()
outstates
=
list
()
for
i
,
(
y
,
state
,
x
)
in
enumerate
(
zip
(
ys
,
states
,
xs
)):
score
,
outstate
=
self
.
score
(
y
,
state
,
x
)
outstates
.
append
(
outstate
)
scores
.
append
(
score
)
scores
=
torch
.
cat
(
scores
,
0
).
view
(
ys
.
shape
[
0
],
-
1
)
return
scores
,
outstates
class
PartialScorerInterface
(
ScorerInterface
):
"""Partial scorer interface for beam search.
The partial scorer performs scoring when non-partial scorer finished scoring,
and receives pre-pruned next tokens to score because it is too heavy to score
all the tokens.
Examples:
* Prefix search for connectionist-temporal-classification models
* :class:`espnet.nets.scorers.ctc.CTCPrefixScorer`
"""
def
score_partial
(
self
,
y
:
torch
.
Tensor
,
next_tokens
:
torch
.
Tensor
,
state
:
Any
,
x
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
Any
]:
"""Score new token (required).
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): The encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
raise
NotImplementedError
class
BatchPartialScorerInterface
(
BatchScorerInterface
,
PartialScorerInterface
):
"""Batch partial scorer interface for beam search."""
def
batch_score_partial
(
self
,
ys
:
torch
.
Tensor
,
next_tokens
:
torch
.
Tensor
,
states
:
List
[
Any
],
xs
:
torch
.
Tensor
,
)
->
Tuple
[
torch
.
Tensor
,
Any
]:
"""Score new token (required).
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
next_tokens (torch.Tensor): torch.int64 tokens to score (n_batch, n_token).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for ys that has a shape `(n_batch, n_vocab)`
and next states for ys
"""
raise
NotImplementedError
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/__init__.py
0 → 100644
View file @
60a2c57a
"""Initialize sub package."""
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/ctc.py
0 → 100644
View file @
60a2c57a
"""ScorerInterface implementation for CTC."""
import
numpy
as
np
import
torch
from
espnet.nets.ctc_prefix_score
import
CTCPrefixScore
,
CTCPrefixScoreTH
from
espnet.nets.scorer_interface
import
BatchPartialScorerInterface
class
CTCPrefixScorer
(
BatchPartialScorerInterface
):
"""Decoder interface wrapper for CTCPrefixScore."""
def
__init__
(
self
,
ctc
:
torch
.
nn
.
Module
,
eos
:
int
):
"""Initialize class.
Args:
ctc (torch.nn.Module): The CTC implementation.
For example, :class:`espnet.nets.pytorch_backend.ctc.CTC`
eos (int): The end-of-sequence id.
"""
self
.
ctc
=
ctc
self
.
eos
=
eos
self
.
impl
=
None
def
init_state
(
self
,
x
:
torch
.
Tensor
):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
)).
detach
().
squeeze
(
0
).
cpu
().
numpy
()
# TODO(karita): use CTCPrefixScoreTH
self
.
impl
=
CTCPrefixScore
(
logp
,
0
,
self
.
eos
,
np
)
return
0
,
self
.
impl
.
initial_state
()
def
select_state
(
self
,
state
,
i
,
new_id
=
None
):
"""Select state with relative ids in the main beam search.
Args:
state: Decoder state for prefix tokens
i (int): Index to select a state in the main beam search
new_id (int): New label id to select a state if necessary
Returns:
state: pruned state
"""
if
type
(
state
)
==
tuple
:
if
len
(
state
)
==
2
:
# for CTCPrefixScore
sc
,
st
=
state
return
sc
[
i
],
st
[
i
]
else
:
# for CTCPrefixScoreTH (need new_id > 0)
r
,
log_psi
,
f_min
,
f_max
,
scoring_idmap
=
state
s
=
log_psi
[
i
,
new_id
].
expand
(
log_psi
.
size
(
1
))
if
scoring_idmap
is
not
None
:
return
r
[:,
:,
i
,
scoring_idmap
[
i
,
new_id
]],
s
,
f_min
,
f_max
else
:
return
r
[:,
:,
i
,
new_id
],
s
,
f_min
,
f_max
return
None
if
state
is
None
else
state
[
i
]
def
score_partial
(
self
,
y
,
ids
,
state
,
x
):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
next_tokens (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
prev_score
,
state
=
state
presub_score
,
new_st
=
self
.
impl
(
y
.
cpu
(),
ids
.
cpu
(),
state
)
tscore
=
torch
.
as_tensor
(
presub_score
-
prev_score
,
device
=
x
.
device
,
dtype
=
x
.
dtype
)
return
tscore
,
(
presub_score
,
new_st
)
def
batch_init_state
(
self
,
x
:
torch
.
Tensor
):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
))
# assuming batch_size = 1
xlen
=
torch
.
tensor
([
logp
.
size
(
1
)])
self
.
impl
=
CTCPrefixScoreTH
(
logp
,
xlen
,
0
,
self
.
eos
)
return
None
def
batch_score_partial
(
self
,
y
,
ids
,
state
,
x
):
"""Score new token.
Args:
y (torch.Tensor): 1D prefix token
ids (torch.Tensor): torch.int64 next token to score
state: decoder state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys
Returns:
tuple[torch.Tensor, Any]:
Tuple of a score tensor for y that has a shape `(len(next_tokens),)`
and next state for ys
"""
batch_state
=
(
(
torch
.
stack
([
s
[
0
]
for
s
in
state
],
dim
=
2
),
torch
.
stack
([
s
[
1
]
for
s
in
state
]),
state
[
0
][
2
],
state
[
0
][
3
],
)
if
state
[
0
]
is
not
None
else
None
)
return
self
.
impl
(
y
,
batch_state
,
ids
)
def
extend_prob
(
self
,
x
:
torch
.
Tensor
):
"""Extend probs for decoding.
This extension is for streaming decoding
as in Eq (14) in https://arxiv.org/abs/2006.14941
Args:
x (torch.Tensor): The encoded feature tensor
"""
logp
=
self
.
ctc
.
log_softmax
(
x
.
unsqueeze
(
0
))
self
.
impl
.
extend_prob
(
logp
)
def
extend_state
(
self
,
state
):
"""Extend state for decoding.
This extension is for streaming decoding
as in Eq (14) in https://arxiv.org/abs/2006.14941
Args:
state: The states of hyps
Returns: exteded state
"""
new_state
=
[]
for
s
in
state
:
new_state
.
append
(
self
.
impl
.
extend_state
(
s
))
return
new_state
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/length_bonus.py
0 → 100644
View file @
60a2c57a
"""Length bonus module."""
from
typing
import
Any
,
List
,
Tuple
import
torch
from
espnet.nets.scorer_interface
import
BatchScorerInterface
class
LengthBonus
(
BatchScorerInterface
):
"""Length bonus in beam search."""
def
__init__
(
self
,
n_vocab
:
int
):
"""Initialize class.
Args:
n_vocab (int): The number of tokens in vocabulary for beam search
"""
self
.
n
=
n_vocab
def
score
(
self
,
y
,
state
,
x
):
"""Score new token.
Args:
y (torch.Tensor): 1D torch.int64 prefix tokens.
state: Scorer state for prefix tokens
x (torch.Tensor): 2D encoder feature that generates ys.
Returns:
tuple[torch.Tensor, Any]: Tuple of
torch.float32 scores for next token (n_vocab)
and None
"""
return
torch
.
tensor
([
1.0
],
device
=
x
.
device
,
dtype
=
x
.
dtype
).
expand
(
self
.
n
),
None
def
batch_score
(
self
,
ys
:
torch
.
Tensor
,
states
:
List
[
Any
],
xs
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
List
[
Any
]]:
"""Score new token batch.
Args:
ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen).
states (List[Any]): Scorer states for prefix tokens.
xs (torch.Tensor):
The encoder feature that generates ys (n_batch, xlen, n_feat).
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return
(
torch
.
tensor
([
1.0
],
device
=
xs
.
device
,
dtype
=
xs
.
dtype
).
expand
(
ys
.
shape
[
0
],
self
.
n
),
None
,
)
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/ngram.py
0 → 100644
View file @
60a2c57a
"""Ngram lm implement."""
from
abc
import
ABC
import
kenlm
import
torch
from
espnet.nets.scorer_interface
import
BatchScorerInterface
,
PartialScorerInterface
class
Ngrambase
(
ABC
):
"""Ngram base implemented through ScorerInterface."""
def
__init__
(
self
,
ngram_model
,
token_list
):
"""Initialize Ngrambase.
Args:
ngram_model: ngram model path
token_list: token list from dict or model.json
"""
self
.
chardict
=
[
x
if
x
!=
"<eos>"
else
"</s>"
for
x
in
token_list
]
self
.
charlen
=
len
(
self
.
chardict
)
self
.
lm
=
kenlm
.
LanguageModel
(
ngram_model
)
self
.
tmpkenlmstate
=
kenlm
.
State
()
def
init_state
(
self
,
x
):
"""Initialize tmp state."""
state
=
kenlm
.
State
()
self
.
lm
.
NullContextWrite
(
state
)
return
state
def
score_partial_
(
self
,
y
,
next_token
,
state
,
x
):
"""Score interface for both full and partial scorer.
Args:
y: previous char
next_token: next token need to be score
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
out_state
=
kenlm
.
State
()
ys
=
self
.
chardict
[
y
[
-
1
]]
if
y
.
shape
[
0
]
>
1
else
"<s>"
self
.
lm
.
BaseScore
(
state
,
ys
,
out_state
)
scores
=
torch
.
empty_like
(
next_token
,
dtype
=
x
.
dtype
,
device
=
y
.
device
)
for
i
,
j
in
enumerate
(
next_token
):
scores
[
i
]
=
self
.
lm
.
BaseScore
(
out_state
,
self
.
chardict
[
j
],
self
.
tmpkenlmstate
)
return
scores
,
out_state
class
NgramFullScorer
(
Ngrambase
,
BatchScorerInterface
):
"""Fullscorer for ngram."""
def
score
(
self
,
y
,
state
,
x
):
"""Score interface for both full and partial scorer.
Args:
y: previous char
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return
self
.
score_partial_
(
y
,
torch
.
tensor
(
range
(
self
.
charlen
)),
state
,
x
)
class
NgramPartScorer
(
Ngrambase
,
PartialScorerInterface
):
"""Partialscorer for ngram."""
def
score_partial
(
self
,
y
,
next_token
,
state
,
x
):
"""Score interface for both full and partial scorer.
Args:
y: previous char
next_token: next token need to be score
state: previous state
x: encoded feature
Returns:
tuple[torch.Tensor, List[Any]]: Tuple of
batchfied scores for next token with shape of `(n_batch, n_vocab)`
and next state list for ys.
"""
return
self
.
score_partial_
(
y
,
next_token
,
state
,
x
)
def
select_state
(
self
,
state
,
i
):
"""Empty select state for scorer interface."""
return
state
conformer/espnet-v.202304_20240621/build/lib/espnet/nets/scorers/uasr.py
0 → 100644
View file @
60a2c57a
"""ScorerInterface implementation for UASR."""
import
numpy
as
np
import
torch
from
espnet.nets.ctc_prefix_score
import
CTCPrefixScore
,
CTCPrefixScoreTH
from
espnet.nets.scorers.ctc
import
CTCPrefixScorer
class
UASRPrefixScorer
(
CTCPrefixScorer
):
"""Decoder interface wrapper for CTCPrefixScore."""
def
__init__
(
self
,
eos
:
int
):
"""Initialize class."""
self
.
eos
=
eos
def
init_state
(
self
,
x
:
torch
.
Tensor
):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
x
[:,
0
]
=
x
[:,
0
]
-
100000000000
# simulate a no-blank CTC
self
.
logp
=
(
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=
1
).
detach
().
squeeze
(
0
).
cpu
().
numpy
()
)
# TODO(karita): use CTCPrefixScoreTH
self
.
impl
=
CTCPrefixScore
(
self
.
logp
,
0
,
self
.
eos
,
np
)
return
0
,
self
.
impl
.
initial_state
()
def
batch_init_state
(
self
,
x
:
torch
.
Tensor
):
"""Get an initial state for decoding.
Args:
x (torch.Tensor): The encoded feature tensor
Returns: initial state
"""
x
[:,
0
]
=
x
[:,
0
]
-
100000000000
# simulate a no-blank CTC
logp
=
torch
.
nn
.
functional
.
log_softmax
(
x
,
dim
=
1
).
unsqueeze
(
0
)
# assuming batch_size = 1
xlen
=
torch
.
tensor
([
logp
.
size
(
1
)])
self
.
impl
=
CTCPrefixScoreTH
(
logp
,
xlen
,
0
,
self
.
eos
)
return
None
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment