Commit 2e311765 authored by ronakice's avatar ronakice
Browse files

fix multi-gpu eval

parent 8aba81a0
...@@ -224,6 +224,10 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -224,6 +224,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu eval
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Num examples = %d", len(eval_dataset))
......
...@@ -300,6 +300,10 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -300,6 +300,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Num examples = %d", len(eval_dataset))
......
...@@ -229,6 +229,10 @@ def evaluate(args, model, tokenizer, prefix="", test=False): ...@@ -229,6 +229,10 @@ def evaluate(args, model, tokenizer, prefix="", test=False):
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Num examples = %d", len(eval_dataset))
......
...@@ -191,6 +191,10 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix="" ...@@ -191,6 +191,10 @@ def evaluate(args, model, tokenizer, labels, pad_token_label_id, mode, prefix=""
eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset) eval_sampler = SequentialSampler(eval_dataset) if args.local_rank == -1 else DistributedSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation %s *****", prefix) logger.info("***** Running evaluation %s *****", prefix)
logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Num examples = %d", len(eval_dataset))
......
...@@ -217,6 +217,10 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -217,6 +217,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset) eval_sampler = SequentialSampler(dataset) if args.local_rank == -1 else DistributedSampler(dataset)
eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size) eval_dataloader = DataLoader(dataset, sampler=eval_sampler, batch_size=args.eval_batch_size)
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
# Eval! # Eval!
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(dataset)) logger.info(" Num examples = %d", len(dataset))
......
...@@ -275,6 +275,10 @@ def evaluate(args, model, tokenizer, prefix=""): ...@@ -275,6 +275,10 @@ def evaluate(args, model, tokenizer, prefix=""):
eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size eval_dataset, sampler=eval_sampler, batch_size=args.eval_batch_size
) )
# multi-gpu evaluate
if args.n_gpu > 1:
model = torch.nn.DataParallel(model)
logger.info("***** Running evaluation {} *****".format(prefix)) logger.info("***** Running evaluation {} *****".format(prefix))
logger.info(" Num examples = %d", len(eval_dataset)) logger.info(" Num examples = %d", len(eval_dataset))
logger.info(" Batch size = %d", args.eval_batch_size) logger.info(" Batch size = %d", args.eval_batch_size)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment