"server/vscode:/vscode.git/clone" did not exist on "d461d955c3ac32efc195ec7871632162cd08e833"
Unverified Commit 8d7e0a6d authored by ZhangShilong's avatar ZhangShilong Committed by GitHub
Browse files

[Refactor]: add init_cfg in transformer base classes (#946)

parent 79f8cbd6
import copy import copy
import warnings import warnings
import torch
import torch.nn as nn import torch.nn as nn
from mmcv import ConfigDict from mmcv import ConfigDict
...@@ -53,7 +54,7 @@ class MultiheadAttention(BaseModule): ...@@ -53,7 +54,7 @@ class MultiheadAttention(BaseModule):
dropout=0., dropout=0.,
init_cfg=None, init_cfg=None,
**kwargs): **kwargs):
super(MultiheadAttention, self).__init__() super(MultiheadAttention, self).__init__(init_cfg)
self.embed_dims = embed_dims self.embed_dims = embed_dims
self.num_heads = num_heads self.num_heads = num_heads
self.dropout = dropout self.dropout = dropout
...@@ -162,7 +163,7 @@ class FFN(BaseModule): ...@@ -162,7 +163,7 @@ class FFN(BaseModule):
dropout=0., dropout=0.,
add_residual=True, add_residual=True,
init_cfg=None): init_cfg=None):
super(FFN, self).__init__() super(FFN, self).__init__(init_cfg)
assert num_fcs >= 2, 'num_fcs should be no less ' \ assert num_fcs >= 2, 'num_fcs should be no less ' \
f'than 2. got {num_fcs}.' f'than 2. got {num_fcs}.'
self.embed_dims = embed_dims self.embed_dims = embed_dims
...@@ -193,7 +194,7 @@ class FFN(BaseModule): ...@@ -193,7 +194,7 @@ class FFN(BaseModule):
""" """
out = self.layers(x) out = self.layers(x)
if not self.add_residual: if not self.add_residual:
return out return self.dropout(out)
if residual is None: if residual is None:
residual = x residual = x
return residual + self.dropout(out) return residual + self.dropout(out)
...@@ -246,7 +247,7 @@ class BaseTransformerLayer(BaseModule): ...@@ -246,7 +247,7 @@ class BaseTransformerLayer(BaseModule):
ffn_num_fcs=2, ffn_num_fcs=2,
init_cfg=None): init_cfg=None):
super(BaseTransformerLayer, self).__init__() super(BaseTransformerLayer, self).__init__(init_cfg)
assert set(operation_order) & set( assert set(operation_order) & set(
['self_attn', 'norm', 'ffn', 'cross_attn']) == \ ['self_attn', 'norm', 'ffn', 'cross_attn']) == \
set(operation_order), f'The operation_order of' \ set(operation_order), f'The operation_order of' \
...@@ -338,6 +339,12 @@ class BaseTransformerLayer(BaseModule): ...@@ -338,6 +339,12 @@ class BaseTransformerLayer(BaseModule):
inp_residual = query inp_residual = query
if attn_masks is None: if attn_masks is None:
attn_masks = [None for _ in range(self.num_attn)] attn_masks = [None for _ in range(self.num_attn)]
elif isinstance(attn_masks, torch.Tensor):
attn_masks = [
copy.deepcopy(attn_masks) for _ in range(self.num_attn)
]
warnings.warn(f'Use same attn_mask in all attentions in '
f'{self.__class__.__name__} ')
else: else:
assert len(attn_masks) == self.num_attn, f'The length of ' \ assert len(attn_masks) == self.num_attn, f'The length of ' \
f'attn_masks {len(attn_masks)} must be equal ' \ f'attn_masks {len(attn_masks)} must be equal ' \
...@@ -407,7 +414,7 @@ class TransformerLayerSequence(BaseModule): ...@@ -407,7 +414,7 @@ class TransformerLayerSequence(BaseModule):
""" """
def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None): def __init__(self, transformerlayers=None, num_layers=None, init_cfg=None):
super(TransformerLayerSequence, self).__init__() super(TransformerLayerSequence, self).__init__(init_cfg)
if isinstance(transformerlayers, ConfigDict): if isinstance(transformerlayers, ConfigDict):
transformerlayers = [ transformerlayers = [
copy.deepcopy(transformerlayers) for _ in range(num_layers) copy.deepcopy(transformerlayers) for _ in range(num_layers)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment