utils.h 8 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

10
#include <cudnn.h>
cyanguwa's avatar
cyanguwa committed
11
#include <cudnn_frontend.h>
12
#include <cudnn_frontend_utils.h>
cyanguwa's avatar
cyanguwa committed
13

14
15
16
#include <cstdint>
#include <mutex>

17
18
19
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

cyanguwa's avatar
cyanguwa committed
20
21
22
23
24
25
namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

enum NVTE_QKV_Matrix {
26
27
28
29
30
31
32
  NVTE_Q_Matrix = 0,            // queries
  NVTE_K_Matrix = 1,            // keys
  NVTE_K_Matrix_Transpose = 2,  // keys transposed
  NVTE_V_Matrix = 3,            // values
  NVTE_V_Matrix_Transpose = 4,  // value matrix transposed
  NVTE_S_Matrix = 5,            // output of GEMM1
  NVTE_O_Matrix = 6,            // final output
cyanguwa's avatar
cyanguwa committed
33
34
};

35
36
void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
                           int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
cyanguwa's avatar
cyanguwa committed
37

38
39
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);

40
41
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim,
                                     int64_t const *stride, bool is_virtual, bool is_value);
42
43

cudnn_frontend::Tensor tensor_create_with_offset(
44
45
    cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual,
    bool is_value, std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);
46

47
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode);
48

49
50
51
cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                             cudnn_frontend::Tensor const &yDesc,
                                             cudnn_frontend::PointWiseDesc const &pwDesc);
52

53
54
55
56
cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                              cudnn_frontend::Tensor const &bDesc,
                                              cudnn_frontend::Tensor const &yDesc,
                                              cudnn_frontend::PointWiseDesc const &pwDesc);
57

58
59
60
61
62
cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                               cudnn_frontend::Tensor const &bDesc,
                                               cudnn_frontend::Tensor const &tDesc,
                                               cudnn_frontend::Tensor const &yDesc,
                                               cudnn_frontend::PointWiseDesc const &pwDesc);
63

cyanguwa's avatar
cyanguwa committed
64
65
66
67
68
69
70
71
72
73
struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
74
75
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
cyanguwa's avatar
cyanguwa committed
76
  cudnnDataType_t tensor_type;
77
  bool use_workspace_opt;
cyanguwa's avatar
cyanguwa committed
78
79

  bool operator<(const FADescriptor &rhs) const {
80
81
82
83
84
    return std::tie(b, h, s_q, s_kv, d, attnScale, isTraining, dropoutProbability, layout,
                    mask_type, bias_type, tensor_type, use_workspace_opt) <
           std::tie(rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, rhs.attnScale, rhs.isTraining,
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.bias_type,
                    rhs.tensor_type, rhs.use_workspace_opt);
cyanguwa's avatar
cyanguwa committed
85
86
87
  }
};

88
89
90
91
92
93
struct FADescriptor_v1 {
  std::int64_t b;
  std::int64_t h;
  std::int64_t hg;
  std::int64_t s_q;
  std::int64_t s_kv;
94
95
  std::int64_t d_qk;
  std::int64_t d_v;
96
97
98
99
100
101
  std::int64_t num_pages_k;
  std::int64_t num_pages_v;
  std::int64_t page_size_k;
  std::int64_t page_size_v;
  std::int64_t max_pages_per_seq_k;
  std::int64_t max_pages_per_seq_v;
102
103
  std::int64_t bias_b;
  std::int64_t bias_h;
104
105
106
107
108
109
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
110
  NVTE_Softmax_Type softmax_type;
111
112
113
  std::int64_t window_size_left;
  std::int64_t window_size_right;
  bool deterministic;
114
115
116
117
  cudnn_frontend::DataType_t qkv_tensor_type;
  cudnn_frontend::DataType_t o_tensor_type;
  cudnn_frontend::DataType_t do_tensor_type;
  cudnn_frontend::DataType_t dqkv_tensor_type;
118
119

  bool operator<(const FADescriptor_v1 &rhs) const {
120
121
    return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
                    page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
122
                    attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
123
124
                    window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
                    o_tensor_type, do_tensor_type, dqkv_tensor_type) <
125
126
127
           std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
                    rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
                    rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
128
129
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
                    rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
130
131
                    rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
                    rhs.dqkv_tensor_type);
132
133
134
  }
};

135
__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q,
136
137
                                      int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
                                      int32_t *o_ragged_offset);
cyanguwa's avatar
cyanguwa committed
138

139
140
__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b,
                                             int32_t const *const q_cu_seqlens,
141
142
                                             int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
                                             int32_t *kv_seqlens);
143

144
145
146
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
                                             int64_t max_b, int64_t h, int64_t hg, int64_t d_qk,
                                             int64_t d_v, const int32_t *cu_seqlens_q_padded,
147
148
                                             const int32_t *cu_seqlens_kv_padded,
                                             DType offset_dtype, void *offsets_q, void *offsets_k,
149
                                             void *offsets_v, void *offsets_o, void *offsets_s);
150
151
152
153
154

DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads,
                              int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv,
                              int64_t head_dim_qk, int64_t head_dim_v);

155
156
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
cyanguwa's avatar
cyanguwa committed
157

158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
class FusedAttnOffsetManager {
 public:
  static FusedAttnOffsetManager &Instance() {
    static thread_local FusedAttnOffsetManager instance;
    return instance;
  }

  size_t GetAndUpdateOffset(size_t increment) {
    size_t ret = offset_;
    offset_ += increment;
    return ret;
  }

  FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
  void operator=(FusedAttnOffsetManager const &) = delete;

 private:
  FusedAttnOffsetManager() {}
  size_t offset_ = 0;
};

__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
                                          int64_t offset);

__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);

void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);

uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);

190
}  // namespace fused_attn
cyanguwa's avatar
cyanguwa committed
191
192
193
}  // namespace transformer_engine

#endif