Commit e5b63fb5 authored by Ananya Harsh Jha's avatar Ananya Harsh Jha
Browse files

Merge branch 'master' of https://github.com/ananyahjha93/pytorch-pretrained-BERT

pull current master to local
parents 8a4e90ff f3e54048
...@@ -857,7 +857,6 @@ def main(): ...@@ -857,7 +857,6 @@ def main():
optimizer.zero_grad() optimizer.zero_grad()
global_step += 1 global_step += 1
if args.do_train:
# Save a trained model and the associated configuration # Save a trained model and the associated configuration
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME) output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
......
...@@ -471,7 +471,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -471,7 +471,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
prelim_predictions = [] prelim_predictions = []
# keep track of the minimum score of null start+end of position 0 # keep track of the minimum score of null start+end of position 0
score_null = 1000000 # large and positive score_null = 1000000 # large and positive
min_null_feature_index = 0 # the paragraph slice with min mull score min_null_feature_index = 0 # the paragraph slice with min null score
null_start_logit = 0 # the start logit at the slice with min null score null_start_logit = 0 # the start logit at the slice with min null score
null_end_logit = 0 # the end logit at the slice with min null score null_end_logit = 0 # the end logit at the slice with min null score
for (feature_index, feature) in enumerate(features): for (feature_index, feature) in enumerate(features):
...@@ -620,7 +620,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size, ...@@ -620,7 +620,7 @@ def write_predictions(all_examples, all_features, all_results, n_best_size,
all_predictions[example.qas_id] = "" all_predictions[example.qas_id] = ""
else: else:
all_predictions[example.qas_id] = best_non_null_entry.text all_predictions[example.qas_id] = best_non_null_entry.text
all_nbest_json[example.qas_id] = nbest_json all_nbest_json[example.qas_id] = nbest_json
with open(output_prediction_file, "w") as writer: with open(output_prediction_file, "w") as writer:
writer.write(json.dumps(all_predictions, indent=4) + "\n") writer.write(json.dumps(all_predictions, indent=4) + "\n")
...@@ -657,8 +657,8 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False): ...@@ -657,8 +657,8 @@ def get_final_text(pred_text, orig_text, do_lower_case, verbose_logging=False):
# #
# What we really want to return is "Steve Smith". # What we really want to return is "Steve Smith".
# #
# Therefore, we have to apply a semi-complicated alignment heruistic between # Therefore, we have to apply a semi-complicated alignment heuristic between
# `pred_text` and `orig_text` to get a character-to-charcter alignment. This # `pred_text` and `orig_text` to get a character-to-character alignment. This
# can fail in certain cases in which case we just return `orig_text`. # can fail in certain cases in which case we just return `orig_text`.
def _strip_spaces(text): def _strip_spaces(text):
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
"""BERT finetuning runner.""" """BERT finetuning runner."""
from __future__ import absolute_import
import argparse import argparse
import csv import csv
import logging import logging
...@@ -31,7 +33,7 @@ from torch.utils.data.distributed import DistributedSampler ...@@ -31,7 +33,7 @@ from torch.utils.data.distributed import DistributedSampler
from tqdm import tqdm, trange from tqdm import tqdm, trange
from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE from pytorch_pretrained_bert.file_utils import PYTORCH_PRETRAINED_BERT_CACHE
from pytorch_pretrained_bert.modeling import BertForMultipleChoice from pytorch_pretrained_bert.modeling import (BertForMultipleChoice, BertConfig, WEIGHTS_NAME, CONFIG_NAME)
from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear from pytorch_pretrained_bert.optimization import BertAdam, warmup_linear
from pytorch_pretrained_bert.tokenization import BertTokenizer from pytorch_pretrained_bert.tokenization import BertTokenizer
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch OpenAI GPT-2 model.""" """PyTorch OpenAI GPT-2 model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import copy import copy
import json import json
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
# limitations under the License. # limitations under the License.
"""PyTorch OpenAI GPT model.""" """PyTorch OpenAI GPT model."""
from __future__ import absolute_import, division, print_function, unicode_literals
import collections import collections
import copy import copy
import json import json
......
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
""" """
from __future__ import absolute_import, division, print_function, unicode_literals
import os import os
import copy import copy
import json import json
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment