Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
EfficientConformer_pytorch
Commits
b28a0aaa
Commit
b28a0aaa
authored
Dec 03, 2021
by
burchim
Browse files
bug fix
parent
2f59ed25
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
30 deletions
+30
-30
models/attentions.py
models/attentions.py
+19
-19
models/encoders.py
models/encoders.py
+2
-2
models/modules.py
models/modules.py
+7
-7
models/schedules.py
models/schedules.py
+2
-2
No files found.
models/attentions.py
View file @
b28a0aaa
...
@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
def
pad
(
self
,
Q
,
K
,
V
,
mask
,
chunk_size
):
def
pad
(
self
,
Q
,
K
,
V
,
mask
,
chunk_size
):
...
@@ -204,7 +204,7 @@ class GroupedMultiHeadAttention(MultiHeadAttention):
...
@@ -204,7 +204,7 @@ class GroupedMultiHeadAttention(MultiHeadAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
class
LocalMultiHeadAttention
(
MultiHeadAttention
):
class
LocalMultiHeadAttention
(
MultiHeadAttention
):
...
@@ -280,14 +280,14 @@ class LocalMultiHeadAttention(MultiHeadAttention):
...
@@ -280,14 +280,14 @@ class LocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
class
StridedMultiHeadAttention
(
MultiHeadAttention
):
class
StridedMultiHeadAttention
(
MultiHeadAttention
):
"""Strided Mutli-Head Attention Layer
"""Strided Mutli-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding
Strided multi-head attention performs global sequence downsampling by striding
the attention query be
d
ore aplying scaled dot-product attention. This results in
the attention query be
f
ore aplying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence
strided attention maps where query positions can attend to the entire sequence
context to perform downsampling.
context to perform downsampling.
...
@@ -314,7 +314,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
...
@@ -314,7 +314,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
mask
=
mask
[:,
:,
::
self
.
stride
]
mask
=
mask
[:,
:,
::
self
.
stride
]
# Multi-Head Attention
# Multi-Head Attention
return
super
(
StridedMultiHeadAttention
).
forward
(
Q
,
K
,
V
,
mask
)
return
super
(
StridedMultiHeadAttention
,
self
).
forward
(
Q
,
K
,
V
,
mask
)
class
StridedLocalMultiHeadAttention
(
MultiHeadAttention
):
class
StridedLocalMultiHeadAttention
(
MultiHeadAttention
):
...
@@ -393,7 +393,7 @@ class StridedLocalMultiHeadAttention(MultiHeadAttention):
...
@@ -393,7 +393,7 @@ class StridedLocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
class
MultiHeadLinearAttention
(
MultiHeadAttention
):
class
MultiHeadLinearAttention
(
MultiHeadAttention
):
...
@@ -442,7 +442,7 @@ class MultiHeadLinearAttention(MultiHeadAttention):
...
@@ -442,7 +442,7 @@ class MultiHeadLinearAttention(MultiHeadAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
KV
return
O
,
KV
.
detach
()
###############################################################################
###############################################################################
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
...
@@ -578,7 +578,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
...
@@ -578,7 +578,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
V
=
torch
.
cat
([
hidden
[
"V"
],
V
],
dim
=
1
)
V
=
torch
.
cat
([
hidden
[
"V"
],
V
],
dim
=
1
)
# Update Hidden State
# Update Hidden State
hidden
=
{
"K"
:
K
,
"V"
:
V
}
hidden
=
{
"K"
:
K
.
detach
(),
"V"
:
V
.
detach
()
}
# Add Bias
# Add Bias
Qu
=
Q
+
self
.
u
Qu
=
Q
+
self
.
u
...
@@ -617,7 +617,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
...
@@ -617,7 +617,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
GroupedRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
class
GroupedRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
@@ -660,12 +660,12 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...
@@ -660,12 +660,12 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V
=
torch
.
cat
([
hidden
[
"V"
][:,
hidden
[
"V"
].
size
(
1
)
%
self
.
group_size
:],
V
],
dim
=
1
)
V
=
torch
.
cat
([
hidden
[
"V"
][:,
hidden
[
"V"
].
size
(
1
)
%
self
.
group_size
:],
V
],
dim
=
1
)
# Update Hidden State
# Update Hidden State
hidden
=
{
"K"
:
Kh
,
"V"
:
Vh
}
hidden
=
{
"K"
:
Kh
.
detach
(),
"V"
:
Vh
.
detach
()
}
else
:
else
:
# Update Hidden State
# Update Hidden State
hidden
=
{
"K"
:
K
,
"V"
:
V
}
hidden
=
{
"K"
:
K
.
detach
(),
"V"
:
V
.
detach
()
}
# Chunk Padding
# Chunk Padding
Q
,
K
,
V
,
mask
,
padding
=
self
.
pad
(
Q
,
K
,
V
,
mask
,
chunk_size
=
self
.
group_size
)
Q
,
K
,
V
,
mask
,
padding
=
self
.
pad
(
Q
,
K
,
V
,
mask
,
chunk_size
=
self
.
group_size
)
...
@@ -715,7 +715,7 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...
@@ -715,7 +715,7 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
LocalRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
class
LocalRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
@@ -861,7 +861,7 @@ class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...
@@ -861,7 +861,7 @@ class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
StridedRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
class
StridedRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
@@ -954,18 +954,18 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...
@@ -954,18 +954,18 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V
=
torch
.
cat
([
hidden
[
"V"
],
V
],
dim
=
1
)
V
=
torch
.
cat
([
hidden
[
"V"
],
V
],
dim
=
1
)
# Update Hidden State
# Update Hidden State
hidden
=
{
"K"
:
K
,
"V"
:
V
}
hidden
=
{
"K"
:
K
.
detach
(),
"V"
:
V
.
detach
()
}
# Chunk Padding
# Chunk Padding
Q
,
K
,
V
,
mask
,
_
=
self
.
pad
(
Q
,
K
,
V
,
mask
,
chunk_size
=
self
.
stride
)
Q
,
K
,
V
,
mask
,
_
=
self
.
pad
(
Q
,
K
,
V
,
mask
,
chunk_size
=
self
.
stride
)
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q
=
Q
[:,
::
self
.
stride
]
# Add Bias
# Add Bias
Qu
=
Q
+
self
.
u
Qu
=
Q
+
self
.
u
Qv
=
Q
+
self
.
v
Qv
=
Q
+
self
.
v
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q
=
Q
[:,
::
self
.
stride
]
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E
=
self
.
pos_layer
(
self
.
rel_pos_enc
(
batch_size
,
self
.
stride
*
Q
.
size
(
1
),
K
.
size
(
1
)
-
self
.
stride
*
Q
.
size
(
1
)))
E
=
self
.
pos_layer
(
self
.
rel_pos_enc
(
batch_size
,
self
.
stride
*
Q
.
size
(
1
),
K
.
size
(
1
)
-
self
.
stride
*
Q
.
size
(
1
)))
...
@@ -1005,7 +1005,7 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...
@@ -1005,7 +1005,7 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
StridedLocalRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
class
StridedLocalRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
@@ -1154,7 +1154,7 @@ class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...
@@ -1154,7 +1154,7 @@ class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
# Output linear layer
O
=
self
.
output_layer
(
O
)
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
###############################################################################
###############################################################################
# Positional Encodings
# Positional Encodings
...
...
models/encoders.py
View file @
b28a0aaa
...
@@ -137,7 +137,7 @@ class ConformerEncoder(nn.Module):
...
@@ -137,7 +137,7 @@ class ConformerEncoder(nn.Module):
# Update Seq Lengths
# Update Seq Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
block
.
stride
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
block
.
stride
,
rounding_mode
=
'floor'
)
+
1
return
x
,
x_len
,
attentions
return
x
,
x_len
,
attentions
...
@@ -204,7 +204,7 @@ class ConformerEncoderInterCTC(ConformerEncoder):
...
@@ -204,7 +204,7 @@ class ConformerEncoderInterCTC(ConformerEncoder):
# Update Seq Lengths
# Update Seq Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
block
.
stride
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
block
.
stride
,
rounding_mode
=
'floor'
)
+
1
# Inter CTC Block
# Inter CTC Block
if
block_id
in
self
.
interctc_blocks
:
if
block_id
in
self
.
interctc_blocks
:
...
...
models/modules.py
View file @
b28a0aaa
...
@@ -97,7 +97,7 @@ class AudioPreprocessing(nn.Module):
...
@@ -97,7 +97,7 @@ class AudioPreprocessing(nn.Module):
# Compute Sequence lengths
# Compute Sequence lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
x_len
//
self
.
hop_length
+
1
x_len
=
torch
.
div
(
x_len
,
self
.
hop_length
,
rounding_mode
=
'floor'
)
+
1
# Normalize
# Normalize
if
self
.
normalize
:
if
self
.
normalize
:
...
@@ -194,7 +194,7 @@ class Conv1dSubsampling(nn.Module):
...
@@ -194,7 +194,7 @@ class Conv1dSubsampling(nn.Module):
# Update Sequence Lengths
# Update Sequence Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
return
x
,
x_len
return
x
,
x_len
...
@@ -240,7 +240,7 @@ class Conv2dSubsampling(nn.Module):
...
@@ -240,7 +240,7 @@ class Conv2dSubsampling(nn.Module):
# Update Sequence Lengths
# Update Sequence Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
...
@@ -291,7 +291,7 @@ class Conv2dPoolSubsampling(nn.Module):
...
@@ -291,7 +291,7 @@ class Conv2dPoolSubsampling(nn.Module):
# Update Sequence Lengths
# Update Sequence Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
...
@@ -347,7 +347,7 @@ class VGGSubsampling(nn.Module):
...
@@ -347,7 +347,7 @@ class VGGSubsampling(nn.Module):
# Update Sequence Lengths
# Update Sequence Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
x_len
//
2
x_len
=
torch
.
div
(
x_len
,
2
,
rounding_mode
=
'floor'
)
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
...
@@ -589,8 +589,8 @@ class ContextNetSubsampling(nn.Module):
...
@@ -589,8 +589,8 @@ class ContextNetSubsampling(nn.Module):
# Update Sequence Lengths
# Update Sequence Lengths
if
x_len
is
not
None
:
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
return
x
,
x_len
return
x
,
x_len
...
...
models/schedules.py
View file @
b28a0aaa
...
@@ -88,10 +88,10 @@ class cosine_annealing_learning_rate_scheduler:
...
@@ -88,10 +88,10 @@ class cosine_annealing_learning_rate_scheduler:
s
=
self
.
model_step
+
1
s
=
self
.
model_step
+
1
# Compute LR
# Compute LR
if
s
elf
.
model_step
<=
self
.
warmup_steps
:
# Warmup phase
if
s
<=
self
.
warmup_steps
:
# Warmup phase
lr
=
s
/
self
.
warmup_steps
*
self
.
lr_max
lr
=
s
/
self
.
warmup_steps
*
self
.
lr_max
else
:
# Annealing phase
else
:
# Annealing phase
lr
=
(
self
.
lr_max
-
self
.
lr_min
)
*
0.5
*
(
1
+
math
.
cos
(
math
.
pi
*
(
s
elf
.
model_step
-
self
.
warmup_steps
)
/
(
self
.
end_step
-
self
.
warmup_steps
)))
+
self
.
lr_min
lr
=
(
self
.
lr_max
-
self
.
lr_min
)
*
0.5
*
(
1
+
math
.
cos
(
math
.
pi
*
(
s
-
self
.
warmup_steps
)
/
(
self
.
end_step
-
self
.
warmup_steps
)))
+
self
.
lr_min
# Update LR
# Update LR
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
=
lr
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
=
lr
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment