"docker/requirements/mminstall.txt" did not exist on "b527cbcea33faaa9fc3c87a71b95f0514f9dbb8a"
fused_dense.cpp 8.02 KB
Newer Older
1
2
3
4
// Adapted from https://github.com/NVIDIA/apex/blob/master/csrc/fused_dense.cpp
// We make it work for bfloat16
#include <torch/extension.h>
#include <torch/torch.h>
5
#include <c10/cuda/CUDAGuard.h>
6
7
8
9
#include <vector>

#include <stdio.h>

Tri Dao's avatar
Tri Dao committed
10
11
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")

12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
// https://github.com/NVIDIA/apex/blob/master/csrc/type_shim.h
// #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#define DISPATCH_HALF_AND_BF16(TYPE, NAME, ...)                                \
  switch (TYPE) {                                                              \
  case at::ScalarType::Half: {                                                 \
    using scalar_t = at::Half;                                                 \
    __VA_ARGS__();                                                             \
    break;                                                                     \
  }                                                                            \
  case at::ScalarType::BFloat16: {                                             \
    using scalar_t = at::BFloat16;                                             \
    __VA_ARGS__();                                                             \
    break;                                                                     \
  }                                                                            \
  default:                                                                     \
    AT_ERROR(#NAME, " not implemented for '", toString(TYPE), "'");            \
  }

template <typename T>
31
int linear_bias_wgrad_cuda(const T *input, const T *d_output, int64_t in_features, int64_t batch_size, int64_t out_features, T *d_weight, T *d_bias);
32
33

template <typename T>
34
int linear_act_forward_cuda(const T *input, const T *weight, const T *bias, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *output, void *pre_act);
35
36

template <typename T>
37
int bias_act_linear_dgrad_bgrad_cuda(const T *weight, const T *d_output, const void *pre_act, int64_t in_features, int64_t batch_size, int64_t out_features, bool is_gelu, int heuristic, T *d_input, T *d_bias);
38

Tri Dao's avatar
Tri Dao committed
39
std::vector<at::Tensor> linear_bias_wgrad(at::Tensor input, at::Tensor d_output, bool has_d_bias) {
40

41
42
43
  int64_t batch_size = input.size(0);
  int64_t in_features = input.size(1);
  int64_t out_features = d_output.size(1);
44

Tri Dao's avatar
Tri Dao committed
45
46
47
48
49
50
51
52
  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
  TORCH_CHECK(input.dtype() == d_output.dtype());
  TORCH_CHECK(input.is_cuda());
  TORCH_CHECK(d_output.is_cuda());
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(d_output.is_contiguous());
  CHECK_SHAPE(input, batch_size, in_features);
  CHECK_SHAPE(d_output, batch_size, out_features);
53

54
55
56
57
  // Otherwise the kernel will be launched from cuda:0 device
  // Cast to char to avoid compiler warning about narrowing
  at::cuda::CUDAGuard device_guard{(char)input.get_device()};

58
59
60
  // create output/workspace tensor
  auto opts = input.options();
  auto d_weight = at::empty({out_features, in_features}, opts);
Tri Dao's avatar
Tri Dao committed
61
62
  at::Tensor d_bias;
  if (has_d_bias) {
63
#if defined(CUBLAS_VERSION) && CUBLAS_VERSION < 11600
Tri Dao's avatar
Tri Dao committed
64
    d_bias = d_output.view({-1, out_features}).sum(0, false);
65
#else
Tri Dao's avatar
Tri Dao committed
66
    d_bias = at::empty({out_features}, opts);
67
#endif
Tri Dao's avatar
Tri Dao committed
68
  }
69
70
71
72
73
74
75
76
77

  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_bias_wgrad", [&] {
    auto result = linear_bias_wgrad_cuda<scalar_t>(
        input.data_ptr<scalar_t>(),
        d_output.data_ptr<scalar_t>(),
        in_features,
        batch_size,
        out_features,
        d_weight.data_ptr<scalar_t>(),
78
        has_d_bias ? d_bias.data_ptr<scalar_t>() : nullptr);
Tri Dao's avatar
Tri Dao committed
79
    TORCH_CHECK(result == 0, "linear_bias_wgrad failed.");
80
81
82
83
84
  });

  return {d_weight, d_bias};
}

85
86
87
std::vector<at::Tensor> linear_act_forward(at::Tensor input, at::Tensor weight,
                                           c10::optional<at::Tensor> bias_,
                                           bool is_gelu, bool save_pre_act, int heuristic) {
88

89
90
91
  int64_t batch_size = input.size(0);
  int64_t in_features = input.size(1);
  int64_t out_features = weight.size(0);
92

Tri Dao's avatar
Tri Dao committed
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
  TORCH_CHECK(input.dtype() == torch::kFloat16 || input.dtype() == torch::kBFloat16);
  TORCH_CHECK(input.dtype() == weight.dtype());
  TORCH_CHECK(input.is_cuda());
  TORCH_CHECK(weight.is_cuda());
  TORCH_CHECK(input.is_contiguous());
  TORCH_CHECK(weight.is_contiguous());
  CHECK_SHAPE(input, batch_size, in_features);
  CHECK_SHAPE(weight, out_features, in_features);
  if (bias_.has_value()) {
    auto bias = bias_.value();
    TORCH_CHECK(bias.dtype() == input.dtype());
    TORCH_CHECK(bias.is_cuda());
    TORCH_CHECK(bias.is_contiguous());
    CHECK_SHAPE(bias, out_features);
  }
108

109
110
111
112
  // Otherwise the kernel will be launched from cuda:0 device
  // Cast to char to avoid compiler warning about narrowing
  at::cuda::CUDAGuard device_guard{(char)input.get_device()};

113
114
115
  // create output/workspace tensor
  auto opts = input.options();
  auto output = at::empty({batch_size, out_features}, opts);
116
117
118
119
  at::Tensor pre_act;
  // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
  if (save_pre_act) { pre_act = at::empty({batch_size, is_gelu ? out_features : out_features / 8},
                                          is_gelu ? opts : opts.dtype(torch::kUInt8)); }
120

121
122
  DISPATCH_HALF_AND_BF16(input.scalar_type(), "linear_act_forward", [&] {
    auto result = linear_act_forward_cuda<scalar_t>(
123
        input.data_ptr<scalar_t>(),
Tri Dao's avatar
Tri Dao committed
124
125
        weight.data_ptr<scalar_t>(),
        bias_.has_value()? bias_.value().data_ptr<scalar_t>() : nullptr,
126
127
128
        in_features,
        batch_size,
        out_features,
129
        is_gelu,
130
131
        heuristic,
        output.data_ptr<scalar_t>(),
132
133
        save_pre_act ? pre_act.data_ptr() : nullptr);
    TORCH_CHECK(result == 0, "linear_act_forward failed.");
134
135
136
  });

  std::vector<at::Tensor> result = {output};
137
  if (save_pre_act) { result.push_back(pre_act); };
138
139
140
  return result;
}

141
142
std::vector<at::Tensor> bias_act_linear_dgrad_bgrad(
  at::Tensor weight, at::Tensor d_output, at::Tensor pre_act, bool is_gelu, int heuristic
Tri Dao's avatar
Tri Dao committed
143
) {
144

145
146
147
  int64_t batch_size = d_output.size(0);
  int64_t out_features = d_output.size(1);
  int64_t in_features = weight.size(1);
Tri Dao's avatar
Tri Dao committed
148
149
150

  TORCH_CHECK(weight.dtype() == torch::kFloat16 || weight.dtype() == torch::kBFloat16);
  TORCH_CHECK(weight.dtype() == d_output.dtype());
151
  TORCH_CHECK(is_gelu ? (pre_act.dtype() == weight.dtype()) : (pre_act.dtype() == torch::kUInt8));
Tri Dao's avatar
Tri Dao committed
152
153
  TORCH_CHECK(weight.is_cuda());
  TORCH_CHECK(d_output.is_cuda());
154
  TORCH_CHECK(pre_act.is_cuda());
Tri Dao's avatar
Tri Dao committed
155
156
  TORCH_CHECK(weight.is_contiguous());
  TORCH_CHECK(d_output.is_contiguous());
157
  TORCH_CHECK(pre_act.is_contiguous());
Tri Dao's avatar
Tri Dao committed
158
159
  CHECK_SHAPE(weight, out_features, in_features);
  CHECK_SHAPE(d_output, batch_size, out_features);
160
161
  // If ReLU, cuBlasLT stores a bit-mask (1 bit per element)
  CHECK_SHAPE(pre_act, batch_size, is_gelu ? in_features : in_features / 8);
162

163
164
165
166
  // Otherwise the kernel will be launched from cuda:0 device
  // Cast to char to avoid compiler warning about narrowing
  at::cuda::CUDAGuard device_guard{(char)weight.get_device()};

167
  // create output/workspace tensor
Tri Dao's avatar
Tri Dao committed
168
169
  auto opts = weight.options();
  auto d_bias = at::empty({in_features}, opts);
170
171
  auto d_input = at::empty({batch_size, in_features}, opts);

172
173
  DISPATCH_HALF_AND_BF16(weight.scalar_type(), "bias_act_linear_dgrad_bgrad", [&] {
    auto result = bias_act_linear_dgrad_bgrad_cuda<scalar_t>(
Tri Dao's avatar
Tri Dao committed
174
175
        weight.data_ptr<scalar_t>(),
        d_output.data_ptr<scalar_t>(),
176
        pre_act.data_ptr(),
177
178
179
        in_features,
        batch_size,
        out_features,
180
        is_gelu,
181
182
        heuristic,
        d_input.data_ptr<scalar_t>(),
183
184
        d_bias.data_ptr<scalar_t>());
    TORCH_CHECK(result == 0, "bias_act_linear_dgrad_bgrad failed.");
185
186
  });

Tri Dao's avatar
Tri Dao committed
187
  return {d_input, d_bias};
188
189
190
191
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("linear_bias_wgrad", &linear_bias_wgrad, "linear bias wgrad");
192
193
  m.def("linear_act_forward", &linear_act_forward, "linear gelu/relu forward");
  m.def("bias_act_linear_dgrad_bgrad", &bias_act_linear_dgrad_bgrad, "bias gelu/relu linear dgrad bgrad");
194
}