Commit 8ade2040 authored by thomwolf's avatar thomwolf
Browse files

fix tf

parent 47f0e3cf
...@@ -590,6 +590,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel): ...@@ -590,6 +590,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
def __init__(self, config): def __init__(self, config):
super(OpenAIGPTDoubleHeadsModel, self).__init__(config) super(OpenAIGPTDoubleHeadsModel, self).__init__(config)
config.num_labels = 1
self.transformer = OpenAIGPTModel(config) self.transformer = OpenAIGPTModel(config)
self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
self.multiple_choice_head = SequenceSummary(config) self.multiple_choice_head = SequenceSummary(config)
......
...@@ -538,6 +538,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel): ...@@ -538,6 +538,7 @@ class TFOpenAIGPTDoubleHeadsModel(TFOpenAIGPTPreTrainedModel):
""" """
def __init__(self, config, *inputs, **kwargs): def __init__(self, config, *inputs, **kwargs):
super(TFOpenAIGPTDoubleHeadsModel, self).__init__(config, *inputs, **kwargs) super(TFOpenAIGPTDoubleHeadsModel, self).__init__(config, *inputs, **kwargs)
config.num_labels = 1
self.transformer = TFOpenAIGPTMainLayer(config, name='transformer') self.transformer = TFOpenAIGPTMainLayer(config, name='transformer')
self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head') self.multiple_choice_head = TFSequenceSummary(config, initializer_range=config.initializer_range, name='multiple_choice_head')
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment