Commit 076602bd authored by Rémi Louf's avatar Rémi Louf Committed by Julien Chaumond
Browse files

prevent BERT weights from being downloaded twice

parent 5909f710
......@@ -158,7 +158,8 @@ class Bert(nn.Module):
def __init__(self):
super(Bert, self).__init__()
self.model = BertModel.from_pretrained("bert-base-uncased")
config = BertConfig.from_pretrained("bert-base-uncased")
self.model = BertModel(config)
def forward(self, input_ids, attention_mask=None, token_type_ids=None, **kwargs):
self.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment