Unverified Commit 8f98aee5 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Update the dataset cache to factor input parameters (#6234)

* Update the dataset cache to factor in parameters from the args.

* Fix linter
parent d6e39ff7
......@@ -98,10 +98,11 @@ def evaluate(model, criterion, data_loader, device):
return metric_logger.acc1.global_avg
def _get_cache_path(filepath):
def _get_cache_path(filepath, args):
import hashlib
h = hashlib.sha1(filepath.encode()).hexdigest()
value = f"{filepath}-{args.clip_len}-{args.kinetics_version}-{args.frame_rate}"
h = hashlib.sha1(value.encode()).hexdigest()
cache_path = os.path.join("~", ".torch", "vision", "datasets", "kinetics", h[:10] + ".pt")
cache_path = os.path.expanduser(cache_path)
return cache_path
......@@ -135,7 +136,7 @@ def main(args):
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
cache_path = _get_cache_path(traindir, args)
transform_train = presets.VideoClassificationPresetTrain(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path):
......@@ -167,7 +168,7 @@ def main(args):
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
cache_path = _get_cache_path(valdir, args)
if args.weights and args.test_only:
weights = torchvision.models.get_weight(args.weights)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment