Commit b28a0aaa authored by burchim's avatar burchim
Browse files

bug fix

parent 2f59ed25
...@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module): ...@@ -102,7 +102,7 @@ class MultiHeadAttention(nn.Module):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w return O, att_w.detach()
def pad(self, Q, K, V, mask, chunk_size): def pad(self, Q, K, V, mask, chunk_size):
...@@ -204,7 +204,7 @@ class GroupedMultiHeadAttention(MultiHeadAttention): ...@@ -204,7 +204,7 @@ class GroupedMultiHeadAttention(MultiHeadAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w return O, att_w.detach()
class LocalMultiHeadAttention(MultiHeadAttention): class LocalMultiHeadAttention(MultiHeadAttention):
...@@ -280,14 +280,14 @@ class LocalMultiHeadAttention(MultiHeadAttention): ...@@ -280,14 +280,14 @@ class LocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w return O, att_w.detach()
class StridedMultiHeadAttention(MultiHeadAttention): class StridedMultiHeadAttention(MultiHeadAttention):
"""Strided Mutli-Head Attention Layer """Strided Mutli-Head Attention Layer
Strided multi-head attention performs global sequence downsampling by striding Strided multi-head attention performs global sequence downsampling by striding
the attention query bedore aplying scaled dot-product attention. This results in the attention query before aplying scaled dot-product attention. This results in
strided attention maps where query positions can attend to the entire sequence strided attention maps where query positions can attend to the entire sequence
context to perform downsampling. context to perform downsampling.
...@@ -314,7 +314,7 @@ class StridedMultiHeadAttention(MultiHeadAttention): ...@@ -314,7 +314,7 @@ class StridedMultiHeadAttention(MultiHeadAttention):
mask = mask[:, :, ::self.stride] mask = mask[:, :, ::self.stride]
# Multi-Head Attention # Multi-Head Attention
return super(StridedMultiHeadAttention).forward(Q, K, V, mask) return super(StridedMultiHeadAttention, self).forward(Q, K, V, mask)
class StridedLocalMultiHeadAttention(MultiHeadAttention): class StridedLocalMultiHeadAttention(MultiHeadAttention):
...@@ -393,7 +393,7 @@ class StridedLocalMultiHeadAttention(MultiHeadAttention): ...@@ -393,7 +393,7 @@ class StridedLocalMultiHeadAttention(MultiHeadAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w return O, att_w.detach()
class MultiHeadLinearAttention(MultiHeadAttention): class MultiHeadLinearAttention(MultiHeadAttention):
...@@ -442,7 +442,7 @@ class MultiHeadLinearAttention(MultiHeadAttention): ...@@ -442,7 +442,7 @@ class MultiHeadLinearAttention(MultiHeadAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, KV return O, KV.detach()
############################################################################### ###############################################################################
# Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings # Multi-Head Self-Attention Layers with Relative Sinusoidal Poditional Encodings
...@@ -578,7 +578,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention): ...@@ -578,7 +578,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
V = torch.cat([hidden["V"], V], dim=1) V = torch.cat([hidden["V"], V], dim=1)
# Update Hidden State # Update Hidden State
hidden = {"K": K, "V": V} hidden = {"K": K.detach(), "V": V.detach()}
# Add Bias # Add Bias
Qu = Q + self.u Qu = Q + self.u
...@@ -617,7 +617,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention): ...@@ -617,7 +617,7 @@ class RelPosMultiHeadSelfAttention(MultiHeadAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w, hidden return O, att_w.detach(), hidden
class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...@@ -660,12 +660,12 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): ...@@ -660,12 +660,12 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V = torch.cat([hidden["V"][:, hidden["V"].size(1)%self.group_size:], V], dim=1) V = torch.cat([hidden["V"][:, hidden["V"].size(1)%self.group_size:], V], dim=1)
# Update Hidden State # Update Hidden State
hidden = {"K": Kh, "V": Vh} hidden = {"K": Kh.detach(), "V": Vh.detach()}
else: else:
# Update Hidden State # Update Hidden State
hidden = {"K": K, "V": V} hidden = {"K": K.detach(), "V": V.detach()}
# Chunk Padding # Chunk Padding
Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.group_size) Q, K, V, mask, padding = self.pad(Q, K, V, mask, chunk_size=self.group_size)
...@@ -715,7 +715,7 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): ...@@ -715,7 +715,7 @@ class GroupedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w, hidden return O, att_w.detach(), hidden
class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...@@ -861,7 +861,7 @@ class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): ...@@ -861,7 +861,7 @@ class LocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w, hidden return O, att_w.detach(), hidden
class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...@@ -954,18 +954,18 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): ...@@ -954,18 +954,18 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
V = torch.cat([hidden["V"], V], dim=1) V = torch.cat([hidden["V"], V], dim=1)
# Update Hidden State # Update Hidden State
hidden = {"K": K, "V": V} hidden = {"K": K.detach(), "V": V.detach()}
# Chunk Padding # Chunk Padding
Q, K, V, mask, _ = self.pad(Q, K, V, mask, chunk_size=self.stride) Q, K, V, mask, _ = self.pad(Q, K, V, mask, chunk_size=self.stride)
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Add Bias # Add Bias
Qu = Q + self.u Qu = Q + self.u
Qv = Q + self.v Qv = Q + self.v
# Query Subsampling (B, T, D) -> (B, T//S, D)
Q = Q[:, ::self.stride]
# Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D) # Relative Positional Embeddings (B, Th + 2*T-1, D) / (B, Th + T, D)
E = self.pos_layer(self.rel_pos_enc(batch_size, self.stride * Q.size(1), K.size(1) - self.stride * Q.size(1))) E = self.pos_layer(self.rel_pos_enc(batch_size, self.stride * Q.size(1), K.size(1) - self.stride * Q.size(1)))
...@@ -1005,7 +1005,7 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): ...@@ -1005,7 +1005,7 @@ class StridedRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w, hidden return O, att_w.detach(), hidden
class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
...@@ -1154,7 +1154,7 @@ class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention): ...@@ -1154,7 +1154,7 @@ class StridedLocalRelPosMultiHeadSelfAttention(RelPosMultiHeadSelfAttention):
# Output linear layer # Output linear layer
O = self.output_layer(O) O = self.output_layer(O)
return O, att_w, hidden return O, att_w.detach(), hidden
############################################################################### ###############################################################################
# Positional Encodings # Positional Encodings
......
...@@ -137,7 +137,7 @@ class ConformerEncoder(nn.Module): ...@@ -137,7 +137,7 @@ class ConformerEncoder(nn.Module):
# Update Seq Lengths # Update Seq Lengths
if x_len is not None: if x_len is not None:
x_len = (x_len - 1) // block.stride + 1 x_len = torch.div(x_len - 1, block.stride, rounding_mode='floor') + 1
return x, x_len, attentions return x, x_len, attentions
...@@ -204,7 +204,7 @@ class ConformerEncoderInterCTC(ConformerEncoder): ...@@ -204,7 +204,7 @@ class ConformerEncoderInterCTC(ConformerEncoder):
# Update Seq Lengths # Update Seq Lengths
if x_len is not None: if x_len is not None:
x_len = (x_len - 1) // block.stride + 1 x_len = torch.div(x_len - 1, block.stride, rounding_mode='floor') + 1
# Inter CTC Block # Inter CTC Block
if block_id in self.interctc_blocks: if block_id in self.interctc_blocks:
......
...@@ -97,7 +97,7 @@ class AudioPreprocessing(nn.Module): ...@@ -97,7 +97,7 @@ class AudioPreprocessing(nn.Module):
# Compute Sequence lengths # Compute Sequence lengths
if x_len is not None: if x_len is not None:
x_len = x_len // self.hop_length + 1 x_len = torch.div(x_len, self.hop_length, rounding_mode='floor') + 1
# Normalize # Normalize
if self.normalize: if self.normalize:
...@@ -194,7 +194,7 @@ class Conv1dSubsampling(nn.Module): ...@@ -194,7 +194,7 @@ class Conv1dSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = (x_len - 1) // 2 + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
return x, x_len return x, x_len
...@@ -240,7 +240,7 @@ class Conv2dSubsampling(nn.Module): ...@@ -240,7 +240,7 @@ class Conv2dSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = (x_len - 1) // 2 + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
...@@ -291,7 +291,7 @@ class Conv2dPoolSubsampling(nn.Module): ...@@ -291,7 +291,7 @@ class Conv2dPoolSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = (x_len - 1) // 2 + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
...@@ -347,7 +347,7 @@ class VGGSubsampling(nn.Module): ...@@ -347,7 +347,7 @@ class VGGSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = x_len // 2 x_len = torch.div(x_len, 2, rounding_mode='floor')
# (B, C, D // S, T // S) -> (B, C * D // S, T // S) # (B, C, D // S, T // S) -> (B, C * D // S, T // S)
batch_size, channels, subsampled_dim, subsampled_length = x.size() batch_size, channels, subsampled_dim, subsampled_length = x.size()
...@@ -589,8 +589,8 @@ class ContextNetSubsampling(nn.Module): ...@@ -589,8 +589,8 @@ class ContextNetSubsampling(nn.Module):
# Update Sequence Lengths # Update Sequence Lengths
if x_len is not None: if x_len is not None:
x_len = (x_len - 1) // 2 + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
x_len = (x_len - 1) // 2 + 1 x_len = torch.div(x_len - 1, 2, rounding_mode='floor') + 1
return x, x_len return x, x_len
......
...@@ -88,10 +88,10 @@ class cosine_annealing_learning_rate_scheduler: ...@@ -88,10 +88,10 @@ class cosine_annealing_learning_rate_scheduler:
s = self.model_step + 1 s = self.model_step + 1
# Compute LR # Compute LR
if self.model_step <= self.warmup_steps: # Warmup phase if s <= self.warmup_steps: # Warmup phase
lr = s / self.warmup_steps * self.lr_max lr = s / self.warmup_steps * self.lr_max
else: # Annealing phase else: # Annealing phase
lr = (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * (self.model_step - self.warmup_steps) / (self.end_step - self.warmup_steps))) + self.lr_min lr = (self.lr_max - self.lr_min) * 0.5 * (1 + math.cos(math.pi * (s - self.warmup_steps) / (self.end_step - self.warmup_steps))) + self.lr_min
# Update LR # Update LR
self.optimizer.param_groups[0]['lr'] = lr self.optimizer.param_groups[0]['lr'] = lr
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment