Commit c84315ec authored by thomwolf's avatar thomwolf
Browse files

model fixes + ipnb fixes

parent 3ff2ec5e
This diff is collapsed.
...@@ -26,6 +26,7 @@ import json ...@@ -26,6 +26,7 @@ import json
import re import re
import tokenization import tokenization
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler from torch.utils.data.distributed import DistributedSampler
...@@ -251,10 +252,9 @@ def main(): ...@@ -251,10 +252,9 @@ def main():
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long) all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long) all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long) all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index) eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
if args.local_rank == -1: if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data) eval_sampler = SequentialSampler(eval_data)
else: else:
...@@ -263,12 +263,11 @@ def main(): ...@@ -263,12 +263,11 @@ def main():
model.eval() model.eval()
with open(args.output_file, "w", encoding='utf-8') as writer: with open(args.output_file, "w", encoding='utf-8') as writer:
for input_ids, input_mask, segment_ids, example_indices in eval_dataloader: for input_ids, input_mask, example_indices in eval_dataloader:
input_ids = input_ids.to(device) input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device) input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
all_encoder_layers, _ = model(input_ids, segment_ids, input_mask) all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
for enc_layers, example_index in zip(all_encoder_layers, example_indices): for enc_layers, example_index in zip(all_encoder_layers, example_indices):
feature = features[example_index.item()] feature = features[example_index.item()]
......
...@@ -377,12 +377,17 @@ class BertModel(nn.Module): ...@@ -377,12 +377,17 @@ class BertModel(nn.Module):
self.encoder = BERTEncoder(config) self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config) self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids, attention_mask): def forward(self, input_ids, token_type_ids=None, attention_mask=None):
# We create 3D attention mask from a 2D tensor mask. # We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length] # Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length] # So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to # It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here # prepare the broadcast here
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * -10000.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment