Commit b32e7c19 authored by mibaumgartner's avatar mibaumgartner
Browse files

add balanced loader

parent d07c5323
......@@ -14,6 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Don't use these. Next nnDetection Version will introduce better/fixed implementations.
"""
import torch
import torch.nn as nn
......
......@@ -14,6 +14,10 @@ See the License for the specific language governing permissions and
limitations under the License.
"""
"""
Don't use these. Next nnDetection Version will introduce better/fixed implementations.
"""
import torch
import torch.nn as nn
......
......@@ -391,3 +391,50 @@ class DataLoader3DOffset(DataLoader3DFast):
slice(origins[1], origins[1] + self.patch_size_generator[1]),
slice(origins[2], origins[2] + self.patch_size_generator[2]),
]
@DATALOADER_REGISTRY.register
class DataLoader3DBalanced(DataLoader3DOffset):
def build_cache(self) -> Tuple[Dict[int, List[Tuple[str, int]]], List]:
"""
Build up cache for sampling
Returns:
Dict[int, List[Tuple[str, int]]]: foreground cache which contains
of list of tuple of case ids and instance ids for each class
List: background cache (all samples which do not have any
foreground)
"""
fg_cache = defaultdict(list)
logger.info("Building Sampling Cache for Dataloder")
for case_id, item in maybe_verbose_iterable(self._data.items(), desc="Sampling Cache"):
candidates = load_pickle(item['boxes_file'])
if candidates["instances"]:
for instance_id, instance_class in zip(candidates["instances"], candidates["labels"]):
fg_cache[int(instance_class)].append((case_id, instance_id))
return {"fg": fg_cache, "case": list(self._data.keys())}
def select(self) -> Tuple[List, List]:
"""
Foreground sampling: sample uniformly from all the foreground classes
and enforce the respective class while patch sampling.
Background sampling: We jsut sample a random case
"""
selected_classes = np.random.choice(
list(self.cache["fg"].keys()), self.batch_size, replace=True)
selected_cases = []
selected_instances = []
for idx in range(len(selected_classes)):
if idx < round(self.batch_size * (1 - self.oversample_foreground_percent)):
# sample bg / random case
selected_cases.append(np.random.choice(self.cache["case"]))
selected_instances.append(-1)
else:
# sample fg / select an instance
_i = np.random.choice(range(len(self.cache["fg"][selected_classes[idx]])))
_case, _instance_id = self.cache["fg"][selected_classes[idx]][_i]
selected_cases.append(_case)
selected_instances.append(int(_instance_id))
return selected_cases, selected_instances
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment