softmax.cpp 4.67 KB
Newer Older
1
2
3
4
5
6
7
8
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "transformer_engine/softmax.h"

9
#include "extensions.h"
10
11
12
13
14
15

namespace transformer_engine {
namespace jax {

void ScaledSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                          size_t opaque_len) {
16
17
  auto *input = buffers[0];
  auto *output = buffers[1];
18

19
20
21
  const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
  auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
  auto dtype = desc.dtype;
22

23
24
  auto input_tensor = TensorWrapper(input, shape, dtype);
  auto output_tensor = TensorWrapper(output, shape, dtype);
25

26
  nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), desc.scale_factor, stream);
27
28
29
30
}

void ScaledSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                           size_t opaque_len) {
31
32
33
  auto *grad_output = buffers[0];
  auto *softmax_output = buffers[1];
  auto *dgrad = buffers[2];
34

35
36
37
  const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
  auto shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
  auto dtype = desc.dtype;
38

39
40
41
  auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
  auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
  auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);
42

43
44
  nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
                               dgrad_tensor.data(), desc.scale_factor, stream);
45
46
47
48
}

void ScaledMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                size_t opaque_len) {
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
  auto *input = buffers[0];
  auto *mask = buffers[1];
  auto *output = buffers[2];

  const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
  auto io_shape = std::vector<size_t>{desc.batch_size, desc.head_dim, desc.q_seqlen, desc.k_seqlen};
  auto mask_shape = std::vector<size_t>{desc.padding_size, 1, desc.q_seqlen, desc.k_seqlen};
  auto dtype = desc.dtype;

  auto input_tensor = TensorWrapper(input, io_shape, dtype);
  // Mask would be casted to uint8_t
  auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);
  auto output_tensor = TensorWrapper(output, io_shape, dtype);

  nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
                                     desc.scale_factor, stream);
65
66
67
68
}

void ScaledMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                 size_t opaque_len) {
69
70
  // The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax.
  ScaledSoftmaxBackward(stream, buffers, opaque, opaque_len);
71
72
73
74
}

void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, const char *opaque,
                                           size_t opaque_len) {
75
76
  auto *input = buffers[0];
  auto *output = buffers[1];
77

78
79
80
81
  const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
  auto attn_batch = desc.batch_size * desc.head_dim;
  auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
  auto dtype = desc.dtype;
82

83
  auto input_tensor = TensorWrapper(input, shape, dtype);
84

85
  auto output_tensor = TensorWrapper(output, shape, dtype);
86

87
88
  nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
                                                  desc.scale_factor, stream);
89
90
91
92
}

void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque,
                                            size_t opaque_len) {
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  auto *grad_output = buffers[0];
  auto *softmax_output = buffers[1];
  auto *dgrad = buffers[2];

  const auto &desc = *UnpackOpaque<SoftmaxDescriptor>(opaque, opaque_len);
  auto attn_batch = desc.batch_size * desc.head_dim;
  auto shape = std::vector<size_t>{attn_batch, desc.q_seqlen, desc.k_seqlen};
  auto dtype = desc.dtype;

  auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);
  auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype);
  auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

  nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
                                                   softmax_output_tensor.data(),
                                                   dgrad_tensor.data(), desc.scale_factor, stream);
109
110
111
112
}

}  // namespace jax
}  // namespace transformer_engine