Commit 3e3e1454 authored by LysandreJik's avatar LysandreJik
Browse files

Added GPT to the generative fine-tuning.

parent 47975ed5
......@@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter
from tqdm import tqdm, trange
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,)
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_lm import WikiTextDataset
......@@ -40,7 +41,8 @@ logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
MODEL_CLASSES = {
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer)
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
}
......
......@@ -28,8 +28,6 @@ class WikiTextDataset(Dataset):
# Sort the array by example length.
self.examples.sort(key=len)
print("nice")
def __len__(self):
return len(self.examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment