Commit 17177e73 authored by Rémi Louf's avatar Rémi Louf
Browse files

add is_decoder as an attribute to Config class

parent df85a0ff
......@@ -322,7 +322,7 @@ class BertLayer(nn.Module):
def __init__(self, config):
super(BertLayer, self).__init__()
self.self_attention = BertAttention(config)
if config.get("is_decoder", False):
if getattr(config, "is_decoder", False):
self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config)
self.output = BertOutput(config)
......@@ -380,7 +380,7 @@ class BertEncoder(nn.Module):
class BertDecoder(nn.Module):
def __init__(self, config):
super(BertDecoder, self).__init__()
config["is_decoder"] = True
config.is_decoder = True
self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment