Unverified Commit 8f98aee5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update the dataset cache to factor input parameters (#6234)

* Update the dataset cache to factor in parameters from the args.

* Fix linter
parent d6e39ff7
...@@ -98,10 +98,11 @@ def evaluate(model, criterion, data_loader, device): ...@@ -98,10 +98,11 @@ def evaluate(model, criterion, data_loader, device):
return metric_logger.acc1.global_avg return metric_logger.acc1.global_avg
def _get_cache_path(filepath): def _get_cache_path(filepath, args):
import hashlib import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest() value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}"
h = hashlib.sha1(value.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt") cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path) cache_path = os.path.expanduser(cache_path)
return cache_path return cache_path
...@@ -135,7 +136,7 @@ def main(args): ...@@ -135,7 +136,7 @@ def main(args):
print("Loading training data") print("Loading training data")
st = time.time() st = time.time()
cache_path = _get_cache_path(traindir) cache_path = _get_cache_path(traindir, args)
transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171)) transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path): if args.cache_dataset and os.path.exists(cache_path):
...@@ -167,7 +168,7 @@ def main(args): ...@@ -167,7 +168,7 @@ def main(args):
print("Took", time.time() - st) print("Took", time.time() - st)
print("Loading validation data") print("Loading validation data")
cache_path = _get_cache_path(valdir) cache_path = _get_cache_path(valdir, args)
if args.weights and args.test_only: if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights) weights = torchvision.models.get_weight(args.weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment