"vscode:/vscode.git/clone" did not exist on "97277232432f3188e6d8f5b75f428ada21773254"
Unverified Commit efeb6b1a authored by Thomas Wolf's avatar Thomas Wolf Committed by GitHub
Browse files

Merge branch 'master' into develop

parents dbc318a4 a1126237
......@@ -388,10 +388,10 @@ class BertForSequenceClassification(nn.Module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(config.initializer_range)
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(config.initializer_range)
module.gamma.data.normal_(config.initializer_range)
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
......@@ -438,10 +438,10 @@ class BertForQuestionAnswering(nn.Module):
if isinstance(module, (nn.Linear, nn.Embedding)):
# Slightly different from the TF version which uses truncated_normal for initialization
# cf https://github.com/pytorch/pytorch/pull/5617
module.weight.data.normal_(config.initializer_range)
module.weight.data.normal_(mean=0.0, std=config.initializer_range)
elif isinstance(module, BERTLayerNorm):
module.beta.data.normal_(config.initializer_range)
module.gamma.data.normal_(config.initializer_range)
module.beta.data.normal_(mean=0.0, std=config.initializer_range)
module.gamma.data.normal_(mean=0.0, std=config.initializer_range)
if isinstance(module, nn.Linear):
module.bias.data.zero_()
self.apply(init_weights)
......@@ -459,7 +459,7 @@ class BertForQuestionAnswering(nn.Module):
start_positions = start_positions.squeeze(-1)
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) + 1
ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index)
end_positions.clamp_(0, ignored_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment