Unverified Commit 5b81c05c authored by Nicolas Hug's avatar Nicolas Hug Committed by GitHub
Browse files

Reduce variance of classification references evaluation (#4609)

parent 23f413c2
import datetime import datetime
import os import os
import time import time
import warnings
import presets import presets
import torch import torch
...@@ -54,6 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" ...@@ -54,6 +55,8 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
model.eval() model.eval()
metric_logger = utils.MetricLogger(delimiter=" ") metric_logger = utils.MetricLogger(delimiter=" ")
header = f"Test: {log_suffix}" header = f"Test: {log_suffix}"
num_processed_samples = 0
with torch.no_grad(): with torch.no_grad():
for image, target in metric_logger.log_every(data_loader, print_freq, header): for image, target in metric_logger.log_every(data_loader, print_freq, header):
image = image.to(device, non_blocking=True) image = image.to(device, non_blocking=True)
...@@ -68,7 +71,23 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix=" ...@@ -68,7 +71,23 @@ def evaluate(model, criterion, data_loader, device, print_freq=100, log_suffix="
metric_logger.update(loss=loss.item()) metric_logger.update(loss=loss.item())
metric_logger.meters["acc1"].update(acc1.item(), n=batch_size) metric_logger.meters["acc1"].update(acc1.item(), n=batch_size)
metric_logger.meters["acc5"].update(acc5.item(), n=batch_size) metric_logger.meters["acc5"].update(acc5.item(), n=batch_size)
num_processed_samples += batch_size
# gather the stats from all processes # gather the stats from all processes
num_processed_samples = utils.reduce_across_processes(num_processed_samples)
if (
hasattr(data_loader.dataset, "__len__")
and len(data_loader.dataset) != num_processed_samples
and torch.distributed.get_rank() == 0
):
# See FIXME above
warnings.warn(
f"It looks like the dataset has {len(data_loader.dataset)} samples, but {num_processed_samples} "
"samples were used for the validation, which might bias the results. "
"Try adjusting the batch size and / or the world size. "
"Setting the world size to 1 is always a safe bet."
)
metric_logger.synchronize_between_processes() metric_logger.synchronize_between_processes()
print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}") print(f"{header} Acc@1 {metric_logger.acc1.global_avg:.3f} Acc@5 {metric_logger.acc5.global_avg:.3f}")
...@@ -147,7 +166,7 @@ def load_data(traindir, valdir, args): ...@@ -147,7 +166,7 @@ def load_data(traindir, valdir, args):
print("Creating data loaders") print("Creating data loaders")
if args.distributed: if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset) train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test) test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test, shuffle=False)
else: else:
train_sampler = torch.utils.data.RandomSampler(dataset) train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test) test_sampler = torch.utils.data.SequentialSampler(dataset_test)
...@@ -164,6 +183,10 @@ def main(args): ...@@ -164,6 +183,10 @@ def main(args):
device = torch.device(args.device) device = torch.device(args.device)
if args.use_deterministic_algorithms:
torch.backends.cudnn.benchmark = False
torch.use_deterministic_algorithms(True)
else:
torch.backends.cudnn.benchmark = True torch.backends.cudnn.benchmark = True
train_dir = os.path.join(args.data_path, "train") train_dir = os.path.join(args.data_path, "train")
...@@ -277,6 +300,10 @@ def main(args): ...@@ -277,6 +300,10 @@ def main(args):
model_ema.load_state_dict(checkpoint["model_ema"]) model_ema.load_state_dict(checkpoint["model_ema"])
if args.test_only: if args.test_only:
# We disable the cudnn benchmarking because it can noticeably affect the accuracy
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
evaluate(model, criterion, data_loader_test, device=device) evaluate(model, criterion, data_loader_test, device=device)
return return
...@@ -394,6 +421,9 @@ def get_args_parser(add_help=True): ...@@ -394,6 +421,9 @@ def get_args_parser(add_help=True):
default=0.9, default=0.9,
help="decay factor for Exponential Moving Average of model parameters(default: 0.9)", help="decay factor for Exponential Moving Average of model parameters(default: 0.9)",
) )
parser.add_argument(
"--use-deterministic-algorithms", action="store_true", help="Forces the use of deterministic algorithms only."
)
return parser return parser
......
...@@ -32,11 +32,7 @@ class SmoothedValue(object): ...@@ -32,11 +32,7 @@ class SmoothedValue(object):
""" """
Warning: does not synchronize the deque! Warning: does not synchronize the deque!
""" """
if not is_dist_avail_and_initialized(): t = reduce_across_processes([self.count, self.total])
return
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
dist.barrier()
dist.all_reduce(t)
t = t.tolist() t = t.tolist()
self.count = int(t[0]) self.count = int(t[0])
self.total = t[1] self.total = t[1]
...@@ -400,3 +396,12 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T ...@@ -400,3 +396,12 @@ def store_model_weights(model, checkpoint_path, checkpoint_key="model", strict=T
os.replace(tmp_path, output_path) os.replace(tmp_path, output_path)
return output_path return output_path
def reduce_across_processes(val):
if not is_dist_avail_and_initialized():
return val
t = torch.tensor(val, device="cuda")
dist.barrier()
dist.all_reduce(t)
return t
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment