Commit 1017eb4f authored by mibaumgartner's avatar mibaumgartner
Browse files

add compose

parent 31d06041
from nndet.io.transforms.base import AbstractTransform
from nndet.io.transforms.base import (
AbstractTransform,
Compose,
)
from nndet.io.transforms.instances import (
Instances2Boxes,
Instances2Segmentation,
......
from typing import Any
from typing import Any, Sequence
import torch
......@@ -30,4 +30,27 @@ class AbstractTransform(torch.nn.Module):
context = torch.no_grad()
with context:
return super().__call__(*args, **kwargs)
\ No newline at end of file
return super().__call__(*args, **kwargs)
class Compose(AbstractTransform):
def __init__(self, *transforms):
"""
Compose multiple transforms to one
Args:
transforms: transformations to compose
"""
super().__init__(grad=False)
if len(transforms) == 1 and isinstance(transforms[0], Sequence):
transforms = transforms[0]
self.transforms = torch.nn.ModuleList(list(transforms))
def forward(self, **batch):
"""
Augment batch
"""
for t in self.transforms:
batch = t(**batch)
return batch
......@@ -62,8 +62,12 @@ from nndet.inference.helper import predict_dir
from nndet.inference.ensembler.segmentation import SegmentationEnsembler
from nndet.inference.ensembler.detection import BoxEnsemblerSelective
from rising.transforms import Compose
from nndet.io.transforms import Instances2Boxes, Instances2Segmentation, FindInstances
from nndet.io.transforms import (
Compose,
Instances2Boxes,
Instances2Segmentation,
FindInstances,
)
class RetinaUNetModule(LightningBaseModuleSWA):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment