Commit fa84ae26 authored by Aymeric Augustin's avatar Aymeric Augustin
Browse files

Reformat source code with black.

This is the result of:

    $ black --line-length 119 examples templates transformers utils hubconf.py setup.py

There's a lot of fairly long lines in the project. As a consequence, I'm
picking the longest widely accepted line length, 119 characters.

This is also Thomas' preference, because it allows for explicit variable
names, to make the code easier to understand.
parent 63e3827c
...@@ -247,7 +247,8 @@ the wall, slowly on into the Social Predestination Room. ...@@ -247,7 +247,8 @@ the wall, slowly on into the Social Predestination Room.
as they entered.""" as they entered."""
def create_setup_and_compute(model_names: List[str], def create_setup_and_compute(
model_names: List[str],
gpu: bool = True, gpu: bool = True,
tensorflow: bool = False, tensorflow: bool = False,
average_over: int = 3, average_over: int = 3,
...@@ -256,7 +257,8 @@ def create_setup_and_compute(model_names: List[str], ...@@ -256,7 +257,8 @@ def create_setup_and_compute(model_names: List[str],
amp: bool = False, amp: bool = False,
fp16: bool = False, fp16: bool = False,
save_to_csv: bool = False, save_to_csv: bool = False,
csv_filename: str = f"results_{round(time())}.csv"): csv_filename: str = f"results_{round(time())}.csv",
):
if xla: if xla:
tf.config.optimizer.set_jit(True) tf.config.optimizer.set_jit(True)
if amp: if amp:
...@@ -266,7 +268,7 @@ def create_setup_and_compute(model_names: List[str], ...@@ -266,7 +268,7 @@ def create_setup_and_compute(model_names: List[str],
dictionary = {model_name: {} for model_name in model_names} dictionary = {model_name: {} for model_name in model_names}
results = _compute_tensorflow(model_names, dictionary, average_over, amp) results = _compute_tensorflow(model_names, dictionary, average_over, amp)
else: else:
device = 'cuda' if (gpu and torch.cuda.is_available()) else 'cpu' device = "cuda" if (gpu and torch.cuda.is_available()) else "cpu"
dictionary = {model_name: {} for model_name in model_names} dictionary = {model_name: {} for model_name in model_names}
results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16) results = _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16)
...@@ -276,22 +278,40 @@ def create_setup_and_compute(model_names: List[str], ...@@ -276,22 +278,40 @@ def create_setup_and_compute(model_names: List[str],
for batch_size in results[model_name]["bs"]: for batch_size in results[model_name]["bs"]:
print("\t\t" + f"===== BATCH SIZE: {batch_size} =====") print("\t\t" + f"===== BATCH SIZE: {batch_size} =====")
for slice_size in results[model_name]["ss"]: for slice_size in results[model_name]["ss"]:
result = results[model_name]['results'][batch_size][slice_size] result = results[model_name]["results"][batch_size][slice_size]
if isinstance(result, str): if isinstance(result, str):
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{result}")
f"{result}")
else: else:
print(f"\t\t{model_name}/{batch_size}/{slice_size}: " print(f"\t\t{model_name}/{batch_size}/{slice_size}: " f"{(round(1000 * result) / 1000)}" f"s")
f"{(round(1000 * result) / 1000)}"
f"s")
if save_to_csv: if save_to_csv:
with open(csv_filename, mode='w') as csv_file: with open(csv_filename, mode="w") as csv_file:
fieldnames = ['model', fieldnames = [
'1x8', '1x64', '1x128', '1x256', '1x512', '1x1024', "model",
'2x8', '2x64', '2x128', '2x256', '2x512', '2x1024', "1x8",
'4x8', '4x64', '4x128', '4x256', '4x512', '4x1024', "1x64",
'8x8', '8x64', '8x128', '8x256', '8x512', '8x1024', "1x128",
"1x256",
"1x512",
"1x1024",
"2x8",
"2x64",
"2x128",
"2x256",
"2x512",
"2x1024",
"4x8",
"4x64",
"4x128",
"4x256",
"4x512",
"4x1024",
"8x8",
"8x64",
"8x128",
"8x256",
"8x512",
"8x1024",
] ]
writer = csv.DictWriter(csv_file, fieldnames=fieldnames) writer = csv.DictWriter(csv_file, fieldnames=fieldnames)
...@@ -299,11 +319,11 @@ def create_setup_and_compute(model_names: List[str], ...@@ -299,11 +319,11 @@ def create_setup_and_compute(model_names: List[str],
for model_name in model_names: for model_name in model_names:
model_results = { model_results = {
f'{bs}x{ss}': results[model_name]['results'][bs][ss] f"{bs}x{ss}": results[model_name]["results"][bs][ss]
for bs in results[model_name]["results"] for bs in results[model_name]["results"]
for ss in results[model_name]['results'][bs] for ss in results[model_name]["results"][bs]
} }
writer.writerow({'model': model_name, **model_results}) writer.writerow({"model": model_name, **model_results})
def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16): def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, fp16):
...@@ -343,7 +363,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript, ...@@ -343,7 +363,7 @@ def _compute_pytorch(model_names, dictionary, average_over, device, torchscript,
print("Going through model with sequence of shape", sequence.shape) print("Going through model with sequence of shape", sequence.shape)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes)/float(len(runtimes)) / 3.0 average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time dictionary[model_name]["results"][batch_size][slice_size] = average_time
except RuntimeError as e: except RuntimeError as e:
print("Doesn't fit on GPU.", e) print("Doesn't fit on GPU.", e)
...@@ -379,7 +399,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -379,7 +399,9 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
if max_input_size is not None and slice_size > max_input_size: if max_input_size is not None and slice_size > max_input_size:
dictionary[model_name]["results"][batch_size][slice_size] = "N/A" dictionary[model_name]["results"][batch_size][slice_size] = "N/A"
else: else:
sequence = tf.stack([tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size) sequence = tf.stack(
[tf.squeeze(tf.constant(tokenized_sequence[:slice_size])[None, :])] * batch_size
)
try: try:
print("Going through model with sequence of shape", sequence.shape) print("Going through model with sequence of shape", sequence.shape)
...@@ -387,7 +409,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -387,7 +409,7 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
inference(sequence) inference(sequence)
runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3) runtimes = timeit.repeat(lambda: inference(sequence), repeat=average_over, number=3)
average_time = sum(runtimes)/float(len(runtimes)) / 3.0 average_time = sum(runtimes) / float(len(runtimes)) / 3.0
dictionary[model_name]["results"][batch_size][slice_size] = average_time dictionary[model_name]["results"][batch_size][slice_size] = average_time
except tf.errors.ResourceExhaustedError as e: except tf.errors.ResourceExhaustedError as e:
print("Doesn't fit on GPU.", e) print("Doesn't fit on GPU.", e)
...@@ -399,33 +421,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp): ...@@ -399,33 +421,64 @@ def _compute_tensorflow(model_names, dictionary, average_over, amp):
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--models", required=False, type=str, default='all', help="Model checkpoints to be provided " parser.add_argument(
"--models",
required=False,
type=str,
default="all",
help="Model checkpoints to be provided "
"to the AutoModel classes. Leave " "to the AutoModel classes. Leave "
"blank to benchmark the base version " "blank to benchmark the base version "
"of all available model " "of all available model "
"architectures.") "architectures.",
parser.add_argument("--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " )
"models") parser.add_argument(
parser.add_argument("--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "--torch", required=False, action="store_true", help="Benchmark the Pytorch version of the " "models"
"cuda devices") )
parser.add_argument("--torchscript", required=False, action="store_true", help="Pytorch only: trace the models " parser.add_argument(
"using torchscript") "--torch_cuda", required=False, action="store_true", help="Pytorch only: run on available " "cuda devices"
parser.add_argument("--tensorflow", required=False, action="store_true", help="Benchmark the TensorFlow version " )
parser.add_argument(
"--torchscript",
required=False,
action="store_true",
help="Pytorch only: trace the models " "using torchscript",
)
parser.add_argument(
"--tensorflow",
required=False,
action="store_true",
help="Benchmark the TensorFlow version "
"of the models. Will run on GPU if " "of the models. Will run on GPU if "
"the correct dependencies are " "the correct dependencies are "
"installed") "installed",
)
parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.") parser.add_argument("--xla", required=False, action="store_true", help="TensorFlow only: use XLA acceleration.")
parser.add_argument("--amp", required=False, action="store_true", help="TensorFlow only: use automatic mixed precision acceleration.") parser.add_argument(
parser.add_argument("--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference.") "--amp",
parser.add_argument("--keras_predict", required=False, action="store_true", help="Whether to use model.predict " required=False,
"instead of model() to do a " action="store_true",
"forward pass.") help="TensorFlow only: use automatic mixed precision acceleration.",
)
parser.add_argument(
"--fp16", required=False, action="store_true", help="PyTorch only: use FP16 to accelerate inference."
)
parser.add_argument(
"--keras_predict",
required=False,
action="store_true",
help="Whether to use model.predict " "instead of model() to do a " "forward pass.",
)
parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.") parser.add_argument("--save_to_csv", required=False, action="store_true", help="Save to a CSV file.")
parser.add_argument("--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv.") parser.add_argument(
parser.add_argument("--average_over", required=False, default=30, type=int, help="Times an experiment will be run.") "--csv_filename", required=False, default=None, help="CSV filename used if saving results to csv."
)
parser.add_argument(
"--average_over", required=False, default=30, type=int, help="Times an experiment will be run."
)
args = parser.parse_args() args = parser.parse_args()
if args.models == 'all': if args.models == "all":
args.models = [ args.models = [
"gpt2", "gpt2",
"bert-base-cased", "bert-base-cased",
...@@ -436,7 +489,7 @@ def main(): ...@@ -436,7 +489,7 @@ def main():
"distilbert-base-uncased", "distilbert-base-uncased",
"distilgpt2", "distilgpt2",
"roberta-base", "roberta-base",
"ctrl" "ctrl",
] ]
else: else:
args.models = args.models.split() args.models = args.models.split()
...@@ -453,7 +506,7 @@ def main(): ...@@ -453,7 +506,7 @@ def main():
fp16=args.fp16, fp16=args.fp16,
save_to_csv=args.save_to_csv, save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over,
) )
else: else:
raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.") raise ImportError("Trying to run a PyTorch benchmark but PyTorch was not found in the environment.")
...@@ -467,11 +520,11 @@ def main(): ...@@ -467,11 +520,11 @@ def main():
amp=args.amp, amp=args.amp,
save_to_csv=args.save_to_csv, save_to_csv=args.save_to_csv,
csv_filename=args.csv_filename, csv_filename=args.csv_filename,
average_over=args.average_over average_over=args.average_over,
) )
else: else:
raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.") raise ImportError("Trying to run a TensorFlow benchmark but TensorFlow was not found in the environment.")
if __name__ == '__main__':
main()
if __name__ == "__main__":
main()
...@@ -10,38 +10,37 @@ from transformers.modeling_camembert import CamembertForMaskedLM ...@@ -10,38 +10,37 @@ from transformers.modeling_camembert import CamembertForMaskedLM
def fill_mask(masked_input, model, tokenizer, topk=5): def fill_mask(masked_input, model, tokenizer, topk=5):
# Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py # Adapted from https://github.com/pytorch/fairseq/blob/master/fairseq/models/roberta/hub_interface.py
assert masked_input.count('<mask>') == 1 assert masked_input.count("<mask>") == 1
input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1 input_ids = torch.tensor(tokenizer.encode(masked_input, add_special_tokens=True)).unsqueeze(0) # Batch size 1
logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple logits = model(input_ids)[0] # The last hidden-state is the first element of the output tuple
masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item() masked_index = (input_ids.squeeze() == tokenizer.mask_token_id).nonzero().item()
logits = logits[0, masked_index, :] logits = logits[0, masked_index, :]
prob = logits.softmax(dim=0) prob = logits.softmax(dim=0)
values, indices = prob.topk(k=topk, dim=0) values, indices = prob.topk(k=topk, dim=0)
topk_predicted_token_bpe = ' '.join([tokenizer.convert_ids_to_tokens(indices[i].item()) topk_predicted_token_bpe = " ".join(
for i in range(len(indices))]) [tokenizer.convert_ids_to_tokens(indices[i].item()) for i in range(len(indices))]
)
masked_token = tokenizer.mask_token masked_token = tokenizer.mask_token
topk_filled_outputs = [] topk_filled_outputs = []
for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(' ')): for index, predicted_token_bpe in enumerate(topk_predicted_token_bpe.split(" ")):
predicted_token = predicted_token_bpe.replace('\u2581', ' ') predicted_token = predicted_token_bpe.replace("\u2581", " ")
if " {0}".format(masked_token) in masked_input: if " {0}".format(masked_token) in masked_input:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace( (
' {0}'.format(masked_token), predicted_token masked_input.replace(" {0}".format(masked_token), predicted_token),
),
values[index].item(), values[index].item(),
predicted_token, predicted_token,
)) )
)
else: else:
topk_filled_outputs.append(( topk_filled_outputs.append(
masked_input.replace(masked_token, predicted_token), (masked_input.replace(masked_token, predicted_token), values[index].item(), predicted_token,)
values[index].item(), )
predicted_token,
))
return topk_filled_outputs return topk_filled_outputs
tokenizer = CamembertTokenizer.from_pretrained('camembert-base') tokenizer = CamembertTokenizer.from_pretrained("camembert-base")
model = CamembertForMaskedLM.from_pretrained('camembert-base') model = CamembertForMaskedLM.from_pretrained("camembert-base")
model.eval() model.eval()
masked_input = "Le camembert est <mask> :)" masked_input = "Le camembert est <mask> :)"
......
...@@ -36,34 +36,42 @@ from tqdm import tqdm, trange ...@@ -36,34 +36,42 @@ from tqdm import tqdm, trange
import numpy as np import numpy as np
import torch import torch
from torch.utils.data import (DataLoader, RandomSampler, SequentialSampler, from torch.utils.data import DataLoader, RandomSampler, SequentialSampler, TensorDataset
TensorDataset)
from transformers import (
from transformers import (OpenAIGPTDoubleHeadsModel, OpenAIGPTTokenizer, OpenAIGPTDoubleHeadsModel,
AdamW, cached_path, WEIGHTS_NAME, CONFIG_NAME, OpenAIGPTTokenizer,
get_linear_schedule_with_warmup) AdamW,
cached_path,
WEIGHTS_NAME,
CONFIG_NAME,
get_linear_schedule_with_warmup,
)
ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz" ROCSTORIES_URL = "https://s3.amazonaws.com/datasets.huggingface.co/ROCStories.tar.gz"
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def accuracy(out, labels): def accuracy(out, labels):
outputs = np.argmax(out, axis=1) outputs = np.argmax(out, axis=1)
return np.sum(outputs == labels) return np.sum(outputs == labels)
def load_rocstories_dataset(dataset_path): def load_rocstories_dataset(dataset_path):
""" Output a list of tuples(story, 1st continuation, 2nd continuation, label) """ """ Output a list of tuples(story, 1st continuation, 2nd continuation, label) """
with open(dataset_path, encoding='utf_8') as f: with open(dataset_path, encoding="utf_8") as f:
f = csv.reader(f) f = csv.reader(f)
output = [] output = []
next(f) # skip the first line next(f) # skip the first line
for line in tqdm(f): for line in tqdm(f):
output.append((' '.join(line[1:5]), line[5], line[6], int(line[-1])-1)) output.append((" ".join(line[1:5]), line[5], line[6], int(line[-1]) - 1))
return output return output
def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token): def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, delimiter_token, clf_token):
""" Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label) """ Pre-process datasets containing lists of tuples(story, 1st continuation, 2nd continuation, label)
...@@ -80,56 +88,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d ...@@ -80,56 +88,68 @@ def pre_process_datasets(encoded_datasets, input_len, cap_length, start_token, d
for i, (story, cont1, cont2, mc_label), in enumerate(dataset): for i, (story, cont1, cont2, mc_label), in enumerate(dataset):
with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token] with_cont1 = [start_token] + story[:cap_length] + [delimiter_token] + cont1[:cap_length] + [clf_token]
with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token] with_cont2 = [start_token] + story[:cap_length] + [delimiter_token] + cont2[:cap_length] + [clf_token]
input_ids[i, 0, :len(with_cont1)] = with_cont1 input_ids[i, 0, : len(with_cont1)] = with_cont1
input_ids[i, 1, :len(with_cont2)] = with_cont2 input_ids[i, 1, : len(with_cont2)] = with_cont2
mc_token_ids[i, 0] = len(with_cont1) - 1 mc_token_ids[i, 0] = len(with_cont1) - 1
mc_token_ids[i, 1] = len(with_cont2) - 1 mc_token_ids[i, 1] = len(with_cont2) - 1
lm_labels[i, 0, :len(with_cont1)] = with_cont1 lm_labels[i, 0, : len(with_cont1)] = with_cont1
lm_labels[i, 1, :len(with_cont2)] = with_cont2 lm_labels[i, 1, : len(with_cont2)] = with_cont2
mc_labels[i] = mc_label mc_labels[i] = mc_label
all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels) all_inputs = (input_ids, mc_token_ids, lm_labels, mc_labels)
tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs)) tensor_datasets.append(tuple(torch.tensor(t) for t in all_inputs))
return tensor_datasets return tensor_datasets
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument('--model_name', type=str, default='openai-gpt', parser.add_argument("--model_name", type=str, default="openai-gpt", help="pretrained model name")
help='pretrained model name') parser.add_argument("--do_train", action="store_true", help="Whether to run training.")
parser.add_argument("--do_train", action='store_true', help="Whether to run training.") parser.add_argument("--do_eval", action="store_true", help="Whether to run eval on the dev set.")
parser.add_argument("--do_eval", action='store_true', help="Whether to run eval on the dev set.") parser.add_argument(
parser.add_argument("--output_dir", default=None, type=str, required=True, "--output_dir",
help="The output directory where the model predictions and checkpoints will be written.") default=None,
parser.add_argument('--train_dataset', type=str, default='') type=str,
parser.add_argument('--eval_dataset', type=str, default='') required=True,
parser.add_argument('--seed', type=int, default=42) help="The output directory where the model predictions and checkpoints will be written.",
parser.add_argument('--num_train_epochs', type=int, default=3) )
parser.add_argument('--train_batch_size', type=int, default=8) parser.add_argument("--train_dataset", type=str, default="")
parser.add_argument('--eval_batch_size', type=int, default=16) parser.add_argument("--eval_dataset", type=str, default="")
parser.add_argument("--adam_epsilon", default=1e-8, type=float, parser.add_argument("--seed", type=int, default=42)
help="Epsilon for Adam optimizer.") parser.add_argument("--num_train_epochs", type=int, default=3)
parser.add_argument('--max_grad_norm', type=int, default=1) parser.add_argument("--train_batch_size", type=int, default=8)
parser.add_argument("--max_steps", default=-1, type=int, parser.add_argument("--eval_batch_size", type=int, default=16)
parser.add_argument("--adam_epsilon", default=1e-8, type=float, help="Epsilon for Adam optimizer.")
parser.add_argument("--max_grad_norm", type=int, default=1)
parser.add_argument(
"--max_steps",
default=-1,
type=int,
help="If > 0: set total number of training \ help="If > 0: set total number of training \
steps to perform. Override num_train_epochs.") steps to perform. Override num_train_epochs.",
parser.add_argument('--gradient_accumulation_steps', type=int, default=1, )
parser.add_argument(
"--gradient_accumulation_steps",
type=int,
default=1,
help="Number of updates steps to accumulate before\ help="Number of updates steps to accumulate before\
performing a backward/update pass.") performing a backward/update pass.",
parser.add_argument('--learning_rate', type=float, default=6.25e-5) )
parser.add_argument("--warmup_steps", default=0, type=int, parser.add_argument("--learning_rate", type=float, default=6.25e-5)
help="Linear warmup over warmup_steps.") parser.add_argument("--warmup_steps", default=0, type=int, help="Linear warmup over warmup_steps.")
parser.add_argument('--lr_schedule', type=str, default='warmup_linear') parser.add_argument("--lr_schedule", type=str, default="warmup_linear")
parser.add_argument('--weight_decay', type=float, default=0.01) parser.add_argument("--weight_decay", type=float, default=0.01)
parser.add_argument('--lm_coef', type=float, default=0.9) parser.add_argument("--lm_coef", type=float, default=0.9)
parser.add_argument('--n_valid', type=int, default=374) parser.add_argument("--n_valid", type=int, default=374)
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.") parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -152,7 +172,7 @@ def main(): ...@@ -152,7 +172,7 @@ def main():
# Load tokenizer and model # Load tokenizer and model
# This loading functions also add new tokens and embeddings called `special tokens` # This loading functions also add new tokens and embeddings called `special tokens`
# These new embeddings will be fine-tuned on the RocStories dataset # These new embeddings will be fine-tuned on the RocStories dataset
special_tokens = ['_start_', '_delimiter_', '_classify_'] special_tokens = ["_start_", "_delimiter_", "_classify_"]
tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name) tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_name)
tokenizer.add_tokens(special_tokens) tokenizer.add_tokens(special_tokens)
special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens) special_tokens_ids = tokenizer.convert_tokens_to_ids(special_tokens)
...@@ -163,6 +183,7 @@ def main(): ...@@ -163,6 +183,7 @@ def main():
# Load and encode the datasets # Load and encode the datasets
if not args.train_dataset and not args.eval_dataset: if not args.train_dataset and not args.eval_dataset:
roc_stories = cached_path(ROCSTORIES_URL) roc_stories = cached_path(ROCSTORIES_URL)
def tokenize_and_encode(obj): def tokenize_and_encode(obj):
""" Tokenize and encode a nested object """ """ Tokenize and encode a nested object """
if isinstance(obj, str): if isinstance(obj, str):
...@@ -170,6 +191,7 @@ def main(): ...@@ -170,6 +191,7 @@ def main():
elif isinstance(obj, int): elif isinstance(obj, int):
return obj return obj
return list(tokenize_and_encode(o) for o in obj) return list(tokenize_and_encode(o) for o in obj)
logger.info("Encoding dataset...") logger.info("Encoding dataset...")
train_dataset = load_rocstories_dataset(args.train_dataset) train_dataset = load_rocstories_dataset(args.train_dataset)
eval_dataset = load_rocstories_dataset(args.eval_dataset) eval_dataset = load_rocstories_dataset(args.eval_dataset)
...@@ -178,8 +200,11 @@ def main(): ...@@ -178,8 +200,11 @@ def main():
# Compute the max input length for the Transformer # Compute the max input length for the Transformer
max_length = model.config.n_positions // 2 - 2 max_length = model.config.n_positions // 2 - 2
input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3 \ input_length = max(
for dataset in encoded_datasets for story, cont1, cont2, _ in dataset) len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3
for dataset in encoded_datasets
for story, cont1, cont2, _ in dataset
)
input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model input_length = min(input_length, model.config.n_positions) # Max size of input for the pre-trained model
# Prepare inputs tensors and dataloaders # Prepare inputs tensors and dataloaders
...@@ -198,20 +223,23 @@ def main(): ...@@ -198,20 +223,23 @@ def main():
if args.do_train: if args.do_train:
if args.max_steps > 0: if args.max_steps > 0:
t_total = args.max_steps t_total = args.max_steps
args.num_train_epochs = args.max_steps //\ args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
(len(train_dataloader) // args.gradient_accumulation_steps) + 1
else: else:
t_total = len(train_dataloader)\ t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs
// args.gradient_accumulation_steps * args.num_train_epochs
param_optimizer = list(model.named_parameters()) param_optimizer = list(model.named_parameters())
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight'] no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay}, {
{'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0} "params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
"weight_decay": args.weight_decay,
},
{"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], "weight_decay": 0.0},
] ]
optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) scheduler = get_linear_schedule_with_warmup(
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
)
if args.do_train: if args.do_train:
nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
...@@ -230,14 +258,16 @@ def main(): ...@@ -230,14 +258,16 @@ def main():
optimizer.step() optimizer.step()
optimizer.zero_grad() optimizer.zero_grad()
tr_loss += loss.item() tr_loss += loss.item()
exp_average_loss = loss.item() if exp_average_loss is None else 0.7*exp_average_loss+0.3*loss.item() exp_average_loss = (
loss.item() if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item()
)
nb_tr_steps += 1 nb_tr_steps += 1
tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0]) tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(exp_average_loss, scheduler.get_lr()[0])
# Save a trained model # Save a trained model
if args.do_train: if args.do_train:
# Save a trained model, configuration and tokenizer # Save a trained model, configuration and tokenizer
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model itself model_to_save = model.module if hasattr(model, "module") else model # Only save the model itself
# If we save using the predefined names, we can load using `from_pretrained` # If we save using the predefined names, we can load using `from_pretrained`
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
...@@ -260,10 +290,12 @@ def main(): ...@@ -260,10 +290,12 @@ def main():
batch = tuple(t.to(device) for t in batch) batch = tuple(t.to(device) for t in batch)
input_ids, mc_token_ids, lm_labels, mc_labels = batch input_ids, mc_token_ids, lm_labels, mc_labels = batch
with torch.no_grad(): with torch.no_grad():
_, mc_loss, _, mc_logits = model(input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels) _, mc_loss, _, mc_logits = model(
input_ids, mc_token_ids=mc_token_ids, lm_labels=lm_labels, mc_labels=mc_labels
)
mc_logits = mc_logits.detach().cpu().numpy() mc_logits = mc_logits.detach().cpu().numpy()
mc_labels = mc_labels.to('cpu').numpy() mc_labels = mc_labels.to("cpu").numpy()
tmp_eval_accuracy = accuracy(mc_logits, mc_labels) tmp_eval_accuracy = accuracy(mc_logits, mc_labels)
eval_loss += mc_loss.mean().item() eval_loss += mc_loss.mean().item()
...@@ -274,10 +306,8 @@ def main(): ...@@ -274,10 +306,8 @@ def main():
eval_loss = eval_loss / nb_eval_steps eval_loss = eval_loss / nb_eval_steps
eval_accuracy = eval_accuracy / nb_eval_examples eval_accuracy = eval_accuracy / nb_eval_examples
train_loss = tr_loss/nb_tr_steps if args.do_train else None train_loss = tr_loss / nb_tr_steps if args.do_train else None
result = {'eval_loss': eval_loss, result = {"eval_loss": eval_loss, "eval_accuracy": eval_accuracy, "train_loss": train_loss}
'eval_accuracy': eval_accuracy,
'train_loss': train_loss}
output_eval_file = os.path.join(args.output_dir, "eval_results.txt") output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
with open(output_eval_file, "w") as writer: with open(output_eval_file, "w") as writer:
...@@ -286,5 +316,6 @@ def main(): ...@@ -286,5 +316,6 @@ def main():
logger.info(" %s = %s", key, str(result[key])) logger.info(" %s = %s", key, str(result[key]))
writer.write("%s = %s\n" % (key, str(result[key]))) writer.write("%s = %s\n" % (key, str(result[key])))
if __name__ == '__main__':
if __name__ == "__main__":
main() main()
This diff is collapsed.
...@@ -30,44 +30,36 @@ import torch ...@@ -30,44 +30,36 @@ import torch
from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer from transformers import TransfoXLLMHeadModel, TransfoXLCorpus, TransfoXLTokenizer
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description='PyTorch Transformer Language Model') parser = argparse.ArgumentParser(description="PyTorch Transformer Language Model")
parser.add_argument('--model_name', type=str, default='transfo-xl-wt103', parser.add_argument("--model_name", type=str, default="transfo-xl-wt103", help="pretrained model name")
help='pretrained model name') parser.add_argument(
parser.add_argument('--split', type=str, default='test', "--split", type=str, default="test", choices=["all", "valid", "test"], help="which split to evaluate"
choices=['all', 'valid', 'test'], )
help='which split to evaluate') parser.add_argument("--batch_size", type=int, default=10, help="batch size")
parser.add_argument('--batch_size', type=int, default=10, parser.add_argument("--tgt_len", type=int, default=128, help="number of tokens to predict")
help='batch size') parser.add_argument("--ext_len", type=int, default=0, help="length of the extended context")
parser.add_argument('--tgt_len', type=int, default=128, parser.add_argument("--mem_len", type=int, default=1600, help="length of the retained previous heads")
help='number of tokens to predict') parser.add_argument("--clamp_len", type=int, default=1000, help="max positional embedding index")
parser.add_argument('--ext_len', type=int, default=0, parser.add_argument("--no_cuda", action="store_true", help="Do not use CUDA even though CUA is available")
help='length of the extended context') parser.add_argument("--work_dir", type=str, required=True, help="path to the work_dir")
parser.add_argument('--mem_len', type=int, default=1600, parser.add_argument("--no_log", action="store_true", help="do not log the eval result")
help='length of the retained previous heads') parser.add_argument("--same_length", action="store_true", help="set same length attention with masking")
parser.add_argument('--clamp_len', type=int, default=1000, parser.add_argument("--server_ip", type=str, default="", help="Can be used for distant debugging.")
help='max positional embedding index') parser.add_argument("--server_port", type=str, default="", help="Can be used for distant debugging.")
parser.add_argument('--no_cuda', action='store_true',
help='Do not use CUDA even though CUA is available')
parser.add_argument('--work_dir', type=str, required=True,
help='path to the work_dir')
parser.add_argument('--no_log', action='store_true',
help='do not log the eval result')
parser.add_argument('--same_length', action='store_true',
help='set same length attention with masking')
parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
args = parser.parse_args() args = parser.parse_args()
assert args.ext_len >= 0, 'extended context length must be non-negative' assert args.ext_len >= 0, "extended context length must be non-negative"
if args.server_ip and args.server_port: if args.server_ip and args.server_port:
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import ptvsd import ptvsd
print("Waiting for debugger attach") print("Waiting for debugger attach")
ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
ptvsd.wait_for_attach() ptvsd.wait_for_attach()
...@@ -84,17 +76,18 @@ def main(): ...@@ -84,17 +76,18 @@ def main():
corpus = TransfoXLCorpus.from_pretrained(args.model_name) corpus = TransfoXLCorpus.from_pretrained(args.model_name)
ntokens = len(corpus.vocab) ntokens = len(corpus.vocab)
va_iter = corpus.get_iterator('valid', args.batch_size, args.tgt_len, va_iter = corpus.get_iterator("valid", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
device=device, ext_len=args.ext_len) te_iter = corpus.get_iterator("test", args.batch_size, args.tgt_len, device=device, ext_len=args.ext_len)
te_iter = corpus.get_iterator('test', args.batch_size, args.tgt_len,
device=device, ext_len=args.ext_len)
# Load a pre-trained model # Load a pre-trained model
model = TransfoXLLMHeadModel.from_pretrained(args.model_name) model = TransfoXLLMHeadModel.from_pretrained(args.model_name)
model = model.to(device) model = model.to(device)
logger.info('Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}'.format( logger.info(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len)) "Evaluating with bsz {} tgt_len {} ext_len {} mem_len {} clamp_len {}".format(
args.batch_size, args.tgt_len, args.ext_len, args.mem_len, args.clamp_len
)
)
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
if args.clamp_len > 0: if args.clamp_len > 0:
...@@ -108,7 +101,7 @@ def main(): ...@@ -108,7 +101,7 @@ def main():
def evaluate(eval_iter): def evaluate(eval_iter):
# Turn on evaluation mode which disables dropout. # Turn on evaluation mode which disables dropout.
model.eval() model.eval()
total_len, total_loss = 0, 0. total_len, total_loss = 0, 0.0
start_time = time.time() start_time = time.time()
with torch.no_grad(): with torch.no_grad():
mems = None mems = None
...@@ -119,35 +112,34 @@ def main(): ...@@ -119,35 +112,34 @@ def main():
total_loss += seq_len * loss.item() total_loss += seq_len * loss.item()
total_len += seq_len total_len += seq_len
total_time = time.time() - start_time total_time = time.time() - start_time
logger.info('Time : {:.2f}s, {:.2f}ms/segment'.format( logger.info("Time : {:.2f}s, {:.2f}ms/segment".format(total_time, 1000 * total_time / (idx + 1)))
total_time, 1000 * total_time / (idx+1)))
return total_loss / total_len return total_loss / total_len
# Run on test data. # Run on test data.
if args.split == 'all': if args.split == "all":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
elif args.split == 'valid': elif args.split == "valid":
valid_loss = evaluate(va_iter) valid_loss = evaluate(va_iter)
test_loss = None test_loss = None
elif args.split == 'test': elif args.split == "test":
test_loss = evaluate(te_iter) test_loss = evaluate(te_iter)
valid_loss = None valid_loss = None
def format_log(loss, split): def format_log(loss, split):
log_str = '| {0} loss {1:5.2f} | {0} ppl {2:9.3f} '.format( log_str = "| {0} loss {1:5.2f} | {0} ppl {2:9.3f} ".format(split, loss, math.exp(loss))
split, loss, math.exp(loss))
return log_str return log_str
log_str = '' log_str = ""
if valid_loss is not None: if valid_loss is not None:
log_str += format_log(valid_loss, 'valid') log_str += format_log(valid_loss, "valid")
if test_loss is not None: if test_loss is not None:
log_str += format_log(test_loss, 'test') log_str += format_log(test_loss, "test")
logger.info('=' * 100) logger.info("=" * 100)
logger.info(log_str) logger.info(log_str)
logger.info('=' * 100) logger.info("=" * 100)
if __name__ == '__main__': if __name__ == "__main__":
main() main()
This diff is collapsed.
...@@ -23,12 +23,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler ...@@ -23,12 +23,14 @@ from torch.utils.data.sampler import BatchSampler, Sampler
from utils import logger from utils import logger
def _quantize(x, bins): def _quantize(x, bins):
bins = copy.deepcopy(bins) bins = copy.deepcopy(bins)
bins = sorted(bins) bins = sorted(bins)
quantized = list(map(lambda y: bisect.bisect_right(bins, y), x)) quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
return quantized return quantized
def create_lengths_groups(lengths, k=0): def create_lengths_groups(lengths, k=0):
bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10] bins = np.arange(start=3, stop=k, step=4).tolist() if k > 0 else [10]
groups = _quantize(lengths, bins) groups = _quantize(lengths, bins)
...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0): ...@@ -39,6 +41,7 @@ def create_lengths_groups(lengths, k=0):
logger.info("Count of instances per bin: {}".format(counts)) logger.info("Count of instances per bin: {}".format(counts))
return groups return groups
class GroupedBatchSampler(BatchSampler): class GroupedBatchSampler(BatchSampler):
""" """
Wraps another sampler to yield a mini-batch of indices. Wraps another sampler to yield a mini-batch of indices.
...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler): ...@@ -53,11 +56,11 @@ class GroupedBatchSampler(BatchSampler):
0, i.e. they must be in the range [0, num_groups). 0, i.e. they must be in the range [0, num_groups).
batch_size (int): Size of mini-batch. batch_size (int): Size of mini-batch.
""" """
def __init__(self, sampler, group_ids, batch_size): def __init__(self, sampler, group_ids, batch_size):
if not isinstance(sampler, Sampler): if not isinstance(sampler, Sampler):
raise ValueError( raise ValueError(
"sampler should be an instance of " "sampler should be an instance of " "torch.utils.data.Sampler, but got sampler={}".format(sampler)
"torch.utils.data.Sampler, but got sampler={}".format(sampler)
) )
self.sampler = sampler self.sampler = sampler
self.group_ids = group_ids self.group_ids = group_ids
...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler): ...@@ -73,7 +76,7 @@ class GroupedBatchSampler(BatchSampler):
buffer_per_group[group_id].append(idx) buffer_per_group[group_id].append(idx)
samples_per_group[group_id].append(idx) samples_per_group[group_id].append(idx)
if len(buffer_per_group[group_id]) == self.batch_size: if len(buffer_per_group[group_id]) == self.batch_size:
yield buffer_per_group[group_id] #TODO yield buffer_per_group[group_id] # TODO
num_batches += 1 num_batches += 1
del buffer_per_group[group_id] del buffer_per_group[group_id]
assert len(buffer_per_group[group_id]) < self.batch_size assert len(buffer_per_group[group_id]) < self.batch_size
...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler): ...@@ -90,8 +93,8 @@ class GroupedBatchSampler(BatchSampler):
for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]): for group_id, idxs in sorted(buffer_per_group.items(), key=lambda x: x[0]):
batch_idx.extend(idxs) batch_idx.extend(idxs)
if len(batch_idx) >= self.batch_size: if len(batch_idx) >= self.batch_size:
yield batch_idx[:self.batch_size] yield batch_idx[: self.batch_size]
batch_idx = batch_idx[self.batch_size:] batch_idx = batch_idx[self.batch_size :]
num_remaining -= 1 num_remaining -= 1
if len(batch_idx) > 0: if len(batch_idx) > 0:
yield batch_idx yield batch_idx
......
...@@ -21,6 +21,7 @@ from torch.utils.data import Dataset ...@@ -21,6 +21,7 @@ from torch.utils.data import Dataset
import numpy as np import numpy as np
from utils import logger from utils import logger
class LmSeqsDataset(Dataset): class LmSeqsDataset(Dataset):
"""Custom Dataset wrapping language modeling sequences. """Custom Dataset wrapping language modeling sequences.
...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset): ...@@ -32,9 +33,7 @@ class LmSeqsDataset(Dataset):
data: `List[np.array[int]] data: `List[np.array[int]]
""" """
def __init__(self, def __init__(self, params, data):
params,
data):
self.params = params self.params = params
self.token_ids = np.array(data) self.token_ids = np.array(data)
...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset): ...@@ -65,17 +64,17 @@ class LmSeqsDataset(Dataset):
""" """
max_len = self.params.max_model_input_size max_len = self.params.max_model_input_size
indices = self.lengths > max_len indices = self.lengths > max_len
logger.info(f'Splitting {sum(indices)} too long sequences.') logger.info(f"Splitting {sum(indices)} too long sequences.")
def divide_chunks(l, n): def divide_chunks(l, n):
return [l[i:i + n] for i in range(0, len(l), n)] return [l[i : i + n] for i in range(0, len(l), n)]
new_tok_ids = [] new_tok_ids = []
new_lengths = [] new_lengths = []
if self.params.mlm: if self.params.mlm:
cls_id, sep_id = self.params.special_tok_ids['cls_token'], self.params.special_tok_ids['sep_token'] cls_id, sep_id = self.params.special_tok_ids["cls_token"], self.params.special_tok_ids["sep_token"]
else: else:
cls_id, sep_id = self.params.special_tok_ids['bos_token'], self.params.special_tok_ids['eos_token'] cls_id, sep_id = self.params.special_tok_ids["bos_token"], self.params.special_tok_ids["eos_token"]
for seq_, len_ in zip(self.token_ids, self.lengths): for seq_, len_ in zip(self.token_ids, self.lengths):
assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_ assert (seq_[0] == cls_id) and (seq_[-1] == sep_id), seq_
...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset): ...@@ -84,7 +83,7 @@ class LmSeqsDataset(Dataset):
new_lengths.append(len_) new_lengths.append(len_)
else: else:
sub_seqs = [] sub_seqs = []
for sub_s in divide_chunks(seq_, max_len-2): for sub_s in divide_chunks(seq_, max_len - 2):
if sub_s[0] != cls_id: if sub_s[0] != cls_id:
sub_s = np.insert(sub_s, 0, cls_id) sub_s = np.insert(sub_s, 0, cls_id)
if sub_s[-1] != sep_id: if sub_s[-1] != sep_id:
...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset): ...@@ -108,7 +107,7 @@ class LmSeqsDataset(Dataset):
self.token_ids = self.token_ids[indices] self.token_ids = self.token_ids[indices]
self.lengths = self.lengths[indices] self.lengths = self.lengths[indices]
new_size = len(self) new_size = len(self)
logger.info(f'Remove {init_size - new_size} too short (<=11 tokens) sequences.') logger.info(f"Remove {init_size - new_size} too short (<=11 tokens) sequences.")
def print_statistics(self): def print_statistics(self):
""" """
...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset): ...@@ -116,7 +115,7 @@ class LmSeqsDataset(Dataset):
""" """
if not self.params.is_master: if not self.params.is_master:
return return
logger.info(f'{len(self)} sequences') logger.info(f"{len(self)} sequences")
# data_len = sum(self.lengths) # data_len = sum(self.lengths)
# nb_unique_tokens = len(Counter(list(chain(*self.token_ids)))) # nb_unique_tokens = len(Counter(list(chain(*self.token_ids))))
# logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)') # logger.info(f'{data_len} tokens ({nb_unique_tokens} unique)')
...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset): ...@@ -125,8 +124,7 @@ class LmSeqsDataset(Dataset):
# nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids]) # nb_unkown = sum([(t==unk_idx).sum() for t in self.token_ids])
# logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)') # logger.info(f'{nb_unkown} unknown tokens (covering {100*nb_unkown/data_len:.2f}% of the data)')
def batch_sequences(self, def batch_sequences(self, batch):
batch):
""" """
Do the padding and transform into torch.tensor. Do the padding and transform into torch.tensor.
""" """
...@@ -139,10 +137,10 @@ class LmSeqsDataset(Dataset): ...@@ -139,10 +137,10 @@ class LmSeqsDataset(Dataset):
# Pad token ids # Pad token ids
if self.params.mlm: if self.params.mlm:
pad_idx = self.params.special_tok_ids['pad_token'] pad_idx = self.params.special_tok_ids["pad_token"]
else: else:
pad_idx = self.params.special_tok_ids['unk_token'] pad_idx = self.params.special_tok_ids["unk_token"]
tk_ = [list(t.astype(int)) + [pad_idx]*(max_seq_len_-len(t)) for t in token_ids] tk_ = [list(t.astype(int)) + [pad_idx] * (max_seq_len_ - len(t)) for t in token_ids]
assert len(tk_) == len(token_ids) assert len(tk_) == len(token_ids)
assert all(len(t) == max_seq_len_ for t in tk_) assert all(len(t) == max_seq_len_ for t in tk_)
......
...@@ -23,68 +23,65 @@ import numpy as np ...@@ -23,68 +23,65 @@ import numpy as np
from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer from transformers import BertTokenizer, RobertaTokenizer, GPT2Tokenizer
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def main(): def main():
parser = argparse.ArgumentParser(description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids).") parser = argparse.ArgumentParser(
parser.add_argument('--file_path', type=str, default='data/dump.txt', description="Preprocess the data to avoid re-doing it several times by (tokenization + token_to_ids)."
help='The path to the data.') )
parser.add_argument('--tokenizer_type', type=str, default='bert', choices=['bert', 'roberta', 'gpt2']) parser.add_argument("--file_path", type=str, default="data/dump.txt", help="The path to the data.")
parser.add_argument('--tokenizer_name', type=str, default='bert-base-uncased', parser.add_argument("--tokenizer_type", type=str, default="bert", choices=["bert", "roberta", "gpt2"])
help="The tokenizer to use.") parser.add_argument("--tokenizer_name", type=str, default="bert-base-uncased", help="The tokenizer to use.")
parser.add_argument('--dump_file', type=str, default='data/dump', parser.add_argument("--dump_file", type=str, default="data/dump", help="The dump file prefix.")
help='The dump file prefix.')
args = parser.parse_args() args = parser.parse_args()
logger.info(f"Loading Tokenizer ({args.tokenizer_name})")
logger.info(f'Loading Tokenizer ({args.tokenizer_name})') if args.tokenizer_type == "bert":
if args.tokenizer_type == 'bert':
tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name) tokenizer = BertTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `[CLS]` bos = tokenizer.special_tokens_map["cls_token"] # `[CLS]`
sep = tokenizer.special_tokens_map['sep_token'] # `[SEP]` sep = tokenizer.special_tokens_map["sep_token"] # `[SEP]`
elif args.tokenizer_type == 'roberta': elif args.tokenizer_type == "roberta":
tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name) tokenizer = RobertaTokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['cls_token'] # `<s>` bos = tokenizer.special_tokens_map["cls_token"] # `<s>`
sep = tokenizer.special_tokens_map['sep_token'] # `</s>` sep = tokenizer.special_tokens_map["sep_token"] # `</s>`
elif args.tokenizer_type == 'gpt2': elif args.tokenizer_type == "gpt2":
tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name) tokenizer = GPT2Tokenizer.from_pretrained(args.tokenizer_name)
bos = tokenizer.special_tokens_map['bos_token'] # `<|endoftext|>` bos = tokenizer.special_tokens_map["bos_token"] # `<|endoftext|>`
sep = tokenizer.special_tokens_map['eos_token'] # `<|endoftext|>` sep = tokenizer.special_tokens_map["eos_token"] # `<|endoftext|>`
logger.info(f'Loading text from {args.file_path}') logger.info(f"Loading text from {args.file_path}")
with open(args.file_path, 'r', encoding='utf8') as fp: with open(args.file_path, "r", encoding="utf8") as fp:
data = fp.readlines() data = fp.readlines()
logger.info(f"Start encoding")
logger.info(f'Start encoding') logger.info(f"{len(data)} examples to process.")
logger.info(f'{len(data)} examples to process.')
rslt = [] rslt = []
iter = 0 iter = 0
interval = 10000 interval = 10000
start = time.time() start = time.time()
for text in data: for text in data:
text = f'{bos} {text.strip()} {sep}' text = f"{bos} {text.strip()} {sep}"
token_ids = tokenizer.encode(text, add_special_tokens=False) token_ids = tokenizer.encode(text, add_special_tokens=False)
rslt.append(token_ids) rslt.append(token_ids)
iter += 1 iter += 1
if iter % interval == 0: if iter % interval == 0:
end = time.time() end = time.time()
logger.info(f'{iter} examples processed. - {(end-start)/interval:.2f}s/expl') logger.info(f"{iter} examples processed. - {(end-start)/interval:.2f}s/expl")
start = time.time() start = time.time()
logger.info('Finished binarization') logger.info("Finished binarization")
logger.info(f'{len(data)} examples processed.') logger.info(f"{len(data)} examples processed.")
dp_file = f'{args.dump_file}.{args.tokenizer_name}.pickle' dp_file = f"{args.dump_file}.{args.tokenizer_name}.pickle"
rslt_ = [np.uint16(d) for d in rslt] rslt_ = [np.uint16(d) for d in rslt]
random.shuffle(rslt_) random.shuffle(rslt_)
logger.info(f'Dump to {dp_file}') logger.info(f"Dump to {dp_file}")
with open(dp_file, 'wb') as handle: with open(dp_file, "wb") as handle:
pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(rslt_, handle, protocol=pickle.HIGHEST_PROTOCOL)
......
...@@ -20,70 +20,80 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel ...@@ -20,70 +20,80 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM, GPT2LMHeadModel
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation") parser = argparse.ArgumentParser(
description="Extraction some layers of the full RobertaForMaskedLM or GPT2LMHeadModel for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"]) parser.add_argument("--model_type", default="roberta", choices=["roberta", "gpt2"])
parser.add_argument("--model_name", default='roberta-large', type=str) parser.add_argument("--model_name", default="roberta-large", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_roberta_048131723.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_roberta_048131723.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "roberta":
if args.model_type == 'roberta':
model = RobertaForMaskedLM.from_pretrained(args.model_name) model = RobertaForMaskedLM.from_pretrained(args.model_name)
prefix = 'roberta' prefix = "roberta"
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
model = GPT2LMHeadModel.from_pretrained(args.model_name) model = GPT2LMHeadModel.from_pretrained(args.model_name)
prefix = 'transformer' prefix = "transformer"
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
### Embeddings ### ### Embeddings ###
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for param_name in ['wte.weight', 'wpe.weight']: for param_name in ["wte.weight", "wpe.weight"]:
compressed_sd[f'{prefix}.{param_name}'] = state_dict[f'{prefix}.{param_name}'] compressed_sd[f"{prefix}.{param_name}"] = state_dict[f"{prefix}.{param_name}"]
else: else:
for w in ['word_embeddings', 'position_embeddings', 'token_type_embeddings']: for w in ["word_embeddings", "position_embeddings", "token_type_embeddings"]:
param_name = f'{prefix}.embeddings.{w}.weight' param_name = f"{prefix}.embeddings.{w}.weight"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
param_name = f'{prefix}.embeddings.LayerNorm.{w}' param_name = f"{prefix}.embeddings.LayerNorm.{w}"
compressed_sd[param_name] = state_dict[param_name] compressed_sd[param_name] = state_dict[param_name]
### Transformer Blocks ### ### Transformer Blocks ###
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
if args.model_type == 'gpt2': if args.model_type == "gpt2":
for layer in ['ln_1', 'attn.c_attn', 'attn.c_proj', 'ln_2', 'mlp.c_fc', 'mlp.c_proj']: for layer in ["ln_1", "attn.c_attn", "attn.c_proj", "ln_2", "mlp.c_fc", "mlp.c_proj"]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.h.{std_idx}.{layer}.{w}'] = \ compressed_sd[f"{prefix}.h.{std_idx}.{layer}.{w}"] = state_dict[
state_dict[f'{prefix}.h.{teacher_idx}.{layer}.{w}'] f"{prefix}.h.{teacher_idx}.{layer}.{w}"
compressed_sd[f'{prefix}.h.{std_idx}.attn.bias'] = state_dict[f'{prefix}.h.{teacher_idx}.attn.bias'] ]
compressed_sd[f"{prefix}.h.{std_idx}.attn.bias"] = state_dict[f"{prefix}.h.{teacher_idx}.attn.bias"]
else: else:
for layer in ['attention.self.query', 'attention.self.key', 'attention.self.value', for layer in [
'attention.output.dense', 'attention.output.LayerNorm', "attention.self.query",
'intermediate.dense', 'output.dense', 'output.LayerNorm']: "attention.self.key",
for w in ['weight', 'bias']: "attention.self.value",
compressed_sd[f'{prefix}.encoder.layer.{std_idx}.{layer}.{w}'] = \ "attention.output.dense",
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}'] "attention.output.LayerNorm",
"intermediate.dense",
"output.dense",
"output.LayerNorm",
]:
for w in ["weight", "bias"]:
compressed_sd[f"{prefix}.encoder.layer.{std_idx}.{layer}.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.{layer}.{w}"
]
std_idx += 1 std_idx += 1
### Language Modeling Head ###s ### Language Modeling Head ###s
if args.model_type == 'roberta': if args.model_type == "roberta":
for layer in ['lm_head.decoder.weight', 'lm_head.bias']: for layer in ["lm_head.decoder.weight", "lm_head.bias"]:
compressed_sd[f'{layer}'] = state_dict[f'{layer}'] compressed_sd[f"{layer}"] = state_dict[f"{layer}"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'lm_head.dense.{w}'] = state_dict[f'lm_head.dense.{w}'] compressed_sd[f"lm_head.dense.{w}"] = state_dict[f"lm_head.dense.{w}"]
compressed_sd[f'lm_head.layer_norm.{w}'] = state_dict[f'lm_head.layer_norm.{w}'] compressed_sd[f"lm_head.layer_norm.{w}"] = state_dict[f"lm_head.layer_norm.{w}"]
elif args.model_type == 'gpt2': elif args.model_type == "gpt2":
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'{prefix}.ln_f.{w}'] = state_dict[f'{prefix}.ln_f.{w}'] compressed_sd[f"{prefix}.ln_f.{w}"] = state_dict[f"{prefix}.ln_f.{w}"]
compressed_sd[f'lm_head.weight'] = state_dict[f'lm_head.weight'] compressed_sd[f"lm_head.weight"] = state_dict[f"lm_head.weight"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM ...@@ -20,63 +20,70 @@ from transformers import BertForMaskedLM, RobertaForMaskedLM
import torch import torch
import argparse import argparse
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation") parser = argparse.ArgumentParser(
description="Extraction some layers of the full BertForMaskedLM or RObertaForMaskedLM for Transfer Learned Distillation"
)
parser.add_argument("--model_type", default="bert", choices=["bert"]) parser.add_argument("--model_type", default="bert", choices=["bert"])
parser.add_argument("--model_name", default='bert-base-uncased', type=str) parser.add_argument("--model_name", default="bert-base-uncased", type=str)
parser.add_argument("--dump_checkpoint", default='serialization_dir/tf_bert-base-uncased_0247911.pth', type=str) parser.add_argument("--dump_checkpoint", default="serialization_dir/tf_bert-base-uncased_0247911.pth", type=str)
parser.add_argument("--vocab_transform", action='store_true') parser.add_argument("--vocab_transform", action="store_true")
args = parser.parse_args() args = parser.parse_args()
if args.model_type == "bert":
if args.model_type == 'bert':
model = BertForMaskedLM.from_pretrained(args.model_name) model = BertForMaskedLM.from_pretrained(args.model_name)
prefix = 'bert' prefix = "bert"
else: else:
raise ValueError(f'args.model_type should be "bert".') raise ValueError(f'args.model_type should be "bert".')
state_dict = model.state_dict() state_dict = model.state_dict()
compressed_sd = {} compressed_sd = {}
for w in ['word_embeddings', 'position_embeddings']: for w in ["word_embeddings", "position_embeddings"]:
compressed_sd[f'distilbert.embeddings.{w}.weight'] = \ compressed_sd[f"distilbert.embeddings.{w}.weight"] = state_dict[f"{prefix}.embeddings.{w}.weight"]
state_dict[f'{prefix}.embeddings.{w}.weight'] for w in ["weight", "bias"]:
for w in ['weight', 'bias']: compressed_sd[f"distilbert.embeddings.LayerNorm.{w}"] = state_dict[f"{prefix}.embeddings.LayerNorm.{w}"]
compressed_sd[f'distilbert.embeddings.LayerNorm.{w}'] = \
state_dict[f'{prefix}.embeddings.LayerNorm.{w}']
std_idx = 0 std_idx = 0
for teacher_idx in [0, 2, 4, 7, 9, 11]: for teacher_idx in [0, 2, 4, 7, 9, 11]:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.q_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.self.query.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.k_lin.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.attention.self.key.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.v_lin.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.self.value.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.attention.out_lin.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.attention.output.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.sa_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.attention.output.LayerNorm.{w}"
]
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}'] = \ compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin1.{w}"] = state_dict[
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}'] f"{prefix}.encoder.layer.{teacher_idx}.intermediate.dense.{w}"
compressed_sd[f'distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}'] = \ ]
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}'] compressed_sd[f"distilbert.transformer.layer.{std_idx}.ffn.lin2.{w}"] = state_dict[
compressed_sd[f'distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}'] = \ f"{prefix}.encoder.layer.{teacher_idx}.output.dense.{w}"
state_dict[f'{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}'] ]
compressed_sd[f"distilbert.transformer.layer.{std_idx}.output_layer_norm.{w}"] = state_dict[
f"{prefix}.encoder.layer.{teacher_idx}.output.LayerNorm.{w}"
]
std_idx += 1 std_idx += 1
compressed_sd[f'vocab_projector.weight'] = state_dict[f'cls.predictions.decoder.weight'] compressed_sd[f"vocab_projector.weight"] = state_dict[f"cls.predictions.decoder.weight"]
compressed_sd[f'vocab_projector.bias'] = state_dict[f'cls.predictions.bias'] compressed_sd[f"vocab_projector.bias"] = state_dict[f"cls.predictions.bias"]
if args.vocab_transform: if args.vocab_transform:
for w in ['weight', 'bias']: for w in ["weight", "bias"]:
compressed_sd[f'vocab_transform.{w}'] = state_dict[f'cls.predictions.transform.dense.{w}'] compressed_sd[f"vocab_transform.{w}"] = state_dict[f"cls.predictions.transform.dense.{w}"]
compressed_sd[f'vocab_layer_norm.{w}'] = state_dict[f'cls.predictions.transform.LayerNorm.{w}'] compressed_sd[f"vocab_layer_norm.{w}"] = state_dict[f"cls.predictions.transform.LayerNorm.{w}"]
print(f'N layers selected for distillation: {std_idx}') print(f"N layers selected for distillation: {std_idx}")
print(f'Number of params transfered for distillation: {len(compressed_sd.keys())}') print(f"Number of params transfered for distillation: {len(compressed_sd.keys())}")
print(f'Save transfered checkpoint to {args.dump_checkpoint}.') print(f"Save transfered checkpoint to {args.dump_checkpoint}.")
torch.save(compressed_sd, args.dump_checkpoint) torch.save(compressed_sd, args.dump_checkpoint)
...@@ -20,32 +20,36 @@ import argparse ...@@ -20,32 +20,36 @@ import argparse
import pickle import pickle
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - %(message)s', logging.basicConfig(
datefmt = '%m/%d/%Y %H:%M:%S', format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO
level = logging.INFO) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)") parser = argparse.ArgumentParser(
parser.add_argument("--data_file", type=str, default="data/dump.bert-base-uncased.pickle", description="Token Counts for smoothing the masking probabilities in MLM (cf XLM/word2vec)"
help="The binarized dataset.") )
parser.add_argument("--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", parser.add_argument(
help="The dump file.") "--data_file", type=str, default="data/dump.bert-base-uncased.pickle", help="The binarized dataset."
)
parser.add_argument(
"--token_counts_dump", type=str, default="data/token_counts.bert-base-uncased.pickle", help="The dump file."
)
parser.add_argument("--vocab_size", default=30522, type=int) parser.add_argument("--vocab_size", default=30522, type=int)
args = parser.parse_args() args = parser.parse_args()
logger.info(f'Loading data from {args.data_file}') logger.info(f"Loading data from {args.data_file}")
with open(args.data_file, 'rb') as fp: with open(args.data_file, "rb") as fp:
data = pickle.load(fp) data = pickle.load(fp)
logger.info('Counting occurences for MLM.') logger.info("Counting occurences for MLM.")
counter = Counter() counter = Counter()
for tk_ids in data: for tk_ids in data:
counter.update(tk_ids) counter.update(tk_ids)
counts = [0]*args.vocab_size counts = [0] * args.vocab_size
for k, v in counter.items(): for k, v in counter.items():
counts[k] = v counts[k] = v
logger.info(f'Dump to {args.token_counts_dump}') logger.info(f"Dump to {args.token_counts_dump}")
with open(args.token_counts_dump, 'wb') as handle: with open(args.token_counts_dump, "wb") as handle:
pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL) pickle.dump(counts, handle, protocol=pickle.HIGHEST_PROTOCOL)
This diff is collapsed.
...@@ -23,9 +23,12 @@ import torch ...@@ -23,9 +23,12 @@ import torch
import numpy as np import numpy as np
import logging import logging
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s',
datefmt = '%m/%d/%Y %H:%M:%S', logging.basicConfig(
level = logging.INFO) format="%(asctime)s - %(levelname)s - %(name)s - PID: %(process)d - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -35,12 +38,12 @@ def git_log(folder_path: str): ...@@ -35,12 +38,12 @@ def git_log(folder_path: str):
""" """
repo = git.Repo(search_parent_directories=True) repo = git.Repo(search_parent_directories=True)
repo_infos = { repo_infos = {
'repo_id': str(repo), "repo_id": str(repo),
'repo_sha': str(repo.head.object.hexsha), "repo_sha": str(repo.head.object.hexsha),
'repo_branch': str(repo.active_branch) "repo_branch": str(repo.active_branch),
} }
with open(os.path.join(folder_path, 'git_log.json'), 'w') as f: with open(os.path.join(folder_path, "git_log.json"), "w") as f:
json.dump(repo_infos, f, indent=4) json.dump(repo_infos, f, indent=4)
...@@ -57,21 +60,21 @@ def init_gpu_params(params): ...@@ -57,21 +60,21 @@ def init_gpu_params(params):
assert torch.cuda.is_available() assert torch.cuda.is_available()
logger.info('Initializing GPUs') logger.info("Initializing GPUs")
if params.n_gpu > 1: if params.n_gpu > 1:
assert params.local_rank != -1 assert params.local_rank != -1
params.world_size = int(os.environ['WORLD_SIZE']) params.world_size = int(os.environ["WORLD_SIZE"])
params.n_gpu_per_node = int(os.environ['N_GPU_NODE']) params.n_gpu_per_node = int(os.environ["N_GPU_NODE"])
params.global_rank = int(os.environ['RANK']) params.global_rank = int(os.environ["RANK"])
# number of nodes / node ID # number of nodes / node ID
params.n_nodes = params.world_size // params.n_gpu_per_node params.n_nodes = params.world_size // params.n_gpu_per_node
params.node_id = params.global_rank // params.n_gpu_per_node params.node_id = params.global_rank // params.n_gpu_per_node
params.multi_gpu = True params.multi_gpu = True
assert params.n_nodes == int(os.environ['N_NODES']) assert params.n_nodes == int(os.environ["N_NODES"])
assert params.node_id == int(os.environ['NODE_RANK']) assert params.node_id == int(os.environ["NODE_RANK"])
# local job (single GPU) # local job (single GPU)
else: else:
...@@ -114,8 +117,7 @@ def init_gpu_params(params): ...@@ -114,8 +117,7 @@ def init_gpu_params(params):
if params.multi_gpu: if params.multi_gpu:
logger.info("Initializing PyTorch distributed") logger.info("Initializing PyTorch distributed")
torch.distributed.init_process_group( torch.distributed.init_process_group(
init_method='env://', init_method="env://", backend="nccl",
backend='nccl',
) )
......
This diff is collapsed.
...@@ -25,17 +25,7 @@ import torchvision ...@@ -25,17 +25,7 @@ import torchvision
import torchvision.transforms as transforms import torchvision.transforms as transforms
from torch.utils.data import Dataset from torch.utils.data import Dataset
POOLING_BREAKDOWN = { POOLING_BREAKDOWN = {1: (1, 1), 2: (2, 1), 3: (3, 1), 4: (2, 2), 5: (5, 1), 6: (3, 2), 7: (7, 1), 8: (4, 2), 9: (3, 3)}
1: (1, 1),
2: (2, 1),
3: (3, 1),
4: (2, 2),
5: (5, 1),
6: (3, 2),
7: (7, 1),
8: (4, 2),
9: (3, 3)
}
class ImageEncoder(nn.Module): class ImageEncoder(nn.Module):
...@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module): ...@@ -54,7 +44,6 @@ class ImageEncoder(nn.Module):
return out # BxNx2048 return out # BxNx2048
class JsonlDataset(Dataset): class JsonlDataset(Dataset):
def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length): def __init__(self, data_path, tokenizer, transforms, labels, max_seq_length):
self.data = [json.loads(l) for l in open(data_path)] self.data = [json.loads(l) for l in open(data_path)]
...@@ -72,7 +61,7 @@ class JsonlDataset(Dataset): ...@@ -72,7 +61,7 @@ class JsonlDataset(Dataset):
def __getitem__(self, index): def __getitem__(self, index):
sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True)) sentence = torch.LongTensor(self.tokenizer.encode(self.data[index]["text"], add_special_tokens=True))
start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1] start_token, sentence, end_token = sentence[0], sentence[1:-1], sentence[-1]
sentence = sentence[:self.max_seq_length] sentence = sentence[: self.max_seq_length]
label = torch.zeros(self.n_classes) label = torch.zeros(self.n_classes)
label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1 label[[self.labels.index(tgt) for tgt in self.data[index]["label"]]] = 1
...@@ -80,8 +69,13 @@ class JsonlDataset(Dataset): ...@@ -80,8 +69,13 @@ class JsonlDataset(Dataset):
image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB") image = Image.open(os.path.join(self.data_dir, self.data[index]["img"])).convert("RGB")
image = self.transforms(image) image = self.transforms(image)
return {"image_start_token": start_token, "image_end_token": end_token, return {
"sentence": sentence, "image": image, "label": label} "image_start_token": start_token,
"image_end_token": end_token,
"sentence": sentence,
"image": image,
"label": label,
}
def get_label_frequencies(self): def get_label_frequencies(self):
label_freqs = Counter() label_freqs = Counter()
...@@ -110,10 +104,31 @@ def collate_fn(batch): ...@@ -110,10 +104,31 @@ def collate_fn(batch):
def get_mmimdb_labels(): def get_mmimdb_labels():
return ['Crime', 'Drama', 'Thriller', 'Action', 'Comedy', 'Romance', return [
'Documentary', 'Short', 'Mystery', 'History', 'Family', 'Adventure', "Crime",
'Fantasy', 'Sci-Fi', 'Western', 'Horror', 'Sport', 'War', 'Music', "Drama",
'Musical', 'Animation', 'Biography', 'Film-Noir'] "Thriller",
"Action",
"Comedy",
"Romance",
"Documentary",
"Short",
"Mystery",
"History",
"Family",
"Adventure",
"Fantasy",
"Sci-Fi",
"Western",
"Horror",
"Sport",
"War",
"Music",
"Musical",
"Animation",
"Biography",
"Film-Noir",
]
def get_image_transforms(): def get_image_transforms():
...@@ -122,9 +137,6 @@ def get_image_transforms(): ...@@ -122,9 +137,6 @@ def get_image_transforms():
transforms.Resize(256), transforms.Resize(256),
transforms.CenterCrop(224), transforms.CenterCrop(224),
transforms.ToTensor(), transforms.ToTensor(),
transforms.Normalize( transforms.Normalize(mean=[0.46777044, 0.44531429, 0.40661017], std=[0.12221994, 0.12145835, 0.14380469],),
mean=[0.46777044, 0.44531429, 0.40661017],
std=[0.12221994, 0.12145835, 0.14380469],
),
] ]
) )
import torch import torch
class ClassificationHead(torch.nn.Module): class ClassificationHead(torch.nn.Module):
"""Classification Head for transformer encoders""" """Classification Head for transformer encoders"""
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment