Commit b80684b2 authored by thomwolf's avatar thomwolf
Browse files

fixing run openai gpt example

parent 80607874
...@@ -31,7 +31,9 @@ import torch ...@@ -31,7 +31,9 @@ import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler,
TensorDataset) TensorDataset)
from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam from pytorch_pretrained_bert import OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIAdam, cached_path
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', datefmt = '%m/%d/%Y %H:%M:%S',
...@@ -63,7 +65,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -63,7 +65,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
n_batch = len(dataset) n_batch = len(dataset)
input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64) input_ids = np.zeros((n_batch, 2, input_len), dtype=np.int64)
mc_token_mask = np.zeros((n_batch, 2, input_len), dtype=np.int64) mc_token_mask = np.zeros((n_batch, 2, input_len), dtype=np.int64)
lm_labels = np.full((n_batch, 2, input_len), -1, dtype=np.int64) lm_labels = np.full((n_batch, 2, input_len), fill_value=-1, dtype=np.int64)
mc_labels = np.zeros((n_batch,), dtype=np.int64) mc_labels = np.zeros((n_batch,), dtype=np.int64)
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
...@@ -71,6 +73,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -71,6 +73,7 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
input_ids[i, 0, :len(with_cont1)] = with_cont1 input_ids[i, 0, :len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2 input_ids[i, 1, :len(with_cont2)] = with_cont2
mc_token_mask[i, 0, len(with_cont1) - 1] = 1 mc_token_mask[i, 0, len(with_cont1) - 1] = 1
mc_token_mask[i, 1, len(with_cont2) - 1] = 1
lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:] lm_labels[i, 0, :len(with_cont1)-1] = with_cont1[1:]
lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:] lm_labels[i, 1, :len(with_cont2)-1] = with_cont2[1:]
mc_labels[i] = mc_label mc_labels[i] = mc_label
...@@ -86,8 +89,8 @@ def main(): ...@@ -86,8 +89,8 @@ def main():
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.")
parser.add_argument("--output_dir", default=None, type=str, required=True, parser.add_argument("--output_dir", default=None, type=str, required=True,
help="The output directory where the model predictions and checkpoints will be written.") help="The output directory where the model predictions and checkpoints will be written.")
parser.add_argument('--train_dataset', type=str, default='cloze_test_val__spring2016 - cloze_test_ALL_val.tsv') parser.add_argument('--train_dataset', type=str, default='')
parser.add_argument('--eval_dataset', type=str, default='test_spring2016.tsv') parser.add_argument('--eval_dataset', type=str, default='')
parser.add_argument('--seed', type=int, default=42) parser.add_argument('--seed', type=int, default=42)
parser.add_argument('--num_train_epochs', type=int, default=3) parser.add_argument('--num_train_epochs', type=int, default=3)
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument('--train_batch_size', type=int, default=8)
...@@ -97,7 +100,7 @@ def main(): ...@@ -97,7 +100,7 @@ def main():
parser.add_argument('--warmup_proportion', type=float, default=0.002) parser.add_argument('--warmup_proportion', type=float, default=0.002)
parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument('--weight_decay', type=float, default=0.01)
parser.add_argument('--lm_coef', type=float, default=0.5) parser.add_argument('--lm_coef', type=float, default=0.9)
parser.add_argument('--n_valid', type=int, default=374) parser.add_argument('--n_valid', type=int, default=374)
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
...@@ -137,6 +140,8 @@ def main(): ...@@ -137,6 +140,8 @@ def main():
model.to(device) model.to(device)
# Load and encode the datasets # Load and encode the datasets
if not args.train_dataset and not args.eval_dataset:
roc_stories = cached_path(ROCSTORIES_URL)
def tokenize_and_encode(obj): def tokenize_and_encode(obj):
""" Tokenize and encode a nested object """ """ Tokenize and encode a nested object """
if isinstance(obj, str): if isinstance(obj, str):
...@@ -144,7 +149,6 @@ def main(): ...@@ -144,7 +149,6 @@ def main():
elif isinstance(obj, int): elif isinstance(obj, int):
return obj return obj
return list(tokenize_and_encode(o) for o in obj) return list(tokenize_and_encode(o) for o in obj)
logger.info("Encoding dataset...") logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset) train_dataset = load_rocstories_dataset(args.train_dataset)
eval_dataset = load_rocstories_dataset(args.eval_dataset) eval_dataset = load_rocstories_dataset(args.eval_dataset)
...@@ -152,13 +156,13 @@ def main(): ...@@ -152,13 +156,13 @@ def main():
encoded_datasets = tokenize_and_encode(datasets) encoded_datasets = tokenize_and_encode(datasets)
# Compute the mex input length for the Transformer # Compute the mex input length for the Transformer
input_length = max(len(story) + max(len(cont1), len(cont2)) + 3 \ max_length = model.config.n_positions // 2 - 2
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
max_sub_part_length = input_length // 2 - 2
# Prepare inputs tensors and dataloaders # Prepare inputs tensors and dataloaders
tensor_datasets = pre_process_datasets(encoded_datasets, input_length, max_sub_part_length, *special_tokens_ids) tensor_datasets = pre_process_datasets(encoded_datasets, input_length, max_length, *special_tokens_ids)
train_tensor_dataset, eval_tensor_dataset = tensor_datasets[0], tensor_datasets[1] train_tensor_dataset, eval_tensor_dataset = tensor_datasets[0], tensor_datasets[1]
train_data = TensorDataset(*train_tensor_dataset) train_data = TensorDataset(*train_tensor_dataset)
...@@ -176,7 +180,7 @@ def main(): ...@@ -176,7 +180,7 @@ def main():
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01}, {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
] ]
num_train_optimization_steps = len(train_data) // args.train_batch_size num_train_optimization_steps = len(train_data) * args.num_train_epochs // args.train_batch_size
optimizer = OpenAIAdam(optimizer_grouped_parameters, optimizer = OpenAIAdam(optimizer_grouped_parameters,
lr=args.learning_rate, lr=args.learning_rate,
warmup=args.warmup_proportion, warmup=args.warmup_proportion,
...@@ -185,12 +189,11 @@ def main(): ...@@ -185,12 +189,11 @@ def main():
t_total=num_train_optimization_steps) t_total=num_train_optimization_steps)
if args.do_train: if args.do_train:
nb_tr_steps = 0 nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
tr_loss = 0
model.train() model.train()
for _ in trange(int(args.num_train_epochs), desc="Epoch"): for _ in trange(int(args.num_train_epochs), desc="Epoch"):
tr_loss = 0 tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0 nb_tr_steps = 0
tqdm_bar = tqdm(train_dataloader, desc="Training") tqdm_bar = tqdm(train_dataloader, desc="Training")
for step, batch in enumerate(tqdm_bar): for step, batch in enumerate(tqdm_bar):
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
...@@ -200,21 +203,22 @@ def main(): ...@@ -200,21 +203,22 @@ def main():
loss.backward() loss.backward()
optimizer.step() optimizer.step()
tr_loss += loss.item() tr_loss += loss.item()
nb_tr_examples += input_ids.size(0) exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item()
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:.2e}".format(tr_loss/nb_tr_steps) tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, optimizer.get_lr()[0])
# Save a trained model # Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train: if args.do_train:
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
config = model.config
torch.save(model_to_save.state_dict(), output_model_file) torch.save(model_to_save.state_dict(), output_model_file)
# Load a trained model that you have fine-tuned # Load a trained model that you have fine-tuned
model_state_dict = torch.load(output_model_file) model_state_dict = torch.load(output_model_file)
model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_name, state_dict=model_state_dict, model = OpenAIGPTDoubleHeadsModel(config)
num_special_tokens=len(special_tokens)) model.load_state_dict(model_state_dict)
model.to(device) model.to(device)
if args.do_eval: if args.do_eval:
model.eval() model.eval()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment