# https://openreview.net/forum?id=TVHS5Y4dNvM import torch.nn as nn class Residual(nn.Module): def __init__(self, fn): super().__init__() self.fn = fn def forward(self, x): return self.fn(x) + x def ConvMixer(dim, depth, kernel_size=9, patch_size=7, n_classes=1000): return nn.Sequential( nn.Conv2d(3, dim, kernel_size=patch_size, stride=patch_size), nn.GELU(), nn.BatchNorm2d(dim), *[nn.Sequential( Residual(nn.Sequential( nn.Conv2d(dim, dim, kernel_size, groups=dim, padding="same"), nn.GELU(), nn.BatchNorm2d(dim) )), nn.Conv2d(dim, dim, kernel_size=1), nn.GELU(), nn.BatchNorm2d(dim) ) for i in range(depth)], nn.AdaptiveAvgPool2d((1,1)), nn.Flatten(), nn.Linear(dim, n_classes) )