utils.h 8.17 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

10
#include <cudnn.h>
cyanguwa's avatar
cyanguwa committed
11
#include <cudnn_frontend.h>
12
#include <cudnn_frontend_utils.h>
cyanguwa's avatar
cyanguwa committed
13

14
15
16
#include <cstdint>
#include <mutex>

17
18
19
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

cyanguwa's avatar
cyanguwa committed
20
21
22
23
24
25
namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

enum NVTE_QKV_Matrix {
26
27
28
29
30
31
32
  NVTE_Q_Matrix = 0,            // queries
  NVTE_K_Matrix = 1,            // keys
  NVTE_K_Matrix_Transpose = 2,  // keys transposed
  NVTE_V_Matrix = 3,            // values
  NVTE_V_Matrix_Transpose = 4,  // value matrix transposed
  NVTE_S_Matrix = 5,            // output of GEMM1
  NVTE_O_Matrix = 6,            // final output
cyanguwa's avatar
cyanguwa committed
33
34
};

35
36
void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
                           int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
cyanguwa's avatar
cyanguwa committed
37

38
39
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);

40
41
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim,
                                     int64_t const *stride, bool is_virtual, bool is_value);
42
43

cudnn_frontend::Tensor tensor_create_with_offset(
44
45
    cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual,
    bool is_value, std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);
46

47
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode);
48

49
50
51
cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                             cudnn_frontend::Tensor const &yDesc,
                                             cudnn_frontend::PointWiseDesc const &pwDesc);
52

53
54
55
56
cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                              cudnn_frontend::Tensor const &bDesc,
                                              cudnn_frontend::Tensor const &yDesc,
                                              cudnn_frontend::PointWiseDesc const &pwDesc);
57

58
59
60
61
62
cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                               cudnn_frontend::Tensor const &bDesc,
                                               cudnn_frontend::Tensor const &tDesc,
                                               cudnn_frontend::Tensor const &yDesc,
                                               cudnn_frontend::PointWiseDesc const &pwDesc);
63

cyanguwa's avatar
cyanguwa committed
64
65
66
67
68
69
70
71
72
73
struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
74
75
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
cyanguwa's avatar
cyanguwa committed
76
  cudnnDataType_t tensor_type;
77
  bool use_workspace_opt;
cyanguwa's avatar
cyanguwa committed
78
79

  bool operator<(const FADescriptor &rhs) const {
80
81
82
83
84
    return std::tie(b, h, s_q, s_kv, d, attnScale, isTraining, dropoutProbability, layout,
                    mask_type, bias_type, tensor_type, use_workspace_opt) <
           std::tie(rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, rhs.attnScale, rhs.isTraining,
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.bias_type,
                    rhs.tensor_type, rhs.use_workspace_opt);
cyanguwa's avatar
cyanguwa committed
85
86
87
  }
};

88
89
90
91
92
93
struct FADescriptor_v1 {
  std::int64_t b;
  std::int64_t h;
  std::int64_t hg;
  std::int64_t s_q;
  std::int64_t s_kv;
94
95
  std::int64_t d_qk;
  std::int64_t d_v;
96
97
98
99
100
101
  std::int64_t num_pages_k;
  std::int64_t num_pages_v;
  std::int64_t page_size_k;
  std::int64_t page_size_v;
  std::int64_t max_pages_per_seq_k;
  std::int64_t max_pages_per_seq_v;
102
103
  std::int64_t bias_b;
  std::int64_t bias_h;
104
105
106
107
108
109
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
110
  NVTE_Softmax_Type softmax_type;
111
112
  std::int64_t window_size_left;
  std::int64_t window_size_right;
113
  bool bottom_right_diagonal;
114
  bool deterministic;
115
116
117
118
  cudnn_frontend::DataType_t qkv_tensor_type;
  cudnn_frontend::DataType_t o_tensor_type;
  cudnn_frontend::DataType_t do_tensor_type;
  cudnn_frontend::DataType_t dqkv_tensor_type;
119
  bool generate_max_sum_exp;
120
121

  bool operator<(const FADescriptor_v1 &rhs) const {
122
123
    return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
                    page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
124
                    attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
125
126
127
                    window_size_left, window_size_right, bottom_right_diagonal, deterministic,
                    bias_type, qkv_tensor_type, o_tensor_type, do_tensor_type, dqkv_tensor_type,
                    generate_max_sum_exp) <
128
129
130
           std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
                    rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
                    rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
131
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
132
133
134
                    rhs.window_size_left, rhs.window_size_right, rhs.bottom_right_diagonal,
                    rhs.deterministic, rhs.bias_type, rhs.qkv_tensor_type, rhs.o_tensor_type,
                    rhs.do_tensor_type, rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
135
136
137
  }
};

138
__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q,
139
140
                                      int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
                                      int32_t *o_ragged_offset);
cyanguwa's avatar
cyanguwa committed
141

142
143
__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b,
                                             int32_t const *const q_cu_seqlens,
144
145
                                             int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
                                             int32_t *kv_seqlens);
146

147
148
149
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
                                             int64_t max_b, int64_t h, int64_t hg, int64_t d_qk,
                                             int64_t d_v, const int32_t *cu_seqlens_q_padded,
150
151
                                             const int32_t *cu_seqlens_kv_padded,
                                             DType offset_dtype, void *offsets_q, void *offsets_k,
152
                                             void *offsets_v, void *offsets_o, void *offsets_s);
153
154
155
156
157

DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads,
                              int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv,
                              int64_t head_dim_qk, int64_t head_dim_v);

158
159
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
cyanguwa's avatar
cyanguwa committed
160

161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
class FusedAttnOffsetManager {
 public:
  static FusedAttnOffsetManager &Instance() {
    static thread_local FusedAttnOffsetManager instance;
    return instance;
  }

  size_t GetAndUpdateOffset(size_t increment) {
    size_t ret = offset_;
    offset_ += increment;
    return ret;
  }

  FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
  void operator=(FusedAttnOffsetManager const &) = delete;

 private:
  FusedAttnOffsetManager() {}
  size_t offset_ = 0;
};

__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
                                          int64_t offset);

__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);

void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);

uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);

193
}  // namespace fused_attn
cyanguwa's avatar
cyanguwa committed
194
195
196
}  // namespace transformer_engine

#endif