"...git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "d1656d5da8af2943dc5201969cd7eed808d3db8d"
Commit 17177e73 authored by Rémi Louf's avatar Rémi Louf
Browse files

add is_decoder as an attribute to Config class

parent df85a0ff
...@@ -322,7 +322,7 @@ class BertLayer(nn.Module): ...@@ -322,7 +322,7 @@ class BertLayer(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertLayer, self).__init__() super(BertLayer, self).__init__()
self.self_attention = BertAttention(config) self.self_attention = BertAttention(config)
if config.get("is_decoder", False): if getattr(config, "is_decoder", False):
self.attention = BertAttention(config) self.attention = BertAttention(config)
self.intermediate = BertIntermediate(config) self.intermediate = BertIntermediate(config)
self.output = BertOutput(config) self.output = BertOutput(config)
...@@ -380,7 +380,7 @@ class BertEncoder(nn.Module): ...@@ -380,7 +380,7 @@ class BertEncoder(nn.Module):
class BertDecoder(nn.Module): class BertDecoder(nn.Module):
def __init__(self, config): def __init__(self, config):
super(BertDecoder, self).__init__() super(BertDecoder, self).__init__()
config["is_decoder"] = True config.is_decoder = True
self.output_attentions = config.output_attentions self.output_attentions = config.output_attentions
self.output_hidden_states = config.output_hidden_states self.output_hidden_states = config.output_hidden_states
self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)]) self.layers = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment