#include #include #include #include "../torchvision/csrc/models/models.h" using namespace vision::models; template torch::Tensor forward_model(const std::string& input_path, torch::Tensor x) { Model network; torch::load(network, input_path); network->eval(); return network->forward(x); } torch::Tensor forward_alexnet(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg11(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg13(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg16(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg19(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg11bn(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg13bn(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg16bn(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_vgg19bn(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnet18(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnet34(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnet50(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnet101( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnet152( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnext50_32x4d( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_resnext101_32x8d( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_squeezenet1_0( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_squeezenet1_1( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_densenet121( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_densenet169( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_densenet201( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_densenet161( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_mobilenetv2( const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_googlenet( const std::string& input_path, torch::Tensor x) { GoogLeNet network; torch::load(network, input_path); network->eval(); return network->forward(x).output; } torch::Tensor forward_inceptionv3( const std::string& input_path, torch::Tensor x) { InceptionV3 network; torch::load(network, input_path); network->eval(); return network->forward(x).output; } torch::Tensor forward_mnasnet0_5(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_mnasnet0_75(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_mnasnet1_0(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } torch::Tensor forward_mnasnet1_3(const std::string& input_path, torch::Tensor x) { return forward_model(input_path, x); } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("forward_alexnet", &forward_alexnet, "forward_alexnet"); m.def("forward_vgg11", &forward_vgg11, "forward_vgg11"); m.def("forward_vgg13", &forward_vgg13, "forward_vgg13"); m.def("forward_vgg16", &forward_vgg16, "forward_vgg16"); m.def("forward_vgg19", &forward_vgg19, "forward_vgg19"); m.def("forward_vgg11bn", &forward_vgg11bn, "forward_vgg11bn"); m.def("forward_vgg13bn", &forward_vgg13bn, "forward_vgg13bn"); m.def("forward_vgg16bn", &forward_vgg16bn, "forward_vgg16bn"); m.def("forward_vgg19bn", &forward_vgg19bn, "forward_vgg19bn"); m.def("forward_resnet18", &forward_resnet18, "forward_resnet18"); m.def("forward_resnet34", &forward_resnet34, "forward_resnet34"); m.def("forward_resnet50", &forward_resnet50, "forward_resnet50"); m.def("forward_resnet101", &forward_resnet101, "forward_resnet101"); m.def("forward_resnet152", &forward_resnet152, "forward_resnet152"); m.def( "forward_resnext50_32x4d", &forward_resnext50_32x4d, "forward_resnext50_32x4d"); m.def( "forward_resnext101_32x8d", &forward_resnext101_32x8d, "forward_resnext101_32x8d"); m.def( "forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0"); m.def( "forward_squeezenet1_1", &forward_squeezenet1_1, "forward_squeezenet1_1"); m.def("forward_densenet121", &forward_densenet121, "forward_densenet121"); m.def("forward_densenet169", &forward_densenet169, "forward_densenet169"); m.def("forward_densenet201", &forward_densenet201, "forward_densenet201"); m.def("forward_densenet161", &forward_densenet161, "forward_densenet161"); m.def("forward_mobilenetv2", &forward_mobilenetv2, "forward_mobilenetv2"); m.def("forward_googlenet", &forward_googlenet, "forward_googlenet"); m.def("forward_inceptionv3", &forward_inceptionv3, "forward_inceptionv3"); m.def("forward_mnasnet0_5", &forward_mnasnet0_5, "forward_mnasnet0_5"); m.def("forward_mnasnet0_75", &forward_mnasnet0_75, "forward_mnasnet0_75"); m.def("forward_mnasnet1_0", &forward_mnasnet1_0, "forward_mnasnet1_0"); m.def("forward_mnasnet1_3", &forward_mnasnet1_3, "forward_mnasnet1_3"); }