"llama/vscode:/vscode.git/clone" did not exist on "369de832cdca7680c8f50ba196d39172a895fcad"
Unverified Commit 845391cd authored by Philip Meier's avatar Philip Meier Committed by GitHub
Browse files

enable selective benchmarks (#4960)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 408c9bea
......@@ -24,31 +24,58 @@ from torchvision.prototype import datasets as new_datasets
from torchvision.transforms import PILToTensor
def main(name, *, number=5, temp_root=None, num_workers=0):
def main(
name,
*,
legacy=True,
new=True,
start=True,
iteration=True,
num_starts=3,
num_samples=10_000,
temp_root=None,
num_workers=0,
):
for benchmark in DATASET_BENCHMARKS:
if benchmark.name == name:
break
else:
raise ValueError(f"No DatasetBenchmark available for dataset '{name}'")
if legacy and start:
print(
"legacy",
"cold_start",
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=number),
Measurement.time(benchmark.legacy_cold_start(temp_root, num_workers=num_workers), number=num_starts),
)
print(
"legacy",
"warm_start",
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=number),
Measurement.time(benchmark.legacy_warm_start(temp_root, num_workers=num_workers), number=num_starts),
)
if legacy and iteration:
print(
"legacy",
"iter",
Measurement.iterations_per_time(benchmark.legacy_iteration(temp_root, num_workers=num_workers), number=number),
"iteration",
Measurement.iterations_per_time(
benchmark.legacy_iteration(temp_root, num_workers=num_workers, num_samples=num_samples)
),
)
if new and start:
print(
"new",
"cold_start",
Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=num_starts),
)
print("new", "cold_start", Measurement.time(benchmark.new_cold_start(num_workers=num_workers), number=number))
print("new", "iter", Measurement.iterations_per_time(benchmark.new_iter(num_workers=num_workers), number=number))
if new and iteration:
print(
"new",
"iteration",
Measurement.iterations_per_time(benchmark.new_iteration(num_workers=num_workers, num_samples=num_samples)),
)
class DatasetBenchmark:
......@@ -91,16 +118,17 @@ class DatasetBenchmark:
return fn
def new_iter(self, *, num_workers):
def new_iteration(self, *, num_samples, num_workers):
def fn(timer):
dataset = self.new_dataset(num_workers=num_workers)
num_samples = 0
num_sample = 0
with timer:
for _ in dataset:
num_samples += 1
num_sample += 1
if num_sample == num_samples:
break
return num_samples
return num_sample
return fn
......@@ -155,7 +183,7 @@ class DatasetBenchmark:
@contextlib.contextmanager
def legacy_root(self, temp_root):
new_root = new_datasets.home() / self.name
new_root = pathlib.Path(new_datasets.home()) / self.name
legacy_root = pathlib.Path(tempfile.mkdtemp(dir=temp_root))
if os.stat(new_root).st_dev != os.stat(legacy_root).st_dev:
......@@ -198,15 +226,16 @@ class DatasetBenchmark:
return fn
def legacy_iteration(self, temp_root, *, num_workers):
def legacy_iteration(self, temp_root, *, num_samples, num_workers):
def fn(timer):
with self.legacy_root(temp_root) as root:
dataset = self.legacy_dataset(root, num_workers=num_workers)
with timer:
for _ in dataset:
pass
for num_sample, _ in enumerate(dataset, 1):
if num_sample == num_samples:
break
return len(dataset)
return num_sample
return fn
......@@ -261,24 +290,14 @@ class Measurement:
@classmethod
def time(cls, fn, *, number):
results = Measurement._timeit(fn, number=number)
times = torch.tensor(tuple(zip(*results))[1])
mean, std = Measurement._compute_mean_and_std(times)
# TODO format that into engineering format
return f"{mean:.3g} ± {std:.3g} s"
return cls._format(times, unit="s")
@classmethod
def iterations_per_time(cls, fn, *, number):
outputs, times = zip(*Measurement._timeit(fn, number=number))
num_samples = outputs[0]
assert all(other_num_samples == num_samples for other_num_samples in outputs[1:])
iterations_per_time = torch.tensor(num_samples) / torch.tensor(times)
mean, std = Measurement._compute_mean_and_std(iterations_per_time)
# TODO format that into engineering format
return f"{mean:.1f} ± {std:.1f} it/s"
def iterations_per_time(cls, fn):
num_samples, time = Measurement._timeit(fn, number=1)[0]
iterations_per_second = torch.tensor(num_samples) / torch.tensor(time)
return cls._format(iterations_per_second, unit="it/s")
class Timer:
def __init__(self):
......@@ -300,7 +319,7 @@ class Measurement:
return self._stop - self._start
@classmethod
def _timeit(cls, fn, *, number):
def _timeit(cls, fn, number):
results = []
for _ in range(number):
timer = cls.Timer()
......@@ -308,9 +327,19 @@ class Measurement:
results.append((output, timer.delta))
return results
@classmethod
def _format(cls, measurements, *, unit):
measurements = torch.as_tensor(measurements).to(torch.float64).flatten()
if measurements.numel() == 1:
# TODO format that into engineering format
return f"{float(measurements):.3f} {unit}"
mean, std = Measurement._compute_mean_and_std(measurements)
# TODO format that into engineering format
return f"{mean:.3f} ± {std:.3f} {unit}"
@classmethod
def _compute_mean_and_std(cls, t):
t = t.flatten()
mean = float(t.mean())
std = float(t.std(0, unbiased=t.numel() > 1))
return mean, std
......@@ -476,6 +505,7 @@ DATASET_BENCHMARKS = [
),
),
DatasetBenchmark("voc", legacy_cls=legacy_datasets.VOCDetection),
DatasetBenchmark("imagenet", legacy_cls=legacy_datasets.ImageNet),
]
......@@ -487,13 +517,51 @@ def parse_args(argv=None):
)
parser.add_argument("name", help="Name of the dataset to benchmark.")
parser.add_argument(
"-n",
"--number",
"--num-starts",
type=int,
default=5,
help="Number of iterations of each benchmark.",
default=3,
help="Number of warm and cold starts of each benchmark. Default to 3.",
)
parser.add_argument(
"-N",
"--num-samples",
type=int,
default=10_000,
help="Maximum number of samples to draw during iteration benchmarks. Defaults to 10_000.",
)
parser.add_argument(
"--nl",
"--no-legacy",
dest="legacy",
action="store_false",
help="Skip legacy benchmarks.",
)
parser.add_argument(
"--nn",
"--no-new",
dest="new",
action="store_false",
help="Skip new benchmarks.",
)
parser.add_argument(
"--ns",
"--no-start",
dest="start",
action="store_false",
help="Skip start benchmarks.",
)
parser.add_argument(
"--ni",
"--no-iteration",
dest="iteration",
action="store_false",
help="Skip iteration benchmarks.",
)
parser.add_argument(
"-t",
"--temp-root",
......@@ -509,8 +577,8 @@ def parse_args(argv=None):
type=int,
default=0,
help=(
"Number of subprocesses used to load the data. Setting this to 0 will load all data in the main process "
"and thus disable multi-processing."
"Number of subprocesses used to load the data. Setting this to 0 (default) will load all data in the main "
"process and thus disable multi-processing."
),
)
......@@ -521,7 +589,17 @@ if __name__ == "__main__":
args = parse_args()
try:
main(args.name, number=args.number, temp_root=args.temp_root, num_workers=args.num_workers)
main(
args.name,
legacy=args.legacy,
new=args.new,
start=args.start,
iteration=args.iteration,
num_starts=args.num_starts,
num_samples=args.num_samples,
temp_root=args.temp_root,
num_workers=args.num_workers,
)
except Exception as error:
msg = str(error)
print(msg or f"Unspecified {type(error)} was raised during execution.", file=sys.stderr)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment