Unverified Commit 8d7e0a6d authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

[Refactor]: add init_cfg in transformer base classes (#946)

parent 79f8cbd6
import copy
import warnings
import torch
import torch.nn as nn
from mmcv import ConfigDict
......@@ -53,7 +54,7 @@ class MultiheadAttention(BaseModule):
dropout=0.,
init_cfg=None,
**kwargs):
super(MultiheadAttention, self).__init__()
super(MultiheadAttention, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.num_heads = num_heads
self.dropout = dropout
......@@ -162,7 +163,7 @@ class FFN(BaseModule):
dropout=0.,
add_residual=True,
init_cfg=None):
super(FFN, self).__init__()
super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims
......@@ -193,7 +194,7 @@ class FFN(BaseModule):
"""
out = self.layers(x)
if not self.add_residual:
return out
return self.dropout(out)
if residual is None:
residual = x
return residual + self.dropout(out)
......@@ -246,7 +247,7 @@ class BaseTransformerLayer(BaseModule):
ffn_num_fcs=2,
init_cfg=None):
super(BaseTransformerLayer, self).__init__()
super(BaseTransformerLayer, self).__init__(init_cfg)
assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \
......@@ -338,6 +339,12 @@ class BaseTransformerLayer(BaseModule):
inp_residual = query
if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else:
assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \
......@@ -407,7 +414,7 @@ class TransformerLayerSequence(BaseModule):
"""
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__()
super(TransformerLayerSequence, self).__init__(init_cfg)
if isinstance(transformerlayers, ConfigDict):
transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment