Commit 9343a231 authored by thomwolf's avatar thomwolf
Browse files

model training loop working – still have to check that everything is exactly same

parent f690f0e1
......@@ -18,21 +18,17 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import copy
import json
import math
import re
import six
import tensorflow as tf
import torch
import torch.nn as nn
from torch.nn import CrossEntropyLoss
def gelu(x):
raise NotImplementedError
# TF BERT says: cdf = 0.5 * (1.0 + tf.erf(input_tensor / tf.sqrt(2.0)))
return 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
return 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))
# OpenAI GPT gelu version was : 0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))
class BertConfig(object):
......@@ -152,12 +148,11 @@ class BERTEmbeddings(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, input_ids, token_type_ids=None):
batch_size = input_ids.size(0)
seq_length = input_ids.size(1)
# TODO finich that
position_ids = torch.range().view(batch_size, seq_length)
position_ids = torch.arange(seq_length, dtype=torch.long, device=input_ids.device)
position_ids = position_ids.unsqueeze(0).expand_as(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros(batch_size, seq_length)
token_type_ids = torch.zeros_like(input_ids)
words_embeddings = self.word_embeddings(input_ids)
position_embeddings = self.position_embeddings(position_ids)
......@@ -218,14 +213,14 @@ class BERTSelfAttention(nn.Module):
# TODO clean up this (precompute)
# MY PYTORCH: w = w * self.b + -1e9 * (1 - self.b) # TF implem method: mask_attn_weights
# `attention_mask` = [B, 1, F, T]
attention_mask = tf.expand_dims(attention_mask, axis=[1])
# attention_mask = tf.expand_dims(attention_mask, axis=[1])
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for
# positions we want to attend and -10000.0 for masked positions.
adder = (1.0 - attention_mask) * -10000.0
# adder = (1.0 - attention_mask) * -10000.0
# Since we are adding it to the raw scores before the softmax, this is
# effectively the same as removing these entirely.
attention_scores += adder
attention_scores += attention_mask
# Normalize the attention scores to probabilities.
# `attention_probs` = [B, N, F, T]
......@@ -289,7 +284,7 @@ class BERTOutput(nn.Module):
self.dropout = nn.Dropout(config.hidden_dropout_prob)
def forward(self, hidden_states, input_tensor):
hidden_states = self.dense(input_tensor)
hidden_states = self.dense(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.LayerNorm(hidden_states + input_tensor)
return hidden_states
......@@ -390,6 +385,14 @@ class BertModel(nn.Module):
self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids, attention_mask):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
embedding_output = self.embeddings(input_ids, token_type_ids)
all_encoder_layers = self.encoder(embedding_output, attention_mask)
sequence_output = all_encoder_layers[-1]
......@@ -404,11 +407,11 @@ class BertForSequenceClassification(nn.Module):
self.classifier = nn.Linear(config.hidden_size, num_labels)
def init_weights(m):
if isinstance(m) == nn.Linear or isinstance(m) == nn.Embedding:
if isinstance(m, nn.Linear) or isinstance(m, nn.Embedding):
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617
m.weight.normal_(config.initializer_range)
m.weight.data.normal_(config.initializer_range)
self.apply(init_weights)
def forward(self, input_ids, token_type_ids, attention_mask, labels=None):
......
......@@ -484,7 +484,7 @@ def main():
num_train_steps = int(
len(train_examples) / args.train_batch_size * args.num_train_epochs)
model = BertForSequenceClassification(bert_config)
model = BertForSequenceClassification(bert_config, len(label_list))
if args.init_checkpoint is not None:
model.bert.load_state_dict(torch.load(args.init_checkpoint, map_location='cpu'))
model.to(device)
......@@ -504,10 +504,10 @@ def main():
logger.info(" Batch size = %d", args.train_batch_size)
logger.info(" Num steps = %d", num_train_steps)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.Long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.Long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.Long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.Long)
all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
......@@ -519,12 +519,12 @@ def main():
model.train()
global_step = 0
for input_ids, input_mask, segment_ids, label_ids in train_dataloader:
input_ids.to(device)
input_mask.to(device)
segment_ids.to(device)
label_ids.to(device)
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
loss = model(input_ids, segment_ids, input_mask, label_ids)
loss, _ = model(input_ids, segment_ids, input_mask, label_ids)
loss.backward()
optimizer.step()
global_step += 1
......@@ -538,10 +538,10 @@ def main():
logger.info(" Num examples = %d", len(eval_examples))
logger.info(" Batch size = %d", args.eval_batch_size)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.Long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.Long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.Long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.Long)
all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
if args.local_rank == -1:
......@@ -554,10 +554,10 @@ def main():
eval_loss = 0
eval_accuracy = 0
for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
input_ids.to(device)
input_mask.to(device)
segment_ids.to(device)
label_ids.to(device)
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
label_ids = label_ids.to(device)
tmp_eval_loss, logits = model(input_ids, segment_ids, input_mask, label_ids)
tmp_eval_accuracy = accuracy(logits, label_ids)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment