Commit 9343a231 authored by thomwolf's avatar thomwolf
Browse files

model training loop working – still have to check that everything is exactly same

parent f690f0e1
...@@ -18,21 +18,17 @@ from __future__ import absolute_import ...@@ -18,21 +18,17 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import collections
import copy import copy
import json import json
import math import math
import re
import six import six
import tensorflow as tf
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
def gelu(x): def gelu(x):
raise NotImplementedError return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0))) # OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class BertConfig(object): class BertConfig(object):
...@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module): ...@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None): def forward(self, input_ids, token_type_ids=None):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1) seq_length = input_ids.size(1)
# TODO finich that position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = torch.range().view(batch_size, seq_length) position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None: if token_type_ids is None:
token_type_ids = torch.zeros(batch_size, seq_length) token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids) words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids) position_embeddings = self.position_embeddings(position_ids)
...@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module): ...@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
# TODO clean up this (precompute) # TODO clean up this (precompute)
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights # MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# `attention_mask` = [B, 1, F, T] # `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1]) # attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions. # positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - attention_mask) * -10000.0 # adder = (1.0 - attention_mask) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is # Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely. # effectively the same as removing these entirely.
attention_scores += adder attention_scores += attention_mask
# Normalize the attention scores to probabilities. # Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T] # `attention_probs` = [B, N, F, T]
...@@ -289,7 +284,7 @@ class BERTOutput(nn.Module): ...@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob) self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor): def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(input_tensor) hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states) hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor) hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states return hidden_states
...@@ -390,6 +385,14 @@ class BertModel(nn.Module): ...@@ -390,6 +385,14 @@ class BertModel(nn.Module):
self.pooler = BERTPooler(config) self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids, attention_mask): def forward(self, input_ids, token_type_ids, attention_mask):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids) embedding_output = self.embeddings(input_ids, token_type_ids)
all_encoder_layers = self.encoder(embedding_output, attention_mask) all_encoder_layers = self.encoder(embedding_output, attention_mask)
sequence_output = all_encoder_layers[-1] sequence_output = all_encoder_layers[-1]
...@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module): ...@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
self.classifier = nn.Linear(config.hidden_size, num_labels) self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m): def init_weights(m):
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding: if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
print("Initializing {}".format(m)) print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal # Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
m.weight.normal_(config.initializer_range) m.weight.data.normal_(config.initializer_range)
self.apply(init_weights) self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None): def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
......
...@@ -484,7 +484,7 @@ def main(): ...@@ -484,7 +484,7 @@ def main():
num_train_steps = int( num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs) len(train_examples) / args.train_batch_size * args.num_train_epochs)
model = BertForSequenceClassification(bert_config) model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None: if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu')) model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device) model.to(device)
...@@ -504,10 +504,10 @@ def main(): ...@@ -504,10 +504,10 @@ def main():
logger.info(" Batch size = %d", args.train_batch_size) logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps) logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long) all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long) all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long) all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long) all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: if args.local_rank == -1:
...@@ -519,12 +519,12 @@ def main(): ...@@ -519,12 +519,12 @@ def main():
model.train() model.train()
global_step = 0 global_step = 0
for input_ids, input_mask, segment_ids, label_ids in train_dataloader: for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids.to(device) input_ids = input_ids.to(device)
input_mask.to(device) input_mask = input_mask.float().to(device)
segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids.to(device) label_ids = label_ids.to(device)
loss = model(input_ids, segment_ids, input_mask, label_ids) loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward() loss.backward()
optimizer.step() optimizer.step()
global_step += 1 global_step += 1
...@@ -538,10 +538,10 @@ def main(): ...@@ -538,10 +538,10 @@ def main():
logger.info(" Num examples = %d", len(eval_examples)) logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long) all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long) all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long) all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long) all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids) eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1: if args.local_rank == -1:
...@@ -554,10 +554,10 @@ def main(): ...@@ -554,10 +554,10 @@ def main():
eval_loss = 0 eval_loss = 0
eval_accuracy = 0 eval_accuracy = 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader: for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids.to(device) input_ids = input_ids.to(device)
input_mask.to(device) input_mask = input_mask.float().to(device)
segment_ids.to(device) segment_ids = segment_ids.to(device)
label_ids.to(device) label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids) tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_accuracy = accuracy(logits, label_ids) tmp_eval_accuracy = accuracy(logits, label_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment