Commit 2aef2f0b authored by Julien Chaumond's avatar Julien Chaumond
Browse files

[common attributes] Fix previous commit for transfo-xl

parent 2f174642
......@@ -72,6 +72,7 @@ if is_torch_available():
OpenAIGPTLMHeadModel, OpenAIGPTDoubleHeadsModel,
load_tf_weights_in_openai_gpt, OPENAI_GPT_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_transfo_xl import (TransfoXLPreTrainedModel, TransfoXLModel, TransfoXLLMHeadModel,
AdaptiveEmbedding,
load_tf_weights_in_transfo_xl, TRANSFO_XL_PRETRAINED_MODEL_ARCHIVE_MAP)
from .modeling_gpt2 import (GPT2PreTrainedModel, GPT2Model,
GPT2LMHeadModel, GPT2DoubleHeadsModel,
......
......@@ -35,7 +35,7 @@ if is_torch_available():
import torch
import numpy as np
from transformers import (PretrainedConfig, PreTrainedModel,
from transformers import (AdaptiveEmbedding, PretrainedConfig, PreTrainedModel,
BertModel, BertConfig, BERT_PRETRAINED_MODEL_ARCHIVE_MAP,
GPT2LMHeadModel, GPT2Config, GPT2_PRETRAINED_MODEL_ARCHIVE_MAP)
else:
......@@ -470,7 +470,7 @@ class CommonTestCases:
model = model_class(config)
self.assertIsInstance(
model.get_input_embeddings(),
torch.nn.Embedding
(torch.nn.Embedding, AdaptiveEmbedding)
)
model.set_input_embeddings(torch.nn.Embedding(10, 10))
x = model.get_output_embeddings()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment