Commit 2c5d993b authored by thomwolf's avatar thomwolf
Browse files

update readme - fix SQuAD model on multi-GPU

parent 4850ec58
...@@ -194,3 +194,8 @@ python run_squad.py \ ...@@ -194,3 +194,8 @@ python run_squad.py \
--doc_stride 128 \ --doc_stride 128 \
--output_dir ../debug_squad/ --output_dir ../debug_squad/
``` ```
Training with the previous hyper-parameters and a batch size 32 (on 4 GPUs) for 2 epochs gave us the following results:
```bash
{"f1": 88.19829549714827, "exact_match": 80.75685903500474}
```
...@@ -455,9 +455,11 @@ class BertForQuestionAnswering(nn.Module): ...@@ -455,9 +455,11 @@ class BertForQuestionAnswering(nn.Module):
end_logits = end_logits.squeeze(-1) end_logits = end_logits.squeeze(-1)
if start_positions is not None and end_positions is not None: if start_positions is not None and end_positions is not None:
# If we are on multi-GPU, split add a dimension - if not this is a no-op # If we are on multi-GPU, split add a dimension
start_positions = start_positions.squeeze(-1) if len(start_positions.size()) > 1:
end_positions = end_positions.squeeze(-1) start_positions = start_positions.squeeze(-1)
if len(end_positions.size()) > 1:
end_positions = end_positions.squeeze(-1)
# sometimes the start/end positions are outside our model inputs, we ignore these terms # sometimes the start/end positions are outside our model inputs, we ignore these terms
ignored_index = start_logits.size(1) ignored_index = start_logits.size(1)
start_positions.clamp_(0, ignored_index) start_positions.clamp_(0, ignored_index)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment