Unverified Commit bdfb7901 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Refactor Augmentation Space calls to speed up. (#5402)

parent 98a5f3ad
......@@ -268,9 +268,9 @@ class AutoAugment(torch.nn.Module):
transform_id, probs, signs = self.get_params(len(self.policies))
op_meta = self._augmentation_space(10, F.get_image_size(img))
for i, (op_name, p, magnitude_id) in enumerate(self.policies[transform_id]):
if probs[i] <= p:
op_meta = self._augmentation_space(10, F.get_image_size(img))
magnitudes, signed = op_meta[op_name]
magnitude = float(magnitudes[magnitude_id].item()) if magnitude_id is not None else 0.0
if signed and signs[i] == 0:
......@@ -350,8 +350,8 @@ class RandAugment(torch.nn.Module):
elif fill is not None:
fill = [float(f) for f in fill]
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
for _ in range(self.num_ops):
op_meta = self._augmentation_space(self.num_magnitude_bins, F.get_image_size(img))
op_index = int(torch.randint(len(op_meta), (1,)).item())
op_name = list(op_meta.keys())[op_index]
magnitudes, signed = op_meta[op_name]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment