alexnet.h 522 Bytes
Newer Older
Shahriar's avatar
Shahriar committed
1
2
3
4
#ifndef ALEXNET_H
#define ALEXNET_H

#include <torch/torch.h>
5
#include "general.h"
Shahriar's avatar
Shahriar committed
6
7
8
9
10

namespace vision {
namespace models {
// AlexNet model architecture from the
// "One weird trick..." <https://arxiv.org/abs/1404.5997> paper.
11
struct VISION_API AlexNetImpl : torch::nn::Module {
Shahriar's avatar
Shahriar committed
12
13
14
15
16
17
18
19
20
21
22
23
24
  torch::nn::Sequential features{nullptr}, classifier{nullptr};

  AlexNetImpl(int64_t num_classes = 1000);

  torch::Tensor forward(torch::Tensor x);
};

TORCH_MODULE(AlexNet);

} // namespace models
} // namespace vision

#endif // ALEXNET_H