fused_dense.cpp 7.57 KB
Newer Older
1
2
3
4
5
6
7
8
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
// We make it work for bfloat16
#include <torch/extension.h>
#include <torch/torch.h>
#include <vector>

#include <stdio.h>

Tri Dao's avatar
Tri Dao committed
9
10
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...)                                \
  switch (TYPE) {                                                              \
  case at::ScalarType::Half: {                                                 \
    using scalar_t = at::Half;                                                 \
    __VA_ARGS__();                                                             \
    break;                                                                     \
  }                                                                            \
  case at::ScalarType::BFloat16: {                                             \
    using scalar_t = at::BFloat16;                                             \
    __VA_ARGS__();                                                             \
    break;                                                                     \
  }                                                                            \
  default:                                                                     \
    AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");            \
  }

template <typename T>
int linear_bias_wgrad_cuda(T *input, T *d_output, int in_features, int batch_size, int out_features, T *d_weight, T *d_bias, void *lt_workspace);

template <typename T>
int linear_gelu_forward_cuda(T *input, T *weight, T *bias, int in_features, int batch_size, int out_features, int heuristic, T *output, T *gelu_in, void *lt_workspace) ;

template <typename T>
Tri Dao's avatar
Tri Dao committed
36
int bias_gelu_linear_dgrad_bgrad_cuda(T *weight, T *d_output, T *gelu_in, int in_features, int batch_size, int out_features, int heuristic, T *d_input, T *d_bias, void *lt_workspace);
37

Tri Dao's avatar
Tri Dao committed
38
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
39

Tri Dao's avatar
Tri Dao committed
40
41
  int batch_size = input.size(0);
  int in_features = input.size(1);
42
43
  int out_features = d_output.size(1);

Tri Dao's avatar
Tri Dao committed
44
45
46
47
48
49
50
51
  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
  TORCH_CHECK(input.dtype() == d_output.dtype());
  TORCH_CHECK(input.is_cuda());
  TORCH_CHECK(d_output.is_cuda());
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(d_output.is_contiguous());
  CHECK_SHAPE(input, batch_size, in_features);
  CHECK_SHAPE(d_output, batch_size, out_features);
52
53
54
55

  // create output/workspace tensor
  auto opts = input.options();
  auto d_weight = at::empty({out_features, in_features}, opts);
Tri Dao's avatar
Tri Dao committed
56
57
  at::Tensor d_bias;
  if (has_d_bias) {
58
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
Tri Dao's avatar
Tri Dao committed
59
    d_bias = d_output.view({-1, out_features}).sum(0, false);
60
#else
Tri Dao's avatar
Tri Dao committed
61
    d_bias = at::empty({out_features}, opts);
62
#endif
Tri Dao's avatar
Tri Dao committed
63
  }
64
65
66
67
68
69
70
71
72
73
74
  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  auto lt_workspace = at::empty({1 << 22}, opts);

  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
    auto result = linear_bias_wgrad_cuda<scalar_t>(
        input.data_ptr<scalar_t>(),
        d_output.data_ptr<scalar_t>(),
        in_features,
        batch_size,
        out_features,
        d_weight.data_ptr<scalar_t>(),
Tri Dao's avatar
Tri Dao committed
75
        has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr,
76
        (void*) (lt_workspace.data_ptr<scalar_t>()));
Tri Dao's avatar
Tri Dao committed
77
    TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
78
79
80
81
82
  });

  return {d_weight, d_bias};
}

