Commit 3e3e1454 authored by LysandreJik's avatar LysandreJik
Browse files

Added GPT to the generative fine-tuning.

parent 47975ed5
...@@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -30,7 +30,8 @@ from torch.utils.data.distributed import DistributedSampler
from tensorboardX import SummaryWriter from tensorboardX import SummaryWriter
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,) from pytorch_transformers import (WEIGHTS_NAME, GPT2Config, GPT2LMHeadModel, GPT2Tokenizer,
OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
from pytorch_transformers import AdamW, WarmupLinearSchedule from pytorch_transformers import AdamW, WarmupLinearSchedule
from utils_lm import WikiTextDataset from utils_lm import WikiTextDataset
...@@ -40,7 +41,8 @@ logger = logging.getLogger(__name__) ...@@ -40,7 +41,8 @@ logger = logging.getLogger(__name__)
ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ()) ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config,)), ())
MODEL_CLASSES = { MODEL_CLASSES = {
'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer) 'gpt2': (GPT2Config, GPT2LMHeadModel, GPT2Tokenizer),
'openai-gpt': (OpenAIGPTConfig, OpenAIGPTLMHeadModel, OpenAIGPTTokenizer)
} }
......
...@@ -28,8 +28,6 @@ class WikiTextDataset(Dataset): ...@@ -28,8 +28,6 @@ class WikiTextDataset(Dataset):
# Sort the array by example length. # Sort the array by example length.
self.examples.sort(key=len) self.examples.sort(key=len)
print("nice")
def __len__(self): def __len__(self):
return len(self.examples) return len(self.examples)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment