mlp.cpp 4.59 KB
Newer Older
1
2
3
4
5
6
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>

#include <stdio.h>

7
size_t get_mlp_reserved_space(int64_t batch_size, int num_layers, const int* output_features);
8
9
10
11
12
13
14
15
16
17
18
19
20
21

template <typename T>
size_t get_mlp_bp_workspace_in_bytes(int batch_size, int num_layers, const int* output_features);

template <typename T>
int mlp_fp(
    T* X,
    int input_features,
    int batch_size,
    T** WPtr,
    int num_layers,
    int* output_features,
    T** BPtr,
    T* Y,
Deyu Fu's avatar
Deyu Fu committed
22
23
    T* reserved_space,
    int use_bias,
24
25
    int activation,
    void* lt_workspace);
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40

template <typename T>
int mlp_bp(
    T* X,
    T* Y,
    int input_features,
    int batch_size,
    T** WPtr,
    int num_layers,
    int* output_features,
    T* dY,
    T* reserved_space,
    T* work_space,
    T* dX,
    T** dwPtr,
Deyu Fu's avatar
Deyu Fu committed
41
42
43
44
45
46
    T** dbPtr,
    bool requires_grad,
    int use_bias,
    int activation);

std::vector<at::Tensor> mlp_forward(int use_bias, int activation, std::vector<at::Tensor> inputs) {
47

Deyu Fu's avatar
Deyu Fu committed
48
49
50
51
52
  auto num_layers = inputs.size() - 1;
  if (use_bias) {
    // inputs contains (input, weights, biases)
    num_layers /= 2;
  }
53
54
55
56
57
58
59
60
61
62
63
64
65
  auto batch_size = inputs[0].size(0);
  auto input_features = inputs[0].size(1);

  std::vector<int> output_features;
  for (int i = 0; i < num_layers; i++) {
    output_features.push_back(inputs[i + 1].size(0));
  }

  auto reserved_size = get_mlp_reserved_space(batch_size, num_layers, output_features.data());

  // create output/workspace tensor
  auto out = at::empty({batch_size, output_features.back()}, inputs[0].type());
  auto reserved_space = at::empty({reserved_size}, inputs[0].type());
66
67
  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  auto lt_workspace = at::empty({1 << 22}, inputs[0].type());
68
69
70
71
72
73

  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_forward", [&] {
    std::vector<scalar_t*> w_ptr;
    std::vector<scalar_t*> b_ptr;
    for (int i = 0; i < num_layers; i++) {
      w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
Deyu Fu's avatar
Deyu Fu committed
74
75
76
      if (use_bias) {
        b_ptr.push_back(inputs[i + 1 + num_layers].data_ptr<scalar_t>());
      }
77
78
79
80
81
82
83
84
85
86
    }
    auto result = mlp_fp<scalar_t>(
        inputs[0].data_ptr<scalar_t>(),
        input_features,
        batch_size,
        w_ptr.data(),
        num_layers,
        output_features.data(),
        b_ptr.data(),
        out.data_ptr<scalar_t>(),
Deyu Fu's avatar
Deyu Fu committed
87
88
        reserved_space.data_ptr<scalar_t>(),
        use_bias,
89
90
        activation,
        (void*) (lt_workspace.data_ptr<scalar_t>()));
91
92
93
94
95
96
  });

  return {out, reserved_space};
}

std::vector<at::Tensor> mlp_backward(
Deyu Fu's avatar
Deyu Fu committed
97
98
99
100
101
102
103
104
105
106
107
108
  int use_bias,
  int activation,
  at::Tensor grad_o,
  std::vector<at::Tensor> fprop_outputs,
  std::vector<at::Tensor> inputs) {

  auto num_layers = inputs.size() - 1;
  if (use_bias) {
    // inputs contains (input, weights, biases)
    num_layers /= 2;
  }

109
110
111
  auto batch_size = inputs[0].size(0);
  auto input_features = inputs[0].size(1);

Deyu Fu's avatar
Deyu Fu committed
112
113
  bool requires_grad = inputs[0].requires_grad();

114
115
116
117
118
119
120
121
122
123
  std::vector<int> output_features;
  for (int i = 0; i < num_layers; i++) {
    output_features.push_back(inputs[i + 1].size(0));
  }
  // create outputs, length of inputs
  std::vector<at::Tensor> outputs;
  for (int i = 0; i < inputs.size(); i++) {
    outputs.push_back(at::empty(inputs[i].sizes(), inputs[i].type()));  // clone for testing now
  }

Deyu Fu's avatar
Deyu Fu committed
124
  AT_DISPATCH_FLOATING_TYPES_AND_HALF(inputs[0].type(), "mlp_backward", [&] {
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    std::vector<scalar_t*> w_ptr;
    for (int i = 0; i < num_layers; i++) {
      w_ptr.push_back(inputs[i + 1].data_ptr<scalar_t>());
    }
    std::vector<scalar_t*> outputs_ptr;
    for (int i = 0; i < inputs.size(); i++) {
      outputs_ptr.push_back(outputs[i].data_ptr<scalar_t>());
    }

    auto work_size =
        get_mlp_bp_workspace_in_bytes<scalar_t>(batch_size, num_layers, output_features.data());

    // auto work_space = at::empty({work_size*4}, at::kByte);
    auto work_space = at::empty({work_size / sizeof(scalar_t)}, inputs[0].type());

    auto result = mlp_bp<scalar_t>(
        inputs[0].data_ptr<scalar_t>(),
        fprop_outputs[0].data_ptr<scalar_t>(),
        input_features,
        batch_size,
        w_ptr.data(),
        num_layers,
        output_features.data(),
        grad_o.contiguous().data_ptr<scalar_t>(),
        fprop_outputs[1].data_ptr<scalar_t>(),
        work_space.data_ptr<scalar_t>(),
        outputs_ptr[0],
        outputs_ptr.data() + 1,
Deyu Fu's avatar
Deyu Fu committed
153
154
155
156
        outputs_ptr.data() + 1 + num_layers,
        requires_grad,
        use_bias,
        activation);
157
158
159
160
161
162
163
164
165
  });

  return outputs;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &mlp_forward, "MLP forward");
  m.def("backward", &mlp_backward, "MLP backward");
}
166