"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "28d16e7ac59960fac54950d87759f78ebd86046e"
Commit 0b7a20c6 authored by thomwolf's avatar thomwolf
Browse files

add tqdm, clean up logging

parent d4e3cf35
...@@ -435,7 +435,6 @@ class BertForSequenceClassification(nn.Module): ...@@ -435,7 +435,6 @@ class BertForSequenceClassification(nn.Module):
def init_weights(m): def init_weights(m):
if isinstance(m, (nn.Linear, nn.Embedding)): if isinstance(m, (nn.Linear, nn.Embedding)):
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal # Slight difference here with the TF version which uses truncated_normal
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(config.initializer_range) m.weight.data.normal_(config.initializer_range)
...@@ -481,7 +480,6 @@ class BertForQuestionAnswering(nn.Module): ...@@ -481,7 +480,6 @@ class BertForQuestionAnswering(nn.Module):
def init_weights(m): def init_weights(m):
if isinstance(m, (nn.Linear, nn.Embedding)): if isinstance(m, (nn.Linear, nn.Embedding)):
print("Initializing {}".format(m))
# Slight difference here with the TF version which uses truncated_normal for initialization # Slight difference here with the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617 # cf https://github.com/pytorch/pytorch/pull/5617
m.weight.data.normal_(config.initializer_range) m.weight.data.normal_(config.initializer_range)
......
...@@ -912,9 +912,9 @@ def main(): ...@@ -912,9 +912,9 @@ def main():
model.eval() model.eval()
all_results = [] all_results = []
logger.info("Start evaulating") logger.info("Start evaluating")
#for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader: #for input_ids, input_mask, segment_ids, label_ids, example_index in eval_dataloader:
for input_ids, input_mask, segment_ids, example_index in eval_dataloader: for input_ids, input_mask, segment_ids, example_index in tqdm(eval_dataloader, descr="Evaluating"):
if len(all_results) % 1000 == 0: if len(all_results) % 1000 == 0:
logger.info("Processing example: %d" % (len(all_results))) logger.info("Processing example: %d" % (len(all_results)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment