Commit c84315ec authored by thomwolf's avatar thomwolf
Browse files

model fixes + ipnb fixes

parent 3ff2ec5e
This diff is collapsed.
......@@ -26,6 +26,7 @@ import json
import re
import tokenization
import torch
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
......@@ -251,10 +252,9 @@ def main():
all_input_ids = torch.tensor([f.input_ids for f in features], dtype=torch.long)
all_input_mask = torch.tensor([f.input_mask for f in features], dtype=torch.long)
all_segment_ids = torch.tensor([f.segment_ids for f in features], dtype=torch.long)
all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
eval_data = TensorDataset(all_input_ids, all_input_mask, all_example_index)
if args.local_rank == -1:
eval_sampler = SequentialSampler(eval_data)
else:
......@@ -263,12 +263,11 @@ def main():
model.eval()
with open(args.output_file, "w", encoding='utf-8') as writer:
for input_ids, input_mask, segment_ids, example_indices in eval_dataloader:
for input_ids, input_mask, example_indices in eval_dataloader:
input_ids = input_ids.to(device)
input_mask = input_mask.float().to(device)
segment_ids = segment_ids.to(device)
all_encoder_layers, _ = model(input_ids, segment_ids, input_mask)
all_encoder_layers, _ = model(input_ids, token_type_ids=None, attention_mask=input_mask)
for enc_layers, example_index in zip(all_encoder_layers, example_indices):
feature = features[example_index.item()]
......
......@@ -377,12 +377,17 @@ class BertModel(nn.Module):
self.encoder = BERTEncoder(config)
self.pooler = BERTPooler(config)
def forward(self, input_ids, token_type_ids, attention_mask):
def forward(self, input_ids, token_type_ids=None, attention_mask=None):
# We create 3D attention mask from a 2D tensor mask.
# Sizes are [batch_size, 1, 1, from_seq_length]
# So we can broadcast to [batch_size, num_heads, to_seq_length, from_seq_length]
# It's more simple than the triangular masking of causal attention, just need to
# prepare the broadcast here
if attention_mask is None:
attention_mask = torch.ones_like(input_ids)
if token_type_ids is None:
token_type_ids = torch.zeros_like(input_ids)
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
attention_mask = (1.0 - attention_mask) * -10000.0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment