alexnet.h 490 Bytes
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
#ifndef ALEXNET_H
#define ALEXNET_H

#include <torch/torch.h>

namespace vision {
namespace models {
// AlexNet model architecture from the
// "One weird trick..." <https://arxiv.org/abs/1404.5997> paper.
struct AlexNetImpl : torch::nn::Module {
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

  AlexNetImpl(int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(AlexNet);

} // namespace models
} // namespace vision

#endif // ALEXNET_H