alexnet.h 486 Bytes
Newer Older
1
#pragma once
Shahriar's avatar
Shahriar committed
2

3
#include <torch/nn.h>
4
#include "../macros.h"
Shahriar's avatar
Shahriar committed
5
6
7
8
9

namespace vision {
namespace models {
// AlexNet model architecture from the
// "One weird trick..." <https://arxiv.org/abs/1404.5997> paper.
10
struct VISION_API AlexNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
11
12
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

13
  explicit AlexNetImpl(int64_t num_classes = 1000);
Shahriar's avatar
Shahriar committed
14
15
16
17
18
19
20
21

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(AlexNet);

} // namespace models
} // namespace vision