Commit 8d580a1f authored by Shahriar's avatar Shahriar Committed by Francisco Massa
Browse files

Updated some stuff in models (#1115)

parent d84fee6d
...@@ -87,6 +87,12 @@ class Tester(unittest.TestCase): ...@@ -87,6 +87,12 @@ class Tester(unittest.TestCase):
def test_resnext101_32x8d(self): def test_resnext101_32x8d(self):
process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d') process_model(models.resnext101_32x8d(), self.image, _C_tests.forward_resnext101_32x8d, 'ResNext101_32x8d')
def test_wide_resnet50_2(self):
process_model(models.wide_resnet50_2(), self.image, _C_tests.forward_wide_resnet50_2, 'WideResNet50_2')
def test_wide_resnet101_2(self):
process_model(models.wide_resnet101_2(), self.image, _C_tests.forward_wide_resnet101_2, 'WideResNet101_2')
def test_squeezenet1_0(self): def test_squeezenet1_0(self):
process_model(models.squeezenet1_0(self.pretrained), self.image, process_model(models.squeezenet1_0(self.pretrained), self.image,
_C_tests.forward_squeezenet1_0, 'Squeezenet1.0') _C_tests.forward_squeezenet1_0, 'Squeezenet1.0')
......
...@@ -73,6 +73,16 @@ torch::Tensor forward_resnext101_32x8d( ...@@ -73,6 +73,16 @@ torch::Tensor forward_resnext101_32x8d(
torch::Tensor x) { torch::Tensor x) {
return forward_model<ResNext101_32x8d>(input_path, x); return forward_model<ResNext101_32x8d>(input_path, x);
} }
torch::Tensor forward_wide_resnet50_2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<WideResNet50_2>(input_path, x);
}
torch::Tensor forward_wide_resnet101_2(
const std::string& input_path,
torch::Tensor x) {
return forward_model<WideResNet101_2>(input_path, x);
}
torch::Tensor forward_squeezenet1_0( torch::Tensor forward_squeezenet1_0(
const std::string& input_path, const std::string& input_path,
...@@ -168,6 +178,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -168,6 +178,14 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"forward_resnext101_32x8d", "forward_resnext101_32x8d",
&forward_resnext101_32x8d, &forward_resnext101_32x8d,
"forward_resnext101_32x8d"); "forward_resnext101_32x8d");
m.def(
"forward_wide_resnet50_2",
&forward_wide_resnet50_2,
"forward_wide_resnet50_2");
m.def(
"forward_wide_resnet101_2",
&forward_wide_resnet101_2,
"forward_wide_resnet101_2");
m.def( m.def(
"forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0"); "forward_squeezenet1_0", &forward_squeezenet1_0, "forward_squeezenet1_0");
......
...@@ -41,6 +41,10 @@ int main(int argc, const char* argv[]) { ...@@ -41,6 +41,10 @@ int main(int argc, const char* argv[]) {
"resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt"); "resnext50_32x4d_python.pt", "resnext50_32x4d_cpp.pt");
convert_and_save_model<ResNext101_32x8d>( convert_and_save_model<ResNext101_32x8d>(
"resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt"); "resnext101_32x8d_python.pt", "resnext101_32x8d_cpp.pt");
convert_and_save_model<WideResNet50_2>(
"wide_resnet50_2_python.pt", "wide_resnet50_2_cpp.pt");
convert_and_save_model<WideResNet101_2>(
"wide_resnet101_2_python.pt", "wide_resnet101_2_cpp.pt");
convert_and_save_model<SqueezeNet1_0>( convert_and_save_model<SqueezeNet1_0>(
"squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt"); "squeezenet1_0_python.pt", "squeezenet1_0_cpp.pt");
......
...@@ -6,6 +6,19 @@ namespace vision { ...@@ -6,6 +6,19 @@ namespace vision {
namespace models { namespace models {
using Options = torch::nn::Conv2dOptions; using Options = torch::nn::Conv2dOptions;
int64_t make_divisible(
double value,
int64_t divisor,
c10::optional<int64_t> min_value = {}) {
if (!min_value.has_value())
min_value = divisor;
auto new_value = std::max(
min_value.value(), (int64_t(value + divisor / 2) / divisor) * divisor);
if (new_value < .9 * value)
new_value += divisor;
return new_value;
}
struct ConvBNReLUImpl : torch::nn::SequentialImpl { struct ConvBNReLUImpl : torch::nn::SequentialImpl {
ConvBNReLUImpl( ConvBNReLUImpl(
int64_t in_planes, int64_t in_planes,
...@@ -69,28 +82,40 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module { ...@@ -69,28 +82,40 @@ struct MobileNetInvertedResidualImpl : torch::nn::Module {
TORCH_MODULE(MobileNetInvertedResidual); TORCH_MODULE(MobileNetInvertedResidual);
MobileNetV2Impl::MobileNetV2Impl(int64_t num_classes, double width_mult) { MobileNetV2Impl::MobileNetV2Impl(
int64_t num_classes,
double width_mult,
std::vector<std::vector<int64_t>> inverted_residual_settings,
int64_t round_nearest) {
using Block = MobileNetInvertedResidual; using Block = MobileNetInvertedResidual;
int64_t input_channel = 32; int64_t input_channel = 32;
int64_t last_channel = 1280; int64_t last_channel = 1280;
std::vector<std::vector<int64_t>> inverted_residual_settings = { if (inverted_residual_settings.empty())
// t, c, n, s inverted_residual_settings = {
{1, 16, 1, 1}, // t, c, n, s
{6, 24, 2, 2}, {1, 16, 1, 1},
{6, 32, 3, 2}, {6, 24, 2, 2},
{6, 64, 4, 2}, {6, 32, 3, 2},
{6, 96, 3, 1}, {6, 64, 4, 2},
{6, 160, 3, 2}, {6, 96, 3, 1},
{6, 320, 1, 1}, {6, 160, 3, 2},
}; {6, 320, 1, 1},
};
input_channel = int64_t(input_channel * width_mult);
this->last_channel = int64_t(last_channel * std::max(1.0, width_mult)); if (inverted_residual_settings[0].size() != 4) {
std::cerr << "inverted_residual_settings should contain 4-element vectors";
assert(false);
}
input_channel = make_divisible(input_channel * width_mult, round_nearest);
this->last_channel =
make_divisible(last_channel * std::max(1.0, width_mult), round_nearest);
features->push_back(ConvBNReLU(3, input_channel, 3, 2)); features->push_back(ConvBNReLU(3, input_channel, 3, 2));
for (auto setting : inverted_residual_settings) { for (auto setting : inverted_residual_settings) {
auto output_channel = int64_t(setting[1] * width_mult); auto output_channel =
make_divisible(setting[1] * width_mult, round_nearest);
for (int64_t i = 0; i < setting[2]; ++i) { for (int64_t i = 0; i < setting[2]; ++i) {
auto stride = i == 0 ? setting[3] : 1; auto stride = i == 0 ? setting[3] : 1;
......
...@@ -10,7 +10,11 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module { ...@@ -10,7 +10,11 @@ struct VISION_API MobileNetV2Impl : torch::nn::Module {
int64_t last_channel; int64_t last_channel;
torch::nn::Sequential features, classifier; torch::nn::Sequential features, classifier;
MobileNetV2Impl(int64_t num_classes = 1000, double width_mult = 1.0); MobileNetV2Impl(
int64_t num_classes = 1000,
double width_mult = 1.0,
std::vector<std::vector<int64_t>> inverted_residual_settings = {},
int64_t round_nearest = 8);
torch::Tensor forward(torch::Tensor x); torch::Tensor forward(torch::Tensor x);
}; };
......
...@@ -145,5 +145,15 @@ ResNext101_32x8dImpl::ResNext101_32x8dImpl( ...@@ -145,5 +145,15 @@ ResNext101_32x8dImpl::ResNext101_32x8dImpl(
bool zero_init_residual) bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {} : ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 32, 8) {}
WideResNet50_2Impl::WideResNet50_2Impl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 6, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}
WideResNet101_2Impl::WideResNet101_2Impl(
int64_t num_classes,
bool zero_init_residual)
: ResNetImpl({3, 4, 23, 3}, num_classes, zero_init_residual, 1, 64 * 2) {}
} // namespace models } // namespace models
} // namespace vision } // namespace vision
...@@ -217,6 +217,18 @@ struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> { ...@@ -217,6 +217,18 @@ struct VISION_API ResNext101_32x8dImpl : ResNetImpl<_resnetimpl::Bottleneck> {
bool zero_init_residual = false); bool zero_init_residual = false);
}; };
struct VISION_API WideResNet50_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet50_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
struct VISION_API WideResNet101_2Impl : ResNetImpl<_resnetimpl::Bottleneck> {
WideResNet101_2Impl(
int64_t num_classes = 1000,
bool zero_init_residual = false);
};
template <typename Block> template <typename Block>
struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> { struct VISION_API ResNet : torch::nn::ModuleHolder<ResNetImpl<Block>> {
using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder; using torch::nn::ModuleHolder<ResNetImpl<Block>>::ModuleHolder;
...@@ -229,6 +241,8 @@ TORCH_MODULE(ResNet101); ...@@ -229,6 +241,8 @@ TORCH_MODULE(ResNet101);
TORCH_MODULE(ResNet152); TORCH_MODULE(ResNet152);
TORCH_MODULE(ResNext50_32x4d); TORCH_MODULE(ResNext50_32x4d);
TORCH_MODULE(ResNext101_32x8d); TORCH_MODULE(ResNext101_32x8d);
TORCH_MODULE(WideResNet50_2);
TORCH_MODULE(WideResNet101_2);
} // namespace models } // namespace models
} // namespace vision } // namespace vision
......
...@@ -65,8 +65,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes) ...@@ -65,8 +65,8 @@ SqueezeNetImpl::SqueezeNetImpl(double version, int64_t num_classes)
Fire(384, 64, 256, 256), Fire(384, 64, 256, 256),
Fire(512, 64, 256, 256)); Fire(512, 64, 256, 256));
} else { } else {
std::cerr << "Wrong version number is passed th SqueeseNet constructor!" std::cerr << "Unsupported SqueezeNet version " << version
<< std::endl; << ". 1_0 or 1_1 expected" << std::endl;
assert(false); assert(false);
} }
......
...@@ -79,7 +79,7 @@ torch::Tensor VGGImpl::forward(torch::Tensor x) { ...@@ -79,7 +79,7 @@ torch::Tensor VGGImpl::forward(torch::Tensor x) {
} }
// clang-format off // clang-format off
static std::unordered_map<char, std::vector<int>> cfg = { static std::unordered_map<char, std::vector<int>> cfgs = {
{'A', {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}}, {'A', {64, -1, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
{'B', {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}}, {'B', {64, 64, -1, 128, 128, -1, 256, 256, -1, 512, 512, -1, 512, 512, -1}},
{'D', {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}}, {'D', {64, 64, -1, 128, 128, -1, 256, 256, 256, -1, 512, 512, 512, -1, 512, 512, 512, -1}},
...@@ -87,28 +87,28 @@ static std::unordered_map<char, std::vector<int>> cfg = { ...@@ -87,28 +87,28 @@ static std::unordered_map<char, std::vector<int>> cfg = {
// clang-format on // clang-format on
VGG11Impl::VGG11Impl(int64_t num_classes, bool initialize_weights) VGG11Impl::VGG11Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['A']), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['A']), num_classes, initialize_weights) {}
VGG13Impl::VGG13Impl(int64_t num_classes, bool initialize_weights) VGG13Impl::VGG13Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['B']), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['B']), num_classes, initialize_weights) {}
VGG16Impl::VGG16Impl(int64_t num_classes, bool initialize_weights) VGG16Impl::VGG16Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['D']), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['D']), num_classes, initialize_weights) {}
VGG19Impl::VGG19Impl(int64_t num_classes, bool initialize_weights) VGG19Impl::VGG19Impl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['E']), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['E']), num_classes, initialize_weights) {}
VGG11BNImpl::VGG11BNImpl(int64_t num_classes, bool initialize_weights) VGG11BNImpl::VGG11BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['A'], true), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['A'], true), num_classes, initialize_weights) {}
VGG13BNImpl::VGG13BNImpl(int64_t num_classes, bool initialize_weights) VGG13BNImpl::VGG13BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['B'], true), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['B'], true), num_classes, initialize_weights) {}
VGG16BNImpl::VGG16BNImpl(int64_t num_classes, bool initialize_weights) VGG16BNImpl::VGG16BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['D'], true), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['D'], true), num_classes, initialize_weights) {}
VGG19BNImpl::VGG19BNImpl(int64_t num_classes, bool initialize_weights) VGG19BNImpl::VGG19BNImpl(int64_t num_classes, bool initialize_weights)
: VGGImpl(makeLayers(cfg['E'], true), num_classes, initialize_weights) {} : VGGImpl(makeLayers(cfgs['E'], true), num_classes, initialize_weights) {}
} // namespace models } // namespace models
} // namespace vision } // namespace vision
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment