#ifndef DENSENET_H #define DENSENET_H #include namespace vision { namespace models { // Densenet-BC model class, based on // "Densely Connected Convolutional Networks" // // Args: // num_classes (int) - number of classification classes // growth_rate (int) - how many filters to add each layer (`k` in paper) // block_config (list of 4 ints) - how many layers in each pooling block // num_init_features (int) - the number of filters to learn in the first // convolution layer // bn_size (int) - multiplicative factor for number of bottle neck layers // (i.e. bn_size * k features in the bottleneck layer) // drop_rate (float) - dropout rate after each dense layer struct DenseNetImpl : torch::nn::Module { torch::nn::Sequential features{nullptr}; torch::nn::Linear classifier{nullptr}; DenseNetImpl( int64_t num_classes = 1000, int64_t growth_rate = 32, std::vector block_config = {6, 12, 24, 16}, int64_t num_init_features = 64, int64_t bn_size = 4, double drop_rate = 0); torch::Tensor forward(torch::Tensor x); }; struct DenseNet121Impl : DenseNetImpl { DenseNet121Impl( int64_t num_classes = 1000, int64_t growth_rate = 32, std::vector block_config = {6, 12, 24, 16}, int64_t num_init_features = 64, int64_t bn_size = 4, double drop_rate = 0); }; struct DenseNet169Impl : DenseNetImpl { DenseNet169Impl( int64_t num_classes = 1000, int64_t growth_rate = 32, std::vector block_config = {6, 12, 32, 32}, int64_t num_init_features = 64, int64_t bn_size = 4, double drop_rate = 0); }; struct DenseNet201Impl : DenseNetImpl { DenseNet201Impl( int64_t num_classes = 1000, int64_t growth_rate = 32, std::vector block_config = {6, 12, 48, 32}, int64_t num_init_features = 64, int64_t bn_size = 4, double drop_rate = 0); }; struct DenseNet161Impl : DenseNetImpl { DenseNet161Impl( int64_t num_classes = 1000, int64_t growth_rate = 48, std::vector block_config = {6, 12, 36, 24}, int64_t num_init_features = 96, int64_t bn_size = 4, double drop_rate = 0); }; TORCH_MODULE(DenseNet); TORCH_MODULE(DenseNet121); TORCH_MODULE(DenseNet169); TORCH_MODULE(DenseNet201); TORCH_MODULE(DenseNet161); } // namespace models } // namespace vision #endif // DENSENET_H