# Implementation of this model is borrowed and modified # (from torch to paddle) from here: # https://github.com/tamasino52/UNETR/blob/main/unetr.py # Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import List, Tuple import numpy as np import paddle import paddle.nn as nn from paddle import Tensor from medicalseg.cvlibs import manager # green block in Fig.1 class TranspConv3DBlock(nn.Layer): def __init__(self, in_planes, out_planes): super(TranspConv3DBlock, self).__init__() self.block = nn.Conv3DTranspose( in_planes, out_planes, kernel_size=2, stride=2, padding=0, output_padding=0, bias_attr=False) def forward(self, x): y = self.block(x) return y class TranspConv3DConv3D(nn.Layer): def __init__(self, in_planes, out_planes, layers=1, conv_block=False): """ blue box in Fig.1 Args: in_planes: in channels of transpose convolution out_planes: out channels of transpose convolution layers: number of blue blocks, transpose convs conv_block: whether to include a conv block after each transpose conv. deafaults to False """ super(TranspConv3DConv3D, self).__init__() self.blocks = nn.LayerList([TranspConv3DBlock(in_planes, out_planes), ]) if conv_block: self.blocks.append( Conv3DBlock( out_planes, out_planes, double=False)) if int(layers) >= 2: for _ in range(int(layers) - 1): self.blocks.append(TranspConv3DBlock(out_planes, out_planes)) if conv_block: self.blocks.append( Conv3DBlock( out_planes, out_planes, double=False)) def forward(self, x): for blk in self.blocks: x = blk(x) return x # yellow block in Fig.1 class Conv3DBlock(nn.Layer): def __init__(self, in_planes, out_planes, kernel_size=3, double=True, norm=nn.BatchNorm3D, skip=True): super(Conv3DBlock, self).__init__() self.skip = skip self.downsample = in_planes != out_planes self.final_activation = nn.LeakyReLU(negative_slope=0.01) padding = (kernel_size - 1) // 2 if double: self.conv_block = nn.Sequential( nn.Conv3D( in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=padding), norm(out_planes), nn.LeakyReLU(negative_slope=0.01), nn.Conv3D( out_planes, out_planes, kernel_size=kernel_size, stride=1, padding=padding), norm(out_planes)) else: self.conv_block = nn.Sequential( nn.Conv3D( in_planes, out_planes, kernel_size=kernel_size, stride=1, padding=padding), norm(out_planes)) if self.skip and self.downsample: self.conv_down = nn.Sequential( nn.Conv3D( in_planes, out_planes, kernel_size=1, stride=1, padding=0), norm(out_planes)) def forward(self, x): y = self.conv_block(x) if self.skip: res = x if self.downsample: res = self.conv_down(res) y = y + res return self.final_activation(y) class AbsPositionalEncoding1D(nn.Layer): def __init__(self, tokens, dim): super(AbsPositionalEncoding1D, self).__init__() params = paddle.randn(shape=[1, tokens, dim]) self.abs_pos_enc = paddle.create_parameter( shape=params.shape, dtype=str(params.numpy().dtype), default_initializer=paddle.nn.initializer.Assign(params)) def forward(self, x): batch = x.shape[0] tile = batch // self.abs_pos_enc.shape[0] expb = paddle.tile(self.abs_pos_enc, repeat_times=(tile, 1, 1)) return x + expb class Embeddings3D(nn.Layer): def __init__(self, input_dim, embed_dim, cube_size, patch_size=16, dropout=0.1): super().__init__() self.n_patches = int((cube_size[0] * cube_size[1] * cube_size[2]) / (patch_size * patch_size * patch_size)) self.patch_size = patch_size self.embed_dim = embed_dim self.patch_embeddings = nn.Conv3D( in_channels=input_dim, out_channels=embed_dim, kernel_size=patch_size, stride=patch_size, bias_attr=False) self.position_embeddings = AbsPositionalEncoding1D(self.n_patches, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x): """ x is a 5D tensor """ patch_embeddings = self.patch_embeddings(x) shape = patch_embeddings.shape _d = paddle.reshape(patch_embeddings, [shape[0], shape[1], -1]) _d = paddle.transpose(_d, perm=[0, 2, 1]) embeddings = self.dropout(self.position_embeddings(_d)) return embeddings def compute_mhsa(q, k, v, scale_factor=1, mask=None): # resulted shape will be: [batch, heads, tokens, tokens] k = paddle.transpose(k, perm=[0, 1, 3, 2]) scaled_dot_prod = paddle.matmul(q, k) * scale_factor if mask is not None: assert mask.shape == scaled_dot_prod.shape[2:] scaled_dot_prod = scaled_dot_prod.masked_fill(mask, -np.inf) attention = paddle.nn.functional.softmax(scaled_dot_prod, axis=-1) # calc result per head return paddle.matmul(attention, v) class MultiHeadSelfAttention(nn.Layer): def __init__(self, dim, heads=8, dim_head=None): """ Implementation of multi-head attention layer of the original transformer model. einsum and einops.rearrange is used whenever possible Args: dim: token's dimension, i.e. word embedding vector size heads: the number of distinct representations to learn dim_head: the dim of the head. In general dim_head