Commit 6994f168 authored by Hang Zhang's avatar Hang Zhang Committed by Facebook GitHub Bot
Browse files

Remove Pre-norm option, since it is not used

Summary: As in the tittle

Reviewed By: XiaoliangDai

Differential Revision: D33413849

fbshipit-source-id: b891849c175edc7b8916bff2fcc40c76c4658f14
parent 9200cbe8
...@@ -203,12 +203,12 @@ class TransformerEncoderLayer(nn.Module): ...@@ -203,12 +203,12 @@ class TransformerEncoderLayer(nn.Module):
self.dropout2 = nn.Dropout(dropout) self.dropout2 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation) self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before assert not normalize_before, "normalize_before is not supported"
def with_pos_embed(self, tensor, pos: Optional[Tensor]): def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_post( def forward(
self, self,
src, src,
src_mask: Optional[Tensor] = None, src_mask: Optional[Tensor] = None,
...@@ -228,35 +228,6 @@ class TransformerEncoderLayer(nn.Module): ...@@ -228,35 +228,6 @@ class TransformerEncoderLayer(nn.Module):
src = self.norm2(src) src = self.norm2(src)
return src return src
def forward_pre(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
src2 = self.norm1(src)
q = k = self.with_pos_embed(src2, pos)
src2 = self.self_attn(
q, k, src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
)[0]
src = src + self.dropout1(src2)
src2 = self.norm2(src)
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
src = src + self.dropout2(src2)
return src
def forward(
self,
src,
src_mask: Optional[Tensor] = None,
src_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
class TransformerDecoderLayer(nn.Module): class TransformerDecoderLayer(nn.Module):
def __init__( def __init__(
...@@ -284,12 +255,12 @@ class TransformerDecoderLayer(nn.Module): ...@@ -284,12 +255,12 @@ class TransformerDecoderLayer(nn.Module):
self.dropout3 = nn.Dropout(dropout) self.dropout3 = nn.Dropout(dropout)
self.activation = _get_activation_fn(activation) self.activation = _get_activation_fn(activation)
self.normalize_before = normalize_before assert not normalize_before, "normalize_before is not supported"
def with_pos_embed(self, tensor, pos: Optional[Tensor]): def with_pos_embed(self, tensor, pos: Optional[Tensor]):
return tensor if pos is None else tensor + pos return tensor if pos is None else tensor + pos
def forward_post( def forward(
self, self,
tgt, tgt,
memory, memory,
...@@ -327,70 +298,6 @@ class TransformerDecoderLayer(nn.Module): ...@@ -327,70 +298,6 @@ class TransformerDecoderLayer(nn.Module):
# return tgt shape (L, B, C) # return tgt shape (L, B, C)
return tgt return tgt
def forward_pre(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
tgt2 = self.norm1(tgt)
q = k = self.with_pos_embed(tgt2, query_pos)
tgt2 = self.self_attn(
q, k, tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
)[0]
tgt = tgt + self.dropout1(tgt2)
tgt2 = self.norm2(tgt)
tgt2 = self.multihead_attn(
self.with_pos_embed(tgt2, query_pos),
self.with_pos_embed(memory, pos),
memory,
attn_mask=memory_mask,
key_padding_mask=memory_key_padding_mask,
)[0]
tgt = tgt + self.dropout2(tgt2)
tgt2 = self.norm3(tgt)
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
tgt = tgt + self.dropout3(tgt2)
return tgt
def forward(
self,
tgt,
memory,
tgt_mask: Optional[Tensor] = None,
memory_mask: Optional[Tensor] = None,
tgt_key_padding_mask: Optional[Tensor] = None,
memory_key_padding_mask: Optional[Tensor] = None,
pos: Optional[Tensor] = None,
query_pos: Optional[Tensor] = None,
):
if self.normalize_before:
return self.forward_pre(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
return self.forward_post(
tgt,
memory,
tgt_mask,
memory_mask,
tgt_key_padding_mask,
memory_key_padding_mask,
pos,
query_pos,
)
def _get_clones(module, N): def _get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment