"docs/source/api/vscode:/vscode.git/clone" did not exist on "0248541deadfa187150fe7f96a575ff905ecddd7"
Unverified Commit 54abc67a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge pull request #2255 from aaugustin/implement-best-practices

Implement some Python best practices
parents 645713e2 c11b3e29
...@@ -86,6 +86,20 @@ jobs: ...@@ -86,6 +86,20 @@ jobs:
- run: sudo pip install --progress-bar off -r docs/requirements.txt - run: sudo pip install --progress-bar off -r docs/requirements.txt
- run: sudo pip install --progress-bar off -r requirements.txt - run: sudo pip install --progress-bar off -r requirements.txt
- run: ./.circleci/deploy.sh - run: ./.circleci/deploy.sh
check_code_quality:
working_directory: ~/transformers
docker:
- image: circleci/python:3.6
resource_class: medium
parallelism: 1
steps:
- checkout
- run: sudo pip install --editable .
- run: sudo pip install torch tensorflow
- run: sudo pip install black git+git://github.com/timothycrosley/isort.git@e63ae06ec7d70b06df9e528357650281a3d3ec22#egg=isort flake8
- run: black --check --line-length 119 examples templates transformers utils
- run: isort --check-only --recursive examples templates transformers utils
- run: flake8 examples templates transformers utils
check_repository_consistency: check_repository_consistency:
working_directory: ~/transformers working_directory: ~/transformers
docker: docker:
...@@ -105,6 +119,7 @@ workflows: ...@@ -105,6 +119,7 @@ workflows:
version: 2 version: 2
build_and_test: build_and_test:
jobs: jobs:
- check_code_quality
- check_repository_consistency - check_repository_consistency
- run_examples_py3_torch - run_examples_py3_torch
- run_tests_py3_custom_tokenizers - run_tests_py3_custom_tokenizers
......
.PHONY: style
style:
black --line-length 119 examples templates transformers utils
isort --recursive examples templates transformers utils
This diff is collapsed.
from pathlib import Path
import tarfile
import urllib.request
import torch import torch
from transformers.tokenization_camembert import CamembertTokenizer
from transformers.modeling_camembert import CamembertForMaskedLM from transformers.modeling_camembert import CamembertForMaskedLM
from transformers.tokenization_camembert import CamembertTokenizer
def fill_mask(masked_input, model, tokenizer, topk=5): def fill_mask(masked_input, model, tokenizer, topk=5):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py # Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert masked_input.count('<mask>') == 1 assert masked_input.count("<mask>") == 1
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item() masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
logits = logits[0, masked_index, :] logits = logits[0, masked_index, :]
prob = logits.softmax(dim=0) prob = logits.softmax(dim=0)
values, indices = prob.topk(k=topk, dim=0) values, indices = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item()) topk_predicted_token_bpe = " ".join(
for i in range(len(indices))]) [tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
)
masked_token = tokenizer.mask_token masked_token = tokenizer.mask_token
topk_filled_outputs = [] topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
predicted_token = predicted_token_bpe.replace('\u2581', ' ') predicted_token = predicted_token_bpe.replace("\u2581", " ")
if " {0}".format(masked_token) in masked_input: if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace( (
' {0}'.format(masked_token), predicted_token masked_input.replace(" {0}".format(masked_token), predicted_token),
), values[index].item(),
values[index].item(), predicted_token,
predicted_token, )
)) )
else: else:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace(masked_token, predicted_token), (masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token,)
values[index].item(), )
predicted_token,
))
return topk_filled_outputs return topk_filled_outputs
tokenizer = CamembertTokenizer.from_pretrained('camembert-base') tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
model = CamembertForMaskedLM.from_pretrained('camembert-base') model = CamembertForMaskedLM.from_pretrained("camembert-base")
model.eval() model.eval()
masked_input = "Le camembert est <mask> :)" masked_input = "Le camembert est <mask> :)"
......
...@@ -22,48 +22,57 @@ ...@@ -22,48 +22,57 @@
--model_name openai-gpt \ --model_name openai-gpt \
--do_train \ --do_train \
--do_eval \ --do_eval \
--train_dataset $ROC_STORIES_DIR/cloze_test_val__spring2016\ -\ cloze_test_ALL_val.csv \ --train_dataset "$ROC_STORIES_DIR/cloze_test_val__spring2016 - cloze_test_ALL_val.csv" \
--eval_dataset $ROC_STORIES_DIR/cloze_test_test__spring2016\ -\ cloze_test_ALL_test.csv \ --eval_dataset "$ROC_STORIES_DIR/cloze_test_test__spring2016 - cloze_test_ALL_test.csv" \
--output_dir ../log \ --output_dir ../log \
--train_batch_size 16 \ --train_batch_size 16 \
""" """
import argparse import argparse
import os
import csv import csv
import random
import logging import logging
from tqdm import tqdm, trange import os
import random
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset) from tqdm import tqdm, trange
from transformers import (
CONFIG_NAME,
WEIGHTS_NAME,
AdamW,
OpenAIGPTDoubleHeadsModel,
OpenAIGPTTokenizer,
cached_path,
get_linear_schedule_with_warmup,
)
from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer,
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME,
get_linear_schedule_with_warmup)
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def load_rocstories_dataset(dataset_path): def load_rocstories_dataset(dataset_path):
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """ """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
with open(dataset_path, encoding='utf_8') as f: with open(dataset_path, encoding="utf_8") as f:
f = csv.reader(f) f = csv.reader(f)
output = [] output = []
next(f) # skip the first line next(f) # skip the first line
for line in tqdm(f): for line in tqdm(f):
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1)) output.append((" ".join(line[1:5]), line[5], line[6], int(line[-1]) - 1))
return output return output
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token): def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label) """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
...@@ -80,56 +89,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -80,56 +89,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token] with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
input_ids[i, 0, :len(with_cont1)] = with_cont1 input_ids[i, 0, : len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2 input_ids[i, 1, : len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1 mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1 mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)] = with_cont1 lm_labels[i, 0, : len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)] = with_cont2 lm_labels[i, 1, : len(with_cont2)] = with_cont2
mc_labels[i] = mc_label mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
return tensor_datasets return tensor_datasets
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='openai-gpt', parser.add_argument("--model_name", type=str, default="openai-gpt", help="pretrained model name")
help='pretrained model name') parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--output_dir", default=None, type=str, required=True, "--output_dir",
help="The output directory where the model predictions and checkpoints will be written.") default=None,
parser.add_argument('--train_dataset', type=str, default='') type=str,
parser.add_argument('--eval_dataset', type=str, default='') required=True,
parser.add_argument('--seed', type=int, default=42) help="The output directory where the model predictions and checkpoints will be written.",
parser.add_argument('--num_train_epochs', type=int, default=3) )
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--train_dataset", type=str, default="")
parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument("--eval_dataset", type=str, default="")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, parser.add_argument("--seed", type=int, default=42)
help="Epsilon for Adam optimizer.") parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument('--max_grad_norm', type=int, default=1) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--max_steps", default=-1, type=int, parser.add_argument("--eval_batch_size", type=int, default=16)
help="If > 0: set total number of training \ parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
steps to perform. Override num_train_epochs.") parser.add_argument("--max_grad_norm", type=int, default=1)
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, parser.add_argument(
help="Number of updates steps to accumulate before\ "--max_steps",
performing a backward/update pass.") default=-1,
parser.add_argument('--learning_rate', type=float, default=6.25e-5) type=int,
parser.add_argument("--warmup_steps", default=0, type=int, help="If > 0: set total number of training \
help="Linear warmup over warmup_steps.") steps to perform. Override num_train_epochs.",
parser.add_argument('--lr_schedule', type=str, default='warmup_linear') )
parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument(
parser.add_argument('--lm_coef', type=float, default=0.9) "--gradient_accumulation_steps",
parser.add_argument('--n_valid', type=int, default=374) type=int,
default=1,
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") help="Number of updates steps to accumulate before\
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") performing a backward/update pass.",
)
parser.add_argument("--learning_rate", type=float, default=6.25e-5)
parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument("--lr_schedule", type=str, default="warmup_linear")
parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument("--lm_coef", type=float, default=0.9)
parser.add_argument("--n_valid", type=int, default=374)
parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -152,7 +173,7 @@ def main(): ...@@ -152,7 +173,7 @@ def main():
# Load tokenizer and model # Load tokenizer and model
# This loading functions also add new tokens and embeddings called `special tokens` # This loading functions also add new tokens and embeddings called `special tokens`
# These new embeddings will be fine-tuned on the RocStories dataset # These new embeddings will be fine-tuned on the RocStories dataset
special_tokens = ['_start_', '_delimiter_', '_classify_'] special_tokens = ["_start_", "_delimiter_", "_classify_"]
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
tokenizer.add_tokens(special_tokens) tokenizer.add_tokens(special_tokens)
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
...@@ -163,6 +184,7 @@ def main(): ...@@ -163,6 +184,7 @@ def main():
# Load and encode the datasets # Load and encode the datasets
if not args.train_dataset and not args.eval_dataset: if not args.train_dataset and not args.eval_dataset:
roc_stories = cached_path(ROCSTORIES_URL) roc_stories = cached_path(ROCSTORIES_URL)
def tokenize_and_encode(obj): def tokenize_and_encode(obj):
""" Tokenize and encode a nested object """ """ Tokenize and encode a nested object """
if isinstance(obj, str): if isinstance(obj, str):
...@@ -170,6 +192,7 @@ def main(): ...@@ -170,6 +192,7 @@ def main():
elif isinstance(obj, int): elif isinstance(obj, int):
return obj return obj
return list(tokenize_and_encode(o) for o in obj) return list(tokenize_and_encode(o) for o in obj)
logger.info("Encoding dataset...") logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset) train_dataset = load_rocstories_dataset(args.train_dataset)
eval_dataset = load_rocstories_dataset(args.eval_dataset) eval_dataset = load_rocstories_dataset(args.eval_dataset)
...@@ -178,8 +201,11 @@ def main(): ...@@ -178,8 +201,11 @@ def main():
# Compute the max input length for the Transformer # Compute the max input length for the Transformer
max_length = model.config.n_positions // 2 - 2 max_length = model.config.n_positions // 2 - 2
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ input_length = max(
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3
for dataset in encoded_datasets
for story, cont1, cont2, _ in dataset
)
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
# Prepare inputs tensors and dataloaders # Prepare inputs tensors and dataloaders
...@@ -198,20 +224,23 @@ def main(): ...@@ -198,20 +224,23 @@ def main():
if args.do_train: if args.do_train:
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
args.num_train_epochs = args.max_steps //\ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
(len(train_dataloader) // args.gradient_accumulation_steps) + 1
else: else:
t_total = len(train_dataloader)\ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
// args.gradient_accumulation_steps * args.num_train_epochs
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
] "weight_decay": args.weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.do_train: if args.do_train:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
...@@ -230,14 +259,16 @@ def main(): ...@@ -230,14 +259,16 @@ def main():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
tr_loss += loss.item() tr_loss += loss.item()
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item() exp_average_loss = (
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
)
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0]) tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
# Save a trained model # Save a trained model
if args.do_train: if args.do_train:
# Save a trained model, configuration and tokenizer # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself model_to_save = model.module if hasattr(model, "module") else model # Only save the model itself
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
...@@ -260,10 +291,12 @@ def main(): ...@@ -260,10 +291,12 @@ def main():
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_ids, lm_labels, mc_labels = batch input_ids, mc_token_ids, lm_labels, mc_labels = batch
with torch.no_grad(): with torch.no_grad():
_, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels) _, mc_loss, _, mc_logits = model(
input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels
)
mc_logits = mc_logits.detach().cpu().numpy() mc_logits = mc_logits.detach().cpu().numpy()
mc_labels = mc_labels.to('cpu').numpy() mc_labels = mc_labels.to("cpu").numpy()
tmp_eval_accuracy = accuracy(mc_logits, mc_labels) tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
eval_loss += mc_loss.mean().item() eval_loss += mc_loss.mean().item()
...@@ -274,10 +307,8 @@ def main(): ...@@ -274,10 +307,8 @@ def main():
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
train_loss = tr_loss/nb_tr_steps if args.do_train else None train_loss = tr_loss / nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss, result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_loss": train_loss}
'eval_accuracy': eval_accuracy,
'train_loss': train_loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
...@@ -286,5 +317,6 @@ def main(): ...@@ -286,5 +317,6 @@ def main():
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()
This diff is collapsed.
...@@ -23,51 +23,44 @@ from __future__ import absolute_import, division, print_function, unicode_litera ...@@ -23,51 +23,44 @@ from __future__ import absolute_import, division, print_function, unicode_litera
import argparse import argparse
import logging import logging
import time
import math import math
import time
import torch import torch
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer from transformers import TransfoXLCorpus, TransfoXLLMHeadModel, TransfoXLTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
help='pretrained model name') parser.add_argument(
parser.add_argument('--split', type=str, default='test', "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
choices=['all', 'valid', 'test'], )
help='which split to evaluate') parser.add_argument("--batch_size", type=int, default=10, help="batch size")
parser.add_argument('--batch_size', type=int, default=10, parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
help='batch size') parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
parser.add_argument('--tgt_len', type=int, default=128, parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
help='number of tokens to predict') parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
parser.add_argument('--ext_len', type=int, default=0, parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
help='length of the extended context') parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
parser.add_argument('--mem_len', type=int, default=1600, parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
help='length of the retained previous heads') parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
parser.add_argument('--clamp_len', type=int, default=1000, parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
help='max positional embedding index') parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--no_cuda', action='store_true',
help='Do not use CUDA even though CUA is available')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative' assert args.ext_len >= 0, "extended context length must be non-negative"
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -84,17 +77,18 @@ def main(): ...@@ -84,17 +77,18 @@ def main():
corpus = TransfoXLCorpus.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab) ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load a pre-trained model # Load a pre-trained model
model = TransfoXLLMHeadModel.from_pretrained(args.model_name) model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device) model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( logger.info(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
)
)
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0: if args.clamp_len > 0:
...@@ -108,7 +102,7 @@ def main(): ...@@ -108,7 +102,7 @@ def main():
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
total_len, total_loss = 0, 0. total_len, total_loss = 0, 0.0
start_time = time.time() start_time = time.time()
with torch.no_grad(): with torch.no_grad():
mems = None mems = None
...@@ -119,35 +113,34 @@ def main(): ...@@ -119,35 +113,34 @@ def main():
total_loss += seq_len * loss.item() total_loss += seq_len * loss.item()
total_len += seq_len total_len += seq_len
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len return total_loss / total_len
# Run on test data. # Run on test data.
if args.split == 'all': if args.split == "all":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
elif args.split == 'valid': elif args.split == "valid":
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
test_loss = None test_loss = None
elif args.split == 'test': elif args.split == "test":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = None valid_loss = None
def format_log(loss, split): def format_log(loss, split):
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
split, loss, math.exp(loss))
return log_str return log_str
log_str = '' log_str = ""
if valid_loss is not None: if valid_loss is not None:
log_str += format_log(valid_loss, 'valid') log_str += format_log(valid_loss, "valid")
if test_loss is not None: if test_loss is not None:
log_str += format_log(test_loss, 'test') log_str += format_log(test_loss, "test")
logger.info('=' * 100) logger.info("=" * 100)
logger.info(log_str) logger.info(log_str)
logger.info('=' * 100) logger.info("=" * 100)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
This diff is collapsed.
...@@ -17,18 +17,20 @@ ...@@ -17,18 +17,20 @@
import bisect import bisect
import copy import copy
from collections import defaultdict from collections import defaultdict
import numpy as np
import numpy as np
from torch.utils.data.sampler import BatchSampler, Sampler from torch.utils.data.sampler import BatchSampler, Sampler
from utils import logger from utils import logger
def _quantize(x, bins): def _quantize(x, bins):
bins = copy.deepcopy(bins) bins = copy.deepcopy(bins)
bins = sorted(bins) bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized return quantized
def create_lengths_groups(lengths, k=0): def create_lengths_groups(lengths, k=0):
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10] bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
groups = _quantize(lengths, bins) groups = _quantize(lengths, bins)
...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0): ...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0):
logger.info("Count of instances per bin: {}".format(counts)) logger.info("Count of instances per bin: {}".format(counts))
return groups return groups
class GroupedBatchSampler(BatchSampler): class GroupedBatchSampler(BatchSampler):
""" """
Wraps another sampler to yield a mini-batch of indices. Wraps another sampler to yield a mini-batch of indices.
...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler): ...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler):
0, i.e. they must be in the range [0, num_groups). 0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch. batch_size (int): Size of mini-batch.
""" """
def __init__(self, sampler, group_ids, batch_size): def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError( raise ValueError(
"sampler should be an instance of " "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
) )
self.sampler = sampler self.sampler = sampler
self.group_ids = group_ids self.group_ids = group_ids
...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler): ...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler):
buffer_per_group[group_id].append(idx) buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx) samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size: if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id] #TODO yield buffer_per_group[group_id] # TODO
num_batches += 1 num_batches += 1
del buffer_per_group[group_id] del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size assert len(buffer_per_group[group_id]) < self.batch_size
...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler): ...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler):
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]): for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
batch_idx.extend(idxs) batch_idx.extend(idxs)
if len(batch_idx) >= self.batch_size: if len(batch_idx) >= self.batch_size:
yield batch_idx[:self.batch_size] yield batch_idx[: self.batch_size]
batch_idx = batch_idx[self.batch_size:] batch_idx = batch_idx[self.batch_size :]
num_remaining -= 1 num_remaining -= 1
if len(batch_idx) > 0: if len(batch_idx) > 0:
yield batch_idx yield batch_idx
......
...@@ -15,12 +15,13 @@ ...@@ -15,12 +15,13 @@
""" Dataset to distilled models """ Dataset to distilled models
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import numpy as np
import torch import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import numpy as np
from utils import logger from utils import logger
class LmSeqsDataset(Dataset): class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences. """Custom Dataset wrapping language modeling sequences.
...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset): ...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset):
data: `List[np.array[int]] data: `List[np.array[int]]
""" """
def __init__(self, def __init__(self, params, data):
params,
data):
self.params = params self.params = params
self.token_ids = np.array(data) self.token_ids = np.array(data)
...@@ -57,7 +56,7 @@ class LmSeqsDataset(Dataset): ...@@ -57,7 +56,7 @@ class LmSeqsDataset(Dataset):
Some sanity checks Some sanity checks
""" """
assert len(self.token_ids) == len(self.lengths) assert len(self.token_ids) == len(self.lengths)
assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths))) assert all(self.lengths[i] == len(self.token_ids[i]) for i in range(len(self.lengths)))
def remove_long_sequences(self): def remove_long_sequences(self):
""" """
...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset): ...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset):
""" """
max_len = self.params.max_model_input_size max_len = self.params.max_model_input_size
indices = self.lengths > max_len indices = self.lengths > max_len
logger.info(f'Splitting {sum(indices)} too long sequences.') logger.info(f"Splitting {sum(indices)} too long sequences.")
def divide_chunks(l, n): def divide_chunks(l, n):
return [l[i:i + n] for i in range(0, len(l), n)] return [l[i : i + n] for i in range(0, len(l), n)]
new_tok_ids = [] new_tok_ids = []
new_lengths = [] new_lengths = []
if self.params.mlm: if self.params.mlm:
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
else: else:
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token'] cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
for seq_, len_ in zip(self.token_ids, self.lengths): for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset): ...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset):
new_lengths.append(len_) new_lengths.append(len_)
else: else:
sub_seqs = [] sub_seqs = []
for sub_s in divide_chunks(seq_, max_len-2): for sub_s in divide_chunks(seq_, max_len - 2):
if sub_s[0] != cls_id: if sub_s[0] != cls_id:
sub_s = np.insert(sub_s, 0, cls_id) sub_s = np.insert(sub_s, 0, cls_id)
if sub_s[-1] != sep_id: if sub_s[-1] != sep_id:
...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset): ...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset):
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.') logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
def print_statistics(self): def print_statistics(self):
""" """
...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset): ...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset):
""" """
if not self.params.is_master: if not self.params.is_master:
return return
logger.info(f'{len(self)} sequences') logger.info(f"{len(self)} sequences")
# data_len = sum(self.lengths) # data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) # nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset): ...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset):
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def batch_sequences(self, def batch_sequences(self, batch):
batch):
""" """
Do the padding and transform into torch.tensor. Do the padding and transform into torch.tensor.
""" """
...@@ -139,13 +137,13 @@ class LmSeqsDataset(Dataset): ...@@ -139,13 +137,13 @@ class LmSeqsDataset(Dataset):
# Pad token ids # Pad token ids
if self.params.mlm: if self.params.mlm:
pad_idx = self.params.special_tok_ids['pad_token'] pad_idx = self.params.special_tok_ids["pad_token"]
else: else:
pad_idx = self.params.special_tok_ids['unk_token'] pad_idx = self.params.special_tok_ids["unk_token"]
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
assert len(tk_) == len(token_ids) assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_) assert all(len(t) == max_seq_len_ for t in tk_)
tk_t = torch.tensor(tk_) # (bs, max_seq_len_) tk_t = torch.tensor(tk_) # (bs, max_seq_len_)
lg_t = torch.tensor(lengths) # (bs) lg_t = torch.tensor(lengths) # (bs)
return tk_t, lg_t return tk_t, lg_t
...@@ -16,75 +16,75 @@ ...@@ -16,75 +16,75 @@
Preprocessing script before distillation. Preprocessing script before distillation.
""" """
import argparse import argparse
import logging
import pickle import pickle
import random import random
import time import time
import numpy as np import numpy as np
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', from transformers import BertTokenizer, GPT2Tokenizer, RobertaTokenizer
datefmt = '%m/%d/%Y %H:%M:%S',
level = logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(
parser.add_argument('--file_path', type=str, default='data/dump.txt', description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
help='The path to the data.') )
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2']) parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
help="The tokenizer to use.") parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
help='The dump file prefix.')
args = parser.parse_args() args = parser.parse_args()
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
logger.info(f'Loading Tokenizer ({args.tokenizer_name})') if args.tokenizer_type == "bert":
if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]` bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
elif args.tokenizer_type == 'roberta': elif args.tokenizer_type == "roberta":
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `<s>` bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
sep = tokenizer.special_tokens_map['sep_token'] # `</s>` sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
elif args.tokenizer_type == 'gpt2': elif args.tokenizer_type == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name) tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>` bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>` sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
logger.info(f'Loading text from {args.file_path}') logger.info(f"Loading text from {args.file_path}")
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, "r", encoding="utf8") as fp:
data = fp.readlines() data = fp.readlines()
logger.info(f"Start encoding")
logger.info(f'Start encoding') logger.info(f"{len(data)} examples to process.")
logger.info(f'{len(data)} examples to process.')
rslt = [] rslt = []
iter = 0 iter = 0
interval = 10000 interval = 10000
start = time.time() start = time.time()
for text in data: for text in data:
text = f'{bos} {text.strip()} {sep}' text = f"{bos} {text.strip()} {sep}"
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)
rslt.append(token_ids) rslt.append(token_ids)
iter += 1 iter += 1
if iter % interval == 0: if iter % interval == 0:
end = time.time() end = time.time()
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl') logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
start = time.time() start = time.time()
logger.info('Finished binarization') logger.info("Finished binarization")
logger.info(f'{len(data)} examples processed.') logger.info(f"{len(data)} examples processed.")
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle' dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
rslt_ = [np.uint16(d) for d in rslt] rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_) random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}') logger.info(f"Dump to {dp_file}")
with open(dp_file, 'wb') as handle: with open(dp_file, "wb") as handle:
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
......
...@@ -16,74 +16,87 @@ ...@@ -16,74 +16,87 @@
Preprocessing script before training the distilled model. Preprocessing script before training the distilled model.
Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2. Specific to RoBERTa -> DistilRoBERTa and GPT2 -> DistilGPT2.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
import torch
import argparse import argparse
if __name__ == '__main__': import torch
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation")
from transformers import GPT2LMHeadModel, RobertaForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default='roberta-large', type=str) parser.add_argument("--model_name", default="roberta-large", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "roberta":
if args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name) model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta' prefix = "roberta"
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
model = GPT2LMHeadModel.from_pretrained(args.model_name) model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = 'transformer' prefix = "transformer"
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
### Embeddings ### # Embeddings #
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for param_name in ['wte.weight', 'wpe.weight']: for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}'] compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
else: else:
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']: for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
param_name = f'{prefix}.embeddings.{w}.weight' param_name = f"{prefix}.embeddings.{w}.weight"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
param_name = f'{prefix}.embeddings.LayerNorm.{w}' param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ### # Transformer Blocks #
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']: for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \ compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}'] f"{prefix}.h.{teacher_idx}.{layer}.{w}"
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias'] ]
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
else: else:
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value', for layer in [
'attention.output.dense', 'attention.output.LayerNorm', "attention.self.query",
'intermediate.dense', 'output.dense', 'output.LayerNorm']: "attention.self.key",
for w in ['weight', 'bias']: "attention.self.value",
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \ "attention.output.dense",
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}'] "attention.output.LayerNorm",
"intermediate.dense",
"output.dense",
"output.LayerNorm",
]:
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
]
std_idx += 1 std_idx += 1
### Language Modeling Head ###s # Language Modeling Head ###s
if args.model_type == 'roberta': if args.model_type == "roberta":
for layer in ['lm_head.decoder.weight', 'lm_head.bias']: for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f'{layer}'] = state_dict[f'{layer}'] compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}'] compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}'] compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight'] compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -16,67 +16,77 @@ ...@@ -16,67 +16,77 @@
Preprocessing script before training DistilBERT. Preprocessing script before training DistilBERT.
Specific to BERT -> DistilBERT. Specific to BERT -> DistilBERT.
""" """
from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch
import argparse import argparse
if __name__ == '__main__': import torch
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation")
from transformers import BertForMaskedLM
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="bert", choices=["bert"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default="bert-base-uncased", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "bert":
if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = "bert"
else: else:
raise ValueError(f'args.model_type should be "bert".') raise ValueError(f'args.model_type should be "bert".')
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ["word_embeddings", "position_embeddings"]:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
state_dict[f'{prefix}.embeddings.{w}.weight'] for w in ["weight", "bias"]:
for w in ['weight', 'bias']: compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
]
std_idx += 1 std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -15,37 +15,42 @@ ...@@ -15,37 +15,42 @@
""" """
Preprocessing script before training the distilled model. Preprocessing script before training the distilled model.
""" """
from collections import Counter
import argparse import argparse
import pickle
import logging import logging
import pickle
from collections import Counter
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)") parser = argparse.ArgumentParser(
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle", description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
help="The binarized dataset.") )
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", parser.add_argument(
help="The dump file.") "--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
)
parser.add_argument(
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
)
parser.add_argument("--vocab_size", default=30522, type=int) parser.add_argument("--vocab_size", default=30522, type=int)
args = parser.parse_args() args = parser.parse_args()
logger.info(f'Loading data from {args.data_file}') logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, 'rb') as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
logger.info('Counting occurences for MLM.') logger.info("Counting occurences for MLM.")
counter = Counter() counter = Counter()
for tk_ids in data: for tk_ids in data:
counter.update(tk_ids) counter.update(tk_ids)
counts = [0]*args.vocab_size counts = [0] * args.vocab_size
for k, v in counter.items(): for k, v in counter.items():
counts[k] = v counts[k] = v
logger.info(f'Dump to {args.token_counts_dump}') logger.info(f"Dump to {args.token_counts_dump}")
with open(args.token_counts_dump, 'wb') as handle: with open(args.token_counts_dump, "wb") as handle:
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
This diff is collapsed.
...@@ -15,17 +15,21 @@ ...@@ -15,17 +15,21 @@
""" Utils to train DistilBERT """ Utils to train DistilBERT
adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM) adapted in part from Facebook, Inc XLM model (https://github.com/facebookresearch/XLM)
""" """
import git
import json import json
import logging
import os import os
import socket import socket
import torch
import git
import numpy as np import numpy as np
import torch
import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
level = logging.INFO) datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,12 +39,12 @@ def git_log(folder_path: str): ...@@ -35,12 +39,12 @@ def git_log(folder_path: str):
""" """
repo = git.Repo(search_parent_directories=True) repo = git.Repo(search_parent_directories=True)
repo_infos = { repo_infos = {
'repo_id': str(repo), "repo_id": str(repo),
'repo_sha': str(repo.head.object.hexsha), "repo_sha": str(repo.head.object.hexsha),
'repo_branch': str(repo.active_branch) "repo_branch": str(repo.active_branch),
} }
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f: with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4) json.dump(repo_infos, f, indent=4)
...@@ -57,21 +61,21 @@ def init_gpu_params(params): ...@@ -57,21 +61,21 @@ def init_gpu_params(params):
assert torch.cuda.is_available() assert torch.cuda.is_available()
logger.info('Initializing GPUs') logger.info("Initializing GPUs")
if params.n_gpu > 1: if params.n_gpu > 1:
assert params.local_rank != -1 assert params.local_rank != -1
params.world_size = int(os.environ['WORLD_SIZE']) params.world_size = int(os.environ["WORLD_SIZE"])
params.n_gpu_per_node = int(os.environ['N_GPU_NODE']) params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
params.global_rank = int(os.environ['RANK']) params.global_rank = int(os.environ["RANK"])
# number of nodes / node ID # number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node params.node_id = params.global_rank // params.n_gpu_per_node
params.multi_gpu = True params.multi_gpu = True
assert params.n_nodes == int(os.environ['N_NODES']) assert params.n_nodes == int(os.environ["N_NODES"])
assert params.node_id == int(os.environ['NODE_RANK']) assert params.node_id == int(os.environ["NODE_RANK"])
# local job (single GPU) # local job (single GPU)
else: else:
...@@ -114,8 +118,7 @@ def init_gpu_params(params): ...@@ -114,8 +118,7 @@ def init_gpu_params(params):
if params.multi_gpu: if params.multi_gpu:
logger.info("Initializing PyTorch distributed") logger.info("Initializing PyTorch distributed")
torch.distributed.init_process_group( torch.distributed.init_process_group(
init_method='env://', init_method="env://", backend="nccl",
backend='nccl',
) )
......
This diff is collapsed.
...@@ -17,25 +17,16 @@ ...@@ -17,25 +17,16 @@
import json import json
import os import os
from collections import Counter from collections import Counter
from PIL import Image
import torch import torch
import torch.nn as nn import torch.nn as nn
import torchvision import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from PIL import Image
from torch.utils.data import Dataset from torch.utils.data import Dataset
POOLING_BREAKDOWN = {
1: (1, 1), POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
2: (2, 1),
3: (3, 1),
4: (2, 2),
5: (5, 1),
6: (3, 2),
7: (7, 1),
8: (4, 2),
9: (3, 3)
}
class ImageEncoder(nn.Module): class ImageEncoder(nn.Module):
...@@ -54,7 +45,6 @@ class ImageEncoder(nn.Module): ...@@ -54,7 +45,6 @@ class ImageEncoder(nn.Module):
return out # BxNx2048 return out # BxNx2048
class JsonlDataset(Dataset): class JsonlDataset(Dataset):
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
self.data = [json.loads(l) for l in open(data_path)] self.data = [json.loads(l) for l in open(data_path)]
...@@ -72,7 +62,7 @@ class JsonlDataset(Dataset): ...@@ -72,7 +62,7 @@ class JsonlDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
sentence = sentence[:self.max_seq_length] sentence = sentence[: self.max_seq_length]
label = torch.zeros(self.n_classes) label = torch.zeros(self.n_classes)
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
...@@ -80,8 +70,13 @@ class JsonlDataset(Dataset): ...@@ -80,8 +70,13 @@ class JsonlDataset(Dataset):
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
image = self.transforms(image) image = self.transforms(image)
return {"image_start_token": start_token, "image_end_token": end_token, return {
"sentence": sentence, "image": image, "label": label} "image_start_token": start_token,
"image_end_token": end_token,
"sentence": sentence,
"image": image,
"label": label,
}
def get_label_frequencies(self): def get_label_frequencies(self):
label_freqs = Counter() label_freqs = Counter()
...@@ -110,10 +105,31 @@ def collate_fn(batch): ...@@ -110,10 +105,31 @@ def collate_fn(batch):
def get_mmimdb_labels(): def get_mmimdb_labels():
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance', return [
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure', "Crime",
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music', "Drama",
'Musical', 'Animation', 'Biography', 'Film-Noir'] "Thriller",
"Action",
"Comedy",
"Romance",
"Documentary",
"Short",
"Mystery",
"History",
"Family",
"Adventure",
"Fantasy",
"Sci-Fi",
"Western",
"Horror",
"Sport",
"War",
"Music",
"Musical",
"Animation",
"Biography",
"Film-Noir",
]
def get_image_transforms(): def get_image_transforms():
...@@ -122,9 +138,6 @@ def get_image_transforms(): ...@@ -122,9 +138,6 @@ def get_image_transforms():
transforms.Resize(256), transforms.Resize(256),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
mean=[0.46777044, 0.44531429, 0.40661017],
std=[0.12221994, 0.12145835, 0.14380469],
),
] ]
) )
import torch import torch
class ClassificationHead(torch.nn.Module): class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders""" """Classification Head for transformer encoders"""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment