Commit 2aef2f0b authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[common attributes] Fix previous commit for transfo-xl

parent 2f174642
...@@ -72,6 +72,7 @@ if is_torch_available(): ...@@ -72,6 +72,7 @@ if is_torch_available():
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel, OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel, from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
AdaptiveEmbedding,
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP) load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model, from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2DoubleHeadsModel,
......
...@@ -35,7 +35,7 @@ if is_torch_available(): ...@@ -35,7 +35,7 @@ if is_torch_available():
import torch import torch
import numpy as np import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel, from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP, BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP) GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else: else:
...@@ -470,7 +470,7 @@ class CommonTestCases: ...@@ -470,7 +470,7 @@ class CommonTestCases:
model = model_class(config) model = model_class(config)
self.assertIsInstance( self.assertIsInstance(
model.get_input_embeddings(), model.get_input_embeddings(),
torch.nn.Embedding (torch.nn.Embedding, AdaptiveEmbedding)
) )
model.set_input_embeddings(torch.nn.Embedding(10, 10)) model.set_input_embeddings(torch.nn.Embedding(10, 10))
x = model.get_output_embeddings() x = model.get_output_embeddings()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment