Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
EfficientConformer_pytorch
Commits
b28a0aaa
Commit
b28a0aaa
authored
Dec 03, 2021
by
burchim
Browse files
bug fix
parent
2f59ed25
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
30 deletions
+30
-30
models/attentions.py
models/attentions.py
+19
-19
models/encoders.py
models/encoders.py
+2
-2
models/modules.py
models/modules.py
+7
-7
models/schedules.py
models/schedules.py
+2
-2
No files found.
models/attentions.py
View file @
b28a0aaa
...
...
@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
def
pad
(
self
,
Q
,
K
,
V
,
mask
,
chunk_size
):
...
...
@@ -204,7 +204,7 @@ class GroupedMultiHeadAttention(MultiHeadAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
class
LocalMultiHeadAttention
(
MultiHeadAttention
):
...
...
@@ -280,14 +280,14 @@ class LocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
class
StridedMultiHeadAttention
(
MultiHeadAttention
):
"""Strided Mutli-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding
the attention query be
d
ore aplying scaled dot-product attention. This results in
the attention query be
f
ore aplying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence
context to perform downsampling.
...
...
@@ -314,7 +314,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
mask
=
mask
[:,
:,
::
self
.
stride
]
# Multi-Head Attention
return
super
(
StridedMultiHeadAttention
).
forward
(
Q
,
K
,
V
,
mask
)
return
super
(
StridedMultiHeadAttention
,
self
).
forward
(
Q
,
K
,
V
,
mask
)
class
StridedLocalMultiHeadAttention
(
MultiHeadAttention
):
...
...
@@ -393,7 +393,7 @@ class StridedLocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
return
O
,
att_w
.
detach
()
class
MultiHeadLinearAttention
(
MultiHeadAttention
):
...
...
@@ -442,7 +442,7 @@ class MultiHeadLinearAttention(MultiHeadAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
KV
return
O
,
KV
.
detach
()
###############################################################################
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
...
...
@@ -578,7 +578,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
V
=
torch
.
cat
([
hidden
[
"V"
],
V
],
dim
=
1
)
# Update Hidden State
hidden
=
{
"K"
:
K
,
"V"
:
V
}
hidden
=
{
"K"
:
K
.
detach
(),
"V"
:
V
.
detach
()
}
# Add Bias
Qu
=
Q
+
self
.
u
...
...
@@ -617,7 +617,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
GroupedRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
...
@@ -660,12 +660,12 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V
=
torch
.
cat
([
hidden
[
"V"
][:,
hidden
[
"V"
].
size
(
1
)
%
self
.
group_size
:],
V
],
dim
=
1
)
# Update Hidden State
hidden
=
{
"K"
:
Kh
,
"V"
:
Vh
}
hidden
=
{
"K"
:
Kh
.
detach
(),
"V"
:
Vh
.
detach
()
}
else
:
# Update Hidden State
hidden
=
{
"K"
:
K
,
"V"
:
V
}
hidden
=
{
"K"
:
K
.
detach
(),
"V"
:
V
.
detach
()
}
# Chunk Padding
Q
,
K
,
V
,
mask
,
padding
=
self
.
pad
(
Q
,
K
,
V
,
mask
,
chunk_size
=
self
.
group_size
)
...
...
@@ -715,7 +715,7 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
LocalRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
...
@@ -861,7 +861,7 @@ class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
StridedRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
...
@@ -954,18 +954,18 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V
=
torch
.
cat
([
hidden
[
"V"
],
V
],
dim
=
1
)
# Update Hidden State
hidden
=
{
"K"
:
K
,
"V"
:
V
}
hidden
=
{
"K"
:
K
.
detach
(),
"V"
:
V
.
detach
()
}
# Chunk Padding
Q
,
K
,
V
,
mask
,
_
=
self
.
pad
(
Q
,
K
,
V
,
mask
,
chunk_size
=
self
.
stride
)
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q
=
Q
[:,
::
self
.
stride
]
# Add Bias
Qu
=
Q
+
self
.
u
Qv
=
Q
+
self
.
v
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q
=
Q
[:,
::
self
.
stride
]
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E
=
self
.
pos_layer
(
self
.
rel_pos_enc
(
batch_size
,
self
.
stride
*
Q
.
size
(
1
),
K
.
size
(
1
)
-
self
.
stride
*
Q
.
size
(
1
)))
...
...
@@ -1005,7 +1005,7 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
class
StridedLocalRelPosMultiHeadSelfAttention
(
RelPosMultiHeadSelfAttention
):
...
...
@@ -1154,7 +1154,7 @@ class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O
=
self
.
output_layer
(
O
)
return
O
,
att_w
,
hidden
return
O
,
att_w
.
detach
()
,
hidden
###############################################################################
# Positional Encodings
...
...
models/encoders.py
View file @
b28a0aaa
...
...
@@ -137,7 +137,7 @@ class ConformerEncoder(nn.Module):
# Update Seq Lengths
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
block
.
stride
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
block
.
stride
,
rounding_mode
=
'floor'
)
+
1
return
x
,
x_len
,
attentions
...
...
@@ -204,7 +204,7 @@ class ConformerEncoderInterCTC(ConformerEncoder):
# Update Seq Lengths
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
block
.
stride
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
block
.
stride
,
rounding_mode
=
'floor'
)
+
1
# Inter CTC Block
if
block_id
in
self
.
interctc_blocks
:
...
...
models/modules.py
View file @
b28a0aaa
...
...
@@ -97,7 +97,7 @@ class AudioPreprocessing(nn.Module):
# Compute Sequence lengths
if
x_len
is
not
None
:
x_len
=
x_len
//
self
.
hop_length
+
1
x_len
=
torch
.
div
(
x_len
,
self
.
hop_length
,
rounding_mode
=
'floor'
)
+
1
# Normalize
if
self
.
normalize
:
...
...
@@ -194,7 +194,7 @@ class Conv1dSubsampling(nn.Module):
# Update Sequence Lengths
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
return
x
,
x_len
...
...
@@ -240,7 +240,7 @@ class Conv2dSubsampling(nn.Module):
# Update Sequence Lengths
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
...
...
@@ -291,7 +291,7 @@ class Conv2dPoolSubsampling(nn.Module):
# Update Sequence Lengths
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
...
...
@@ -347,7 +347,7 @@ class VGGSubsampling(nn.Module):
# Update Sequence Lengths
if
x_len
is
not
None
:
x_len
=
x_len
//
2
x_len
=
torch
.
div
(
x_len
,
2
,
rounding_mode
=
'floor'
)
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size
,
channels
,
subsampled_dim
,
subsampled_length
=
x
.
size
()
...
...
@@ -589,8 +589,8 @@ class ContextNetSubsampling(nn.Module):
# Update Sequence Lengths
if
x_len
is
not
None
:
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
(
x_len
-
1
)
//
2
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
x_len
=
torch
.
div
(
x_len
-
1
,
2
,
rounding_mode
=
'floor'
)
+
1
return
x
,
x_len
...
...
models/schedules.py
View file @
b28a0aaa
...
...
@@ -88,10 +88,10 @@ class cosine_annealing_learning_rate_scheduler:
s
=
self
.
model_step
+
1
# Compute LR
if
s
elf
.
model_step
<=
self
.
warmup_steps
:
# Warmup phase
if
s
<=
self
.
warmup_steps
:
# Warmup phase
lr
=
s
/
self
.
warmup_steps
*
self
.
lr_max
else
:
# Annealing phase
lr
=
(
self
.
lr_max
-
self
.
lr_min
)
*
0.5
*
(
1
+
math
.
cos
(
math
.
pi
*
(
s
elf
.
model_step
-
self
.
warmup_steps
)
/
(
self
.
end_step
-
self
.
warmup_steps
)))
+
self
.
lr_min
lr
=
(
self
.
lr_max
-
self
.
lr_min
)
*
0.5
*
(
1
+
math
.
cos
(
math
.
pi
*
(
s
-
self
.
warmup_steps
)
/
(
self
.
end_step
-
self
.
warmup_steps
)))
+
self
.
lr_min
# Update LR
self
.
optimizer
.
param_groups
[
0
][
'lr'
]
=
lr
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment