Unverified Commit 845391cd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enable selective benchmarks (#4960)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 408c9bea
...@@ -24,31 +24,58 @@ from torchvision.prototype import datasets as new_datasets ...@@ -24,31 +24,58 @@ from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import PILToTensor from torchvision.transforms import PILToTensor
def main(name, *, number=5, temp_root=None, num_workers=0): def main(
name,
*,
legacy=True,
new=True,
start=True,
iteration=True,
num_starts=3,
num_samples=10_000,
temp_root=None,
num_workers=0,
):
for benchmark in DATASET_BENCHMARKS: for benchmark in DATASET_BENCHMARKS:
if benchmark.name == name: if benchmark.name == name:
break break
else: else:
raise ValueError(f"No DatasetBenchmark available for dataset '{name}'") raise ValueError(f"No DatasetBenchmark available for dataset '{name}'")
print( if legacy and start:
"legacy", print(
"cold_start", "legacy",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=number), "cold_start",
) Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
print( )
"legacy", print(
"warm_start", "legacy",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=number), "warm_start",
) Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts),
print( )
"legacy",
"iter", if legacy and iteration:
Measurement.iterations_per_time(benchmark.legacy_iteration(temp_root, num_workers=num_workers), number=number), print(
) "legacy",
"iteration",
Measurement.iterations_per_time(
benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples)
),
)
if new and start:
print(
"new",
"cold_start",
Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts),
)
print("new", "cold_start", Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=number)) if new and iteration:
print("new", "iter", Measurement.iterations_per_time(benchmark.new_iter(num_workers=num_workers), number=number)) print(
"new",
"iteration",
Measurement.iterations_per_time(benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)),
)
class DatasetBenchmark: class DatasetBenchmark:
...@@ -91,16 +118,17 @@ class DatasetBenchmark: ...@@ -91,16 +118,17 @@ class DatasetBenchmark:
return fn return fn
def new_iter(self, *, num_workers): def new_iteration(self, *, num_samples, num_workers):
def fn(timer): def fn(timer):
dataset = self.new_dataset(num_workers=num_workers) dataset = self.new_dataset(num_workers=num_workers)
num_samples = 0 num_sample = 0
with timer: with timer:
for _ in dataset: for _ in dataset:
num_samples += 1 num_sample += 1
if num_sample == num_samples:
break
return num_samples return num_sample
return fn return fn
...@@ -155,7 +183,7 @@ class DatasetBenchmark: ...@@ -155,7 +183,7 @@ class DatasetBenchmark:
@contextlib.contextmanager @contextlib.contextmanager
def legacy_root(self, temp_root): def legacy_root(self, temp_root):
new_root = new_datasets.home() / self.name new_root = pathlib.Path(new_datasets.home()) / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root)) legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev: if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
...@@ -198,15 +226,16 @@ class DatasetBenchmark: ...@@ -198,15 +226,16 @@ class DatasetBenchmark:
return fn return fn
def legacy_iteration(self, temp_root, *, num_workers): def legacy_iteration(self, temp_root, *, num_samples, num_workers):
def fn(timer): def fn(timer):
with self.legacy_root(temp_root) as root: with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers) dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer: with timer:
for _ in dataset: for num_sample, _ in enumerate(dataset, 1):
pass if num_sample == num_samples:
break
return len(dataset) return num_sample
return fn return fn
...@@ -261,24 +290,14 @@ class Measurement: ...@@ -261,24 +290,14 @@ class Measurement:
@classmethod @classmethod
def time(cls, fn, *, number): def time(cls, fn, *, number):
results = Measurement._timeit(fn, number=number) results = Measurement._timeit(fn, number=number)
times = torch.tensor(tuple(zip(*results))[1]) times = torch.tensor(tuple(zip(*results))[1])
return cls._format(times, unit="s")
mean, std = Measurement._compute_mean_and_std(times)
# TODO format that into engineering format
return f"{mean:.3g} ± {std:.3g} s"
@classmethod @classmethod
def iterations_per_time(cls, fn, *, number): def iterations_per_time(cls, fn):
outputs, times = zip(*Measurement._timeit(fn, number=number)) num_samples, time = Measurement._timeit(fn, number=1)[0]
iterations_per_second = torch.tensor(num_samples) / torch.tensor(time)
num_samples = outputs[0] return cls._format(iterations_per_second, unit="it/s")
assert all(other_num_samples == num_samples for other_num_samples in outputs[1:])
iterations_per_time = torch.tensor(num_samples) / torch.tensor(times)
mean, std = Measurement._compute_mean_and_std(iterations_per_time)
# TODO format that into engineering format
return f"{mean:.1f} ± {std:.1f} it/s"
class Timer: class Timer:
def __init__(self): def __init__(self):
...@@ -300,7 +319,7 @@ class Measurement: ...@@ -300,7 +319,7 @@ class Measurement:
return self._stop - self._start return self._stop - self._start
@classmethod @classmethod
def _timeit(cls, fn, *, number): def _timeit(cls, fn, number):
results = [] results = []
for _ in range(number): for _ in range(number):
timer = cls.Timer() timer = cls.Timer()
...@@ -308,9 +327,19 @@ class Measurement: ...@@ -308,9 +327,19 @@ class Measurement:
results.append((output, timer.delta)) results.append((output, timer.delta))
return results return results
@classmethod
def _format(cls, measurements, *, unit):
measurements = torch.as_tensor(measurements).to(torch.float64).flatten()
if measurements.numel() == 1:
# TODO format that into engineering format
return f"{float(measurements):.3f} {unit}"
mean, std = Measurement._compute_mean_and_std(measurements)
# TODO format that into engineering format
return f"{mean:.3f} ± {std:.3f} {unit}"
@classmethod @classmethod
def _compute_mean_and_std(cls, t): def _compute_mean_and_std(cls, t):
t = t.flatten()
mean = float(t.mean()) mean = float(t.mean())
std = float(t.std(0, unbiased=t.numel() > 1)) std = float(t.std(0, unbiased=t.numel() > 1))
return mean, std return mean, std
...@@ -476,6 +505,7 @@ DATASET_BENCHMARKS = [ ...@@ -476,6 +505,7 @@ DATASET_BENCHMARKS = [
), ),
), ),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection), DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet),
] ]
...@@ -487,13 +517,51 @@ def parse_args(argv=None): ...@@ -487,13 +517,51 @@ def parse_args(argv=None):
) )
parser.add_argument("name", help="Name of the dataset to benchmark.") parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument( parser.add_argument(
"-n", "-n",
"--number", "--num-starts",
type=int, type=int,
default=5, default=3,
help="Number of iterations of each benchmark.", help="Number of warm and cold starts of each benchmark. Default to 3.",
)
parser.add_argument(
"-N",
"--num-samples",
type=int,
default=10_000,
help="Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000.",
)
parser.add_argument(
"--nl",
"--no-legacy",
dest="legacy",
action="store_false",
help="Skip legacy benchmarks.",
)
parser.add_argument(
"--nn",
"--no-new",
dest="new",
action="store_false",
help="Skip new benchmarks.",
)
parser.add_argument(
"--ns",
"--no-start",
dest="start",
action="store_false",
help="Skip start benchmarks.",
) )
parser.add_argument(
"--ni",
"--no-iteration",
dest="iteration",
action="store_false",
help="Skip iteration benchmarks.",
)
parser.add_argument( parser.add_argument(
"-t", "-t",
"--temp-root", "--temp-root",
...@@ -509,8 +577,8 @@ def parse_args(argv=None): ...@@ -509,8 +577,8 @@ def parse_args(argv=None):
type=int, type=int,
default=0, default=0,
help=( help=(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main process " "Number of subprocesses used to load the data. Setting this to 0 (default) will load all data in the main "
"and thus disable multi-processing." "process and thus disable multi-processing."
), ),
) )
...@@ -521,7 +589,17 @@ if __name__ == "__main__": ...@@ -521,7 +589,17 @@ if __name__ == "__main__":
args = parse_args() args = parse_args()
try: try:
main(args.name, number=args.number, temp_root=args.temp_root, num_workers=args.num_workers) main(
args.name,
legacy=args.legacy,
new=args.new,
start=args.start,
iteration=args.iteration,
num_starts=args.num_starts,
num_samples=args.num_samples,
temp_root=args.temp_root,
num_workers=args.num_workers,
)
except Exception as error: except Exception as error:
msg = str(error) msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr) print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment