Commit 8ade2040 authored by thomwolf's avatar thomwolf
Browse files

fix tf

parent 47f0e3cf
......@@ -590,6 +590,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
config.num_labels = 1
self.transformer = OpenAIGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config)
......
......@@ -538,6 +538,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
"""
def __init__(self, config, *inputs, **kwargs):
super(TFOpenAIGPTDoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
config.num_labels = 1
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment