Commit d82e5dee authored by thomwolf's avatar thomwolf
Browse files

set find_unused_parameters=True in DDP

parent a59abedf
...@@ -1468,12 +1468,13 @@ python -m torch.distributed.launch --nproc_per_node=8 \ ...@@ -1468,12 +1468,13 @@ python -m torch.distributed.launch --nproc_per_node=8 \
--do_lower_case \ --do_lower_case \
--train_file $SQUAD_DIR/train-v1.1.json \ --train_file $SQUAD_DIR/train-v1.1.json \
--predict_file $SQUAD_DIR/dev-v1.1.json \ --predict_file $SQUAD_DIR/dev-v1.1.json \
--train_batch_size 12 \
--learning_rate 3e-5 \ --learning_rate 3e-5 \
--num_train_epochs 2.0 \ --num_train_epochs 2 \
--max_seq_length 384 \ --max_seq_length 384 \
--doc_stride 128 \ --doc_stride 128 \
--output_dir /tmp/debug_squad/ --output_dir /tmp/debug_squad/ \
--train_batch_size 24 \
--gradient_accumulation_steps 2
``` ```
## Notebooks ## Notebooks
......
...@@ -907,7 +907,10 @@ def main(): ...@@ -907,7 +907,10 @@ def main():
# except ImportError: # except ImportError:
# raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.") # raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) model = torch.nn.parallel.DistributedDataParallel(model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=True)
elif n_gpu > 1: elif n_gpu > 1:
model = torch.nn.DataParallel(model) model = torch.nn.DataParallel(model)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment