Commit ad2973b2 authored by Jake Popham's avatar Jake Popham Committed by Facebook GitHub Bot
Browse files

Refactor Preprocessors, Add CoordConv

Summary:
Refactors the `MODEL.REGRESSOR.PREPROCESSORS` usage to allow for multiple preprocessors, and adds a new `ADD_COORD_CHANNELS` preprocessor.

Note: `MODEL.FBNET_V2.STEM_IN_CHANNELS` should be modified in your config to reflect the preprocessors that are enabled. Specifically, `ADD_COORD_CHANNELS` increases the input channels by 2, while `SPLIT_AND_CONCAT` decreases by a factor of the chunk size (typically 2). See the included `quick_pupil_3d_*` configs as an example.

Differential Revision: D30459924

fbshipit-source-id: dd8e3293a416a1a556e091cecc058a1be5288cc0
parent a11cb507
......@@ -33,3 +33,48 @@ class SplitAndConcat(nn.Module):
f"split_dim={self.split_dim}, concat_dim={self.concat_dim}, "
f"chunk={self.chunk}"
)
class AddCoordChannels(nn.Module):
"""Appends coordinate location values to the channel dimension.
@param with_r include radial distance from centroid as additional channel (default: False)
"""
def __init__(self, with_r: bool = False) -> None:
super().__init__()
self.with_r = with_r
def forward(self, input_tensor):
batch_size_shape, channel_in_shape, dim_y, dim_x = input_tensor.shape
device = input_tensor.device
xx_ones = torch.ones([1, 1, 1, dim_x], dtype=torch.int32)
yy_ones = torch.ones([1, 1, 1, dim_y], dtype=torch.int32)
xx_range = torch.arange(dim_y, dtype=torch.int32)
yy_range = torch.arange(dim_x, dtype=torch.int32)
xx_range = xx_range[None, None, :, None]
yy_range = yy_range[None, None, :, None]
xx_channel = torch.matmul(xx_range, xx_ones)
yy_channel = torch.matmul(yy_range, yy_ones)
# transpose y
yy_channel = yy_channel.permute(0, 1, 3, 2)
xx_channel = xx_channel.float() / (dim_y - 1)
yy_channel = yy_channel.float() / (dim_x - 1)
xx_channel = xx_channel * 2 - 1
yy_channel = yy_channel * 2 - 1
xx_channel = xx_channel.repeat(batch_size_shape, 1, 1, 1)
yy_channel = yy_channel.repeat(batch_size_shape, 1, 1, 1)
out = torch.cat([input_tensor, xx_channel.to(device), yy_channel.to(device)], dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_channel - 0.5, 2) + torch.pow(yy_channel - 0.5, 2))
out = torch.cat([out, rr], dim=1)
return out
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment