softmax.cpp 8.33 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "transformer_engine/softmax.h"

9
#include "extensions.h"
10
#include "xla/ffi/api/c_api.h"
11
12
13
14

namespace transformer_engine {
namespace jax {

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
#define SOFTMAX_COMMON_BLOCK(tensor_buf)                                      \
  auto dtype = convert_ffi_datatype_to_te_dtype((tensor_buf).element_type()); \
  auto tensor_dims = (tensor_buf).dimensions();                               \
  auto tensor_ranks = tensor_dims.size();                                     \
  auto batch_size = product(tensor_dims, 0, tensor_ranks - 3);                \
  auto head_dim = product(tensor_dims, tensor_ranks - 3, tensor_ranks - 2);   \
  auto q_seqlen = product(tensor_dims, tensor_ranks - 2, tensor_ranks - 1);   \
  auto k_seqlen = product(tensor_dims, tensor_ranks - 1, tensor_ranks);       \
  float scale_factor = static_cast<float>(scale_factor_);

#define SOFTMAX_FORWARD_COMMON_BLOCK                      \
  auto *input = input_buf.untyped_data();                 \
  auto *output = output_buf->untyped_data();              \
  auto input_tensor = TensorWrapper(input, shape, dtype); \
  auto output_tensor = TensorWrapper(output, shape, dtype);

Error_Type ScaledSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
                                   Result_Type output_buf, double scale_factor_) {
  SOFTMAX_COMMON_BLOCK(input_buf);
  auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
  SOFTMAX_FORWARD_COMMON_BLOCK;
  nvte_scaled_softmax_forward(input_tensor.data(), output_tensor.data(), scale_factor, stream);
  return ffi_with_cuda_error_check();
}

Error_Type ScaledMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
                                         Buffer_Type mask_buf, Result_Type output_buf,
                                         double scale_factor_) {
  SOFTMAX_COMMON_BLOCK(input_buf);

  // Mask would be casted to uint8_t
  auto *mask = mask_buf.untyped_data();
  auto mask_dims = mask_buf.dimensions();
  auto padding_size = product(mask_dims, mask_dims.size() - 3);
  auto mask_shape = std::vector<size_t>{padding_size, 1, q_seqlen, k_seqlen};
  auto mask_tensor = TensorWrapper(mask, mask_shape, DType::kByte);

  auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
  SOFTMAX_FORWARD_COMMON_BLOCK;
  nvte_scaled_masked_softmax_forward(input_tensor.data(), mask_tensor.data(), output_tensor.data(),
                                     scale_factor, stream);
  return ffi_with_cuda_error_check();
}

Error_Type ScaledUpperTriangMaskedSoftmaxForwardFFI(cudaStream_t stream, Buffer_Type input_buf,
                                                    Result_Type output_buf, double scale_factor_) {
  SOFTMAX_COMMON_BLOCK(input_buf);
  auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
  SOFTMAX_FORWARD_COMMON_BLOCK;
  nvte_scaled_upper_triang_masked_softmax_forward(input_tensor.data(), output_tensor.data(),
                                                  scale_factor, stream);
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxForwardHandler, ScaledSoftmaxForwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Ret<Buffer_Type>()      // output
                                  .Attr<double>("scale_factor"),
                              FFI_CudaGraph_Traits);

XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxForwardHandler, ScaledMaskedSoftmaxForwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // mask
                                  .Ret<Buffer_Type>()      // output
                                  .Attr<double>("scale_factor"),
                              FFI_CudaGraph_Traits);

XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxForwardHandler,
                              ScaledUpperTriangMaskedSoftmaxForwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Ret<Buffer_Type>()      // output
                                  .Attr<double>("scale_factor"),
                              FFI_CudaGraph_Traits);

#define SOFTMAX_BACKWARD_COMMON_BLOCK                                       \
  auto *grad_output = grad_output_buf.untyped_data();                       \
  auto *softmax_output = softmax_output_buf.untyped_data();                 \
  auto *dgrad = dgrad_buf->untyped_data();                                  \
  auto grad_output_tensor = TensorWrapper(grad_output, shape, dtype);       \
  auto softmax_output_tensor = TensorWrapper(softmax_output, shape, dtype); \
  auto dgrad_tensor = TensorWrapper(dgrad, shape, dtype);

Error_Type ScaledSoftmaxBackwardFFI(cudaStream_t stream, Buffer_Type grad_output_buf,
                                    Buffer_Type softmax_output_buf, Result_Type dgrad_buf,
                                    double scale_factor_) {
  SOFTMAX_COMMON_BLOCK(grad_output_buf);
  auto shape = std::vector<size_t>{batch_size, head_dim, q_seqlen, k_seqlen};
  SOFTMAX_BACKWARD_COMMON_BLOCK;
  nvte_scaled_softmax_backward(grad_output_tensor.data(), softmax_output_tensor.data(),
                               dgrad_tensor.data(), scale_factor, stream);
  return ffi_with_cuda_error_check();
}

Error_Type ScaledUpperTriangMaskedSoftmaxBackwardFFI(cudaStream_t stream,
                                                     Buffer_Type grad_output_buf,
                                                     Buffer_Type softmax_output_buf,
                                                     Result_Type dgrad_buf, double scale_factor_) {
  SOFTMAX_COMMON_BLOCK(grad_output_buf);
  auto shape = std::vector<size_t>{batch_size * head_dim, q_seqlen, k_seqlen};
  SOFTMAX_BACKWARD_COMMON_BLOCK;
  nvte_scaled_upper_triang_masked_softmax_backward(grad_output_tensor.data(),
                                                   softmax_output_tensor.data(),
                                                   dgrad_tensor.data(), scale_factor, stream);
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // grad_output
                                  .Arg<Buffer_Type>()      // softmax_output
                                  .Ret<Buffer_Type>()      // dgrad
                                  .Attr<double>("scale_factor"),
                              FFI_CudaGraph_Traits);

// The backward of ScaledMaskedSoftmax is equivalent to ScaledSoftmax
XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledMaskedSoftmaxBackwardHandler, ScaledSoftmaxBackwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // grad_output
                                  .Arg<Buffer_Type>()      // softmax_output
                                  .Ret<Buffer_Type>()      // dgrad
                                  .Attr<double>("scale_factor"),
                              FFI_CudaGraph_Traits);

XLA_FFI_DEFINE_HANDLER_SYMBOL(ScaledUpperTriangMaskedSoftmaxBackwardHandler,
                              ScaledUpperTriangMaskedSoftmaxBackwardFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // grad_output
                                  .Arg<Buffer_Type>()      // softmax_output
                                  .Ret<Buffer_Type>()      // dgrad
                                  .Attr<double>("scale_factor"),
                              FFI_CudaGraph_Traits);

156
157
}  // namespace jax
}  // namespace transformer_engine