Tri Dao's avatar
Tri Dao committed
83
84
std::vector<at::Tensor> linear_gelu_forward(at::Tensor input, at::Tensor weight,
                                            c10::optional<at::Tensor> bias_,
85
86
                                            bool save_gelu_in, int heuristic) {

Tri Dao's avatar
Tri Dao committed
87
88
  int batch_size = input.size(0);
  int in_features = input.size(1);
89
90
  int out_features = weight.size(0);

Tri Dao's avatar
Tri Dao committed
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
  TORCH_CHECK(input.dtype() == weight.dtype());
  TORCH_CHECK(input.is_cuda());
  TORCH_CHECK(weight.is_cuda());
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(weight.is_contiguous());
  CHECK_SHAPE(input, batch_size, in_features);
  CHECK_SHAPE(weight, out_features, in_features);
  if (bias_.has_value()) {
    auto bias = bias_.value();
    TORCH_CHECK(bias.dtype() == input.dtype());
    TORCH_CHECK(bias.is_cuda());
    TORCH_CHECK(bias.is_contiguous());
    CHECK_SHAPE(bias, out_features);
  }
106
107
108
109
110
111
112
113
114
115
116
117

  // create output/workspace tensor
  auto opts = input.options();
  auto output = at::empty({batch_size, out_features}, opts);
  at::Tensor gelu_in;
  if (save_gelu_in) { gelu_in = at::empty({batch_size, out_features}, opts); }
  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  auto lt_workspace = at::empty({1 << 22}, opts);

  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_gelu_forward", [&] {
    auto result = linear_gelu_forward_cuda<scalar_t>(
        input.data_ptr<scalar_t>(),
Tri Dao's avatar
Tri Dao committed
118
119
        weight.data_ptr<scalar_t>(),
        bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
120
121
122
123
124
125
126
        in_features,
        batch_size,
        out_features,
        heuristic,
        output.data_ptr<scalar_t>(),
        save_gelu_in ? gelu_in.data_ptr<scalar_t>() : nullptr,
        (void*) (lt_workspace.data_ptr<scalar_t>()));
Tri Dao's avatar
Tri Dao committed
127
    TORCH_CHECK(result == 0, "linear_gelu_forward failed.");
128
129
130
131
132
133
134
  });

  std::vector<at::Tensor> result = {output};
  if (save_gelu_in) { result.push_back(gelu_in); };
  return result;
}

Tri Dao's avatar
Tri Dao committed
135
136
137
std::vector<at::Tensor> bias_gelu_linear_dgrad_bgrad(
  at::Tensor weight, at::Tensor d_output, at::Tensor gelu_in, int heuristic
) {
138

Tri Dao's avatar
Tri Dao committed
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
  int batch_size = d_output.size(0);
  int out_features = d_output.size(1);
  int in_features = weight.size(1);

  TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
  TORCH_CHECK(weight.dtype() == d_output.dtype());
  TORCH_CHECK(weight.dtype() == gelu_in.dtype());
  TORCH_CHECK(weight.is_cuda());
  TORCH_CHECK(d_output.is_cuda());
  TORCH_CHECK(gelu_in.is_cuda());
  TORCH_CHECK(weight.is_contiguous());
  TORCH_CHECK(d_output.is_contiguous());
  TORCH_CHECK(gelu_in.is_contiguous());
  CHECK_SHAPE(weight, out_features, in_features);
  CHECK_SHAPE(d_output, batch_size, out_features);
  CHECK_SHAPE(gelu_in, batch_size, in_features);
155
156

  // create output/workspace tensor
Tri Dao's avatar
Tri Dao committed
157
158
  auto opts = weight.options();
  auto d_bias = at::empty({in_features}, opts);
159
160
161
162
  auto d_input = at::empty({batch_size, in_features}, opts);
  // allocate fixed 4MB workspace for cublaslt for now, and this gets at least 4 MB
  auto lt_workspace = at::empty({1 << 22}, opts);

Tri Dao's avatar
Tri Dao committed
163
164
165
166
  DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_gelu_linear_dgrad_bgrad", [&] {
    auto result = bias_gelu_linear_dgrad_bgrad_cuda<scalar_t>(
        weight.data_ptr<scalar_t>(),
        d_output.data_ptr<scalar_t>(),
167
168
169
170
171
172
        gelu_in.data_ptr<scalar_t>(),
        in_features,
        batch_size,
        out_features,
        heuristic,
        d_input.data_ptr<scalar_t>(),
Tri Dao's avatar
Tri Dao committed
173
        d_bias.data_ptr<scalar_t>(),
174
        (void*) (lt_workspace.data_ptr<scalar_t>()));
Tri Dao's avatar
Tri Dao committed
175
    TORCH_CHECK(result == 0, "bias_gelu_linear_dgrad_bgrad failed.");
176
177
  });

Tri Dao's avatar
Tri Dao committed
178
  return {d_input, d_bias};
179
180
181
182
183
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
  m.def("linear_gelu_forward", &linear_gelu_forward, "linear gelu forward");
Tri Dao's avatar
Tri Dao committed
184
  m.def("bias_gelu_linear_dgrad_bgrad", &bias_gelu_linear_dgrad_bgrad, "bias gelu linear dgrad bgrad");
185
}