Commit b28a0aaa authored by burchim's avatar burchim
Browse files

bug fix

parent 2f59ed25
......@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
# Output linear layer
O = self.output_layer(O)
return O, att_w
return O, att_w.detach()
def pad(self, Q, K, V, mask, chunk_size):
......@@ -204,7 +204,7 @@ class GroupedMultiHeadAttention(MultiHeadAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w
return O, att_w.detach()
class LocalMultiHeadAttention(MultiHeadAttention):
......@@ -280,14 +280,14 @@ class LocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w
return O, att_w.detach()
class StridedMultiHeadAttention(MultiHeadAttention):
"""Strided Mutli-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding
the attention query bedore aplying scaled dot-product attention. This results in
the attention query before aplying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence
context to perform downsampling.
......@@ -314,7 +314,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
mask = mask[:, :, ::self.stride]
# Multi-Head Attention
return super(StridedMultiHeadAttention).forward(Q, K, V, mask)
return super(StridedMultiHeadAttention, self).forward(Q, K, V, mask)
class StridedLocalMultiHeadAttention(MultiHeadAttention):
......@@ -393,7 +393,7 @@ class StridedLocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w
return O, att_w.detach()
class MultiHeadLinearAttention(MultiHeadAttention):
......@@ -442,7 +442,7 @@ class MultiHeadLinearAttention(MultiHeadAttention):
# Output linear layer
O = self.output_layer(O)
return O, KV
return O, KV.detach()
###############################################################################
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
......@@ -578,7 +578,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
V = torch.cat([hidden["V"], V], dim=1)
# Update Hidden State
hidden = {"K": K, "V": V}
hidden = {"K": K.detach(), "V": V.detach()}
# Add Bias
Qu = Q + self.u
......@@ -617,7 +617,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
return O, att_w.detach(), hidden
class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
......@@ -660,12 +660,12 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V = torch.cat([hidden["V"][:, hidden["V"].size(1)%self.group_size:], V], dim=1)
# Update Hidden State
hidden = {"K": Kh, "V": Vh}
hidden = {"K": Kh.detach(), "V": Vh.detach()}
else:
# Update Hidden State
hidden = {"K": K, "V": V}
hidden = {"K": K.detach(), "V": V.detach()}
# Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.group_size)
......@@ -715,7 +715,7 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
return O, att_w.detach(), hidden
class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
......@@ -861,7 +861,7 @@ class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
return O, att_w.detach(), hidden
class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
......@@ -954,18 +954,18 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V = torch.cat([hidden["V"], V], dim=1)
# Update Hidden State
hidden = {"K": K, "V": V}
hidden = {"K": K.detach(), "V": V.detach()}
# Chunk Padding
Q, K, V, mask, _ = self.pad(Q, K, V, mask, chunk_size=self.stride)
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Add Bias
Qu = Q + self.u
Qv = Q + self.v
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, self.stride * Q.size(1), K.size(1) - self.stride * Q.size(1)))
......@@ -1005,7 +1005,7 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
return O, att_w.detach(), hidden
class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
......@@ -1154,7 +1154,7 @@ class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer
O = self.output_layer(O)
return O, att_w, hidden
return O, att_w.detach(), hidden
###############################################################################
# Positional Encodings
......
......@@ -137,7 +137,7 @@ class ConformerEncoder(nn.Module):
# Update Seq Lengths
if x_len is not None:
x_len = (x_len - 1) // block.stride + 1
x_len = torch.div(x_len - 1, block.stride, rounding_mode='floor') + 1
return x, x_len, attentions
......@@ -204,7 +204,7 @@ class ConformerEncoderInterCTC(ConformerEncoder):
# Update Seq Lengths
if x_len is not None:
x_len = (x_len - 1) // block.stride + 1
x_len = torch.div(x_len - 1, block.stride, rounding_mode='floor') + 1
# Inter CTC Block
if block_id in self.interctc_blocks:
......
......@@ -97,7 +97,7 @@ class AudioPreprocessing(nn.Module):
# Compute Sequence lengths
if x_len is not None:
x_len = x_len // self.hop_length + 1
x_len = torch.div(x_len, self.hop_length, rounding_mode='floor') + 1
# Normalize
if self.normalize:
......@@ -194,7 +194,7 @@ class Conv1dSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
return x, x_len
......@@ -240,7 +240,7 @@ class Conv2dSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
......@@ -291,7 +291,7 @@ class Conv2dPoolSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
......@@ -347,7 +347,7 @@ class VGGSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = x_len // 2
x_len = torch.div(x_len, 2, rounding_mode='floor')
# (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size()
......@@ -589,8 +589,8 @@ class ContextNetSubsampling(nn.Module):
# Update Sequence Lengths
if x_len is not None:
x_len = (x_len - 1) // 2 + 1
x_len = (x_len - 1) // 2 + 1
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
return x, x_len
......
......@@ -88,10 +88,10 @@ class cosine_annealing_learning_rate_scheduler:
s = self.model_step + 1
# Compute LR
if self.model_step <= self.warmup_steps: # Warmup phase
if s <= self.warmup_steps: # Warmup phase
lr = s / self.warmup_steps * self.lr_max
else: # Annealing phase
lr = (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * (self.model_step - self.warmup_steps) / (self.end_step - self.warmup_steps))) + self.lr_min
lr = (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * (s - self.warmup_steps) / (self.end_step - self.warmup_steps))) + self.lr_min
# Update LR
self.optimizer.param_groups[0]['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment