scaled_upper_triang_masked_softmax.cpp 2.02 KB
Newer Older
shenggan's avatar
shenggan committed
1
2
3
4
5
/*This code from NVIDIA Megatron:
 *     with minor changes. */

#include <cuda_fp16.h>
#include <torch/extension.h>
superhao1995's avatar
superhao1995 committed
6

shenggan's avatar
shenggan committed
7
8
9
10
11
12
#include <vector>

namespace multihead_attn {
namespace fused_softmax {
namespace scaled_upper_triang_masked_softmax {

superhao1995's avatar
superhao1995 committed
13
torch::Tensor fwd_cuda(torch::Tensor const& input, float scale_factor);
shenggan's avatar
shenggan committed
14

superhao1995's avatar
superhao1995 committed
15
16
17
torch::Tensor bwd_cuda(torch::Tensor const& output_grads,
                       torch::Tensor const& softmax_results,
                       float scale_factor);
shenggan's avatar
shenggan committed
18
19
20
21

torch::Tensor fwd(torch::Tensor const& input, float scale_factor) {
  AT_ASSERTM(input.dim() == 3, "expected 3D tensor");
  AT_ASSERTM((input.scalar_type() == at::ScalarType::Half) ||
superhao1995's avatar
superhao1995 committed
22
23
                 (input.scalar_type() == at::ScalarType::BFloat16),
             "Only fp16 and bf16 are supported");
shenggan's avatar
shenggan committed
24
25
26
27

  return fwd_cuda(input, scale_factor);
}

superhao1995's avatar
superhao1995 committed
28
29
torch::Tensor bwd(torch::Tensor const& output_grads,
                  torch::Tensor const& softmax_results, float scale_factor) {
shenggan's avatar
shenggan committed
30
31
32
33
  AT_ASSERTM(output_grads.dim() == 3, "expected 3D tensor");
  AT_ASSERTM(softmax_results.dim() == 3, "expected 3D tensor");

  AT_ASSERTM((output_grads.scalar_type() == at::ScalarType::Half) ||
superhao1995's avatar
superhao1995 committed
34
35
                 (output_grads.scalar_type() == at::ScalarType::BFloat16),
             "Only fp16 and bf16 are supported");
shenggan's avatar
shenggan committed
36
  AT_ASSERTM((softmax_results.scalar_type() == at::ScalarType::Half) ||
superhao1995's avatar
superhao1995 committed
37
38
                 (softmax_results.scalar_type() == at::ScalarType::BFloat16),
             "Only fp16 and bf16 are supported");
shenggan's avatar
shenggan committed
39
40
41
42

  return bwd_cuda(output_grads, softmax_results, scale_factor);
}

superhao1995's avatar
superhao1995 committed
43
44
45
}  // end namespace scaled_upper_triang_masked_softmax
}  // end namespace fused_softmax
}  // end namespace multihead_attn
shenggan's avatar
shenggan committed
46
47

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
superhao1995's avatar
superhao1995 committed
48
  m.def("forward",
shenggan's avatar
shenggan committed
49
        &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::fwd,
superhao1995's avatar
superhao1995 committed
50
51
        "Self Multihead Attention scaled, time masked softmax -- Forward.");
  m.def("backward",
shenggan's avatar
shenggan committed
52
        &multihead_attn::fused_softmax::scaled_upper_triang_masked_softmax::bwd,
superhao1995's avatar
superhao1995 committed
53
        "Self Multihead Attention scaled, time masked softmax -- Backward.");
shenggan's avatar
shenggan committed
54
}