Commit 1a57a76b authored by plyfager's avatar plyfager Committed by zhouzaida
Browse files

complete repr functions

parent f90567a0
...@@ -610,10 +610,10 @@ class CenterCrop(BaseTransform): ...@@ -610,10 +610,10 @@ class CenterCrop(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', crop_size = {self.crop_size}' repr_str += f'(crop_size = {self.crop_size}'
repr_str += f', auto_pad={self.auto_pad}' repr_str += f', auto_pad={self.auto_pad}'
repr_str += f', pad_cfg={self.pad_cfg}' repr_str += f', pad_cfg={self.pad_cfg}'
repr_str += f',clip_object_border = {self.clip_object_border}' repr_str += f',clip_object_border = {self.clip_object_border})'
return repr_str return repr_str
...@@ -705,10 +705,10 @@ class RandomGrayscale(BaseTransform): ...@@ -705,10 +705,10 @@ class RandomGrayscale(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', prob = {self.prob}' repr_str += f'(prob = {self.prob}'
repr_str += f', keep_channels = {self.keep_channels}' repr_str += f', keep_channels = {self.keep_channels}'
repr_str += f', channel_weights = {self.channel_weights}' repr_str += f', channel_weights = {self.channel_weights}'
repr_str += f', color_format = {self.color_format}' repr_str += f', color_format = {self.color_format})'
return repr_str return repr_str
...@@ -861,10 +861,10 @@ class MultiScaleFlipAug(BaseTransform): ...@@ -861,10 +861,10 @@ class MultiScaleFlipAug(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', transforms={self.transforms}' repr_str += f'(transforms={self.transforms}'
repr_str += f', scales={self.scales}' repr_str += f', scales={self.scales}'
repr_str += f', allow_flip={self.allow_flip}' repr_str += f', allow_flip={self.allow_flip}'
repr_str += f', flip_direction={self.flip_direction}' repr_str += f', flip_direction={self.flip_direction})'
return repr_str return repr_str
...@@ -975,8 +975,8 @@ class RandomChoiceResize(BaseTransform): ...@@ -975,8 +975,8 @@ class RandomChoiceResize(BaseTransform):
def __repr__(self) -> str: def __repr__(self) -> str:
repr_str = self.__class__.__name__ repr_str = self.__class__.__name__
repr_str += f', scales={self.scales}' repr_str += f'(scales={self.scales}'
repr_str += f', resize_cfg={self.resize_cfg}' repr_str += f', resize_cfg={self.resize_cfg})'
return repr_str return repr_str
......
...@@ -339,6 +339,15 @@ class KeyMapper(BaseTransform): ...@@ -339,6 +339,15 @@ class KeyMapper(BaseTransform):
results.update(outputs) # type: ignore results.update(outputs) # type: ignore
return results return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(transforms = {self.transforms}'
repr_str += f', mapping = {self.mapping}'
repr_str += f', remapping = {self.remapping}'
repr_str += f', auto_remap = {self.auto_remap}'
repr_str += f', allow_nonexist_keys = {self.allow_nonexist_keys})'
return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class TransformBroadcaster(KeyMapper): class TransformBroadcaster(KeyMapper):
...@@ -518,6 +527,16 @@ class TransformBroadcaster(KeyMapper): ...@@ -518,6 +527,16 @@ class TransformBroadcaster(KeyMapper):
results.update(outputs) results.update(outputs)
return results return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(transforms = {self.transforms}'
repr_str += f', mapping = {self.mapping}'
repr_str += f', remapping = {self.remapping}'
repr_str += f', auto_remap = {self.auto_remap}'
repr_str += f', allow_nonexist_keys = {self.allow_nonexist_keys}'
repr_str += f', share_random_params = {self.share_random_params})'
return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomChoice(BaseTransform): class RandomChoice(BaseTransform):
...@@ -573,6 +592,12 @@ class RandomChoice(BaseTransform): ...@@ -573,6 +592,12 @@ class RandomChoice(BaseTransform):
idx = self.random_pipeline_index() idx = self.random_pipeline_index()
return self.transforms[idx](results) return self.transforms[idx](results)
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(transforms = {self.transforms}'
repr_str += f'prob = {self.prob})'
return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomApply(BaseTransform): class RandomApply(BaseTransform):
...@@ -615,3 +640,9 @@ class RandomApply(BaseTransform): ...@@ -615,3 +640,9 @@ class RandomApply(BaseTransform):
return self.transforms(results) # type: ignore return self.transforms(results) # type: ignore
else: else:
return results return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'(transforms = {self.transforms}'
repr_str += f', prob = {self.prob})'
return repr_str
...@@ -41,6 +41,11 @@ class AddToValue(BaseTransform): ...@@ -41,6 +41,11 @@ class AddToValue(BaseTransform):
def transform(self, results): def transform(self, results):
return self.add(results, self.addend) return self.add(results, self.addend)
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'addend = {self.addend}'
return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class RandomAddToValue(AddToValue): class RandomAddToValue(AddToValue):
...@@ -59,6 +64,11 @@ class RandomAddToValue(AddToValue): ...@@ -59,6 +64,11 @@ class RandomAddToValue(AddToValue):
results = self.add(results, addend=self.get_random_addend()) results = self.add(results, addend=self.get_random_addend())
return results return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
repr_str += f'repeat = {self.repeat}'
return repr_str
@TRANSFORMS.register_module() @TRANSFORMS.register_module()
class SumTwoValues(BaseTransform): class SumTwoValues(BaseTransform):
...@@ -75,6 +85,10 @@ class SumTwoValues(BaseTransform): ...@@ -75,6 +85,10 @@ class SumTwoValues(BaseTransform):
results['sum'] = np.nan results['sum'] = np.nan
return results return results
def __repr__(self) -> str:
repr_str = self.__class__.__name__
return repr_str
def test_compose(): def test_compose():
...@@ -296,7 +310,11 @@ def test_key_mapper(): ...@@ -296,7 +310,11 @@ def test_key_mapper():
pass pass
# __repr__ # __repr__
_ = str(pipeline) assert repr(pipeline) == (
'KeyMapper(transforms = Compose(\n ' + 'AddToValueaddend = 1' +
'\n), mapping = {\'value\': \'v_in\'}, ' +
'remapping = {\'value\': \'v_out\'}, auto_remap = False, ' +
'allow_nonexist_keys = False)')
def test_transform_broadcaster(): def test_transform_broadcaster():
...@@ -394,7 +412,12 @@ def test_transform_broadcaster(): ...@@ -394,7 +412,12 @@ def test_transform_broadcaster():
assert 'b' not in results assert 'b' not in results
# Test repr # Test repr
_ = str(pipeline) assert repr(pipeline) == (
'TransformBroadcaster(transforms = Compose(\n' + ' SumTwoValues' +
'\n), mapping = {\'num_1\': [\'a_1\', \'b_1\'], ' +
'\'num_2\': [\'a_2\', \'b_2\']}, ' +
'remapping = {\'sum\': [\'a\', Ellipsis]}, auto_remap = False, ' +
'allow_nonexist_keys = False, share_random_params = False)')
def test_random_choice(): def test_random_choice():
...@@ -430,6 +453,16 @@ def test_random_choice(): ...@@ -430,6 +453,16 @@ def test_random_choice():
values = results['values'] values = results['values']
assert all(map(lambda x: x == values[0], values)) assert all(map(lambda x: x == values[0], values))
# repr
assert repr(pipeline) == (
'TransformBroadcaster(transforms = Compose(\n' +
' RandomChoice(transforms = [Compose(\n' +
' AddToValueaddend = 1.0' + '\n), Compose(\n' +
' AddToValueaddend = 2.0' + '\n)]prob = None)' +
'\n), mapping = {\'value\': \'values\'}, ' +
'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
'allow_nonexist_keys = False, share_random_params = True)')
def test_random_apply(): def test_random_apply():
...@@ -459,6 +492,15 @@ def test_random_apply(): ...@@ -459,6 +492,15 @@ def test_random_apply():
for _ in pipeline: for _ in pipeline:
pass pass
# repr
assert repr(pipeline) == (
'TransformBroadcaster(transforms = Compose(\n' +
' RandomApply(transforms = Compose(\n' +
' AddToValueaddend = 1' + '\n), prob = 0.5)' +
'\n), mapping = {\'value\': \'values\'}, ' +
'remapping = {\'value\': \'values\'}, auto_remap = True, ' +
'allow_nonexist_keys = False, share_random_params = True)')
def test_utils(): def test_utils():
# Test cache_randomness: normal case # Test cache_randomness: normal case
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment