utils.h 8.14 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

10
#ifndef __HIP_PLATFORM_AMD__
11
#include <cudnn.h>
cyanguwa's avatar
cyanguwa committed
12
#include <cudnn_frontend.h>
13
#include <cudnn_frontend_utils.h>
14
#endif
15
16
17
#include <cstdint>
#include <mutex>

18
19
20
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

cyanguwa's avatar
cyanguwa committed
21
22
namespace transformer_engine {
namespace fused_attn {
23
#ifndef __HIP_PLATFORM_AMD__
cyanguwa's avatar
cyanguwa committed
24
25
26
using namespace transformer_engine;

enum NVTE_QKV_Matrix {
27
28
29
30
31
32
33
  NVTE_Q_Matrix = 0,            // queries
  NVTE_K_Matrix = 1,            // keys
  NVTE_K_Matrix_Transpose = 2,  // keys transposed
  NVTE_V_Matrix = 3,            // values
  NVTE_V_Matrix_Transpose = 4,  // value matrix transposed
  NVTE_S_Matrix = 5,            // output of GEMM1
  NVTE_O_Matrix = 6,            // final output
cyanguwa's avatar
cyanguwa committed
34
35
};

36
37
void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
                           int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
cyanguwa's avatar
cyanguwa committed
38

39
40
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);

41
42
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim,
                                     int64_t const *stride, bool is_virtual, bool is_value);
43
44

cudnn_frontend::Tensor tensor_create_with_offset(
45
46
    cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual,
    bool is_value, std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);
47

48
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode);
49

50
51
52
cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                             cudnn_frontend::Tensor const &yDesc,
                                             cudnn_frontend::PointWiseDesc const &pwDesc);
53

54
55
56
57
cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                              cudnn_frontend::Tensor const &bDesc,
                                              cudnn_frontend::Tensor const &yDesc,
                                              cudnn_frontend::PointWiseDesc const &pwDesc);
58

59
60
61
62
63
cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                               cudnn_frontend::Tensor const &bDesc,
                                               cudnn_frontend::Tensor const &tDesc,
                                               cudnn_frontend::Tensor const &yDesc,
                                               cudnn_frontend::PointWiseDesc const &pwDesc);
64

cyanguwa's avatar
cyanguwa committed
65
66
67
68
69
70
71
72
73
74
struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
75
76
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
cyanguwa's avatar
cyanguwa committed
77
  cudnnDataType_t tensor_type;
78
  bool use_workspace_opt;
cyanguwa's avatar
cyanguwa committed
79
80

  bool operator<(const FADescriptor &rhs) const {
81
82
83
84
85
    return std::tie(b, h, s_q, s_kv, d, attnScale, isTraining, dropoutProbability, layout,
                    mask_type, bias_type, tensor_type, use_workspace_opt) <
           std::tie(rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, rhs.attnScale, rhs.isTraining,
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.bias_type,
                    rhs.tensor_type, rhs.use_workspace_opt);
cyanguwa's avatar
cyanguwa committed
86
87
88
  }
};

89
90
91
92
93
94
struct FADescriptor_v1 {
  std::int64_t b;
  std::int64_t h;
  std::int64_t hg;
  std::int64_t s_q;
  std::int64_t s_kv;
95
96
  std::int64_t d_qk;
  std::int64_t d_v;
97
98
99
100
101
102
  std::int64_t num_pages_k;
  std::int64_t num_pages_v;
  std::int64_t page_size_k;
  std::int64_t page_size_v;
  std::int64_t max_pages_per_seq_k;
  std::int64_t max_pages_per_seq_v;
103
104
  std::int64_t bias_b;
  std::int64_t bias_h;
105
106
107
108
109
110
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
111
  NVTE_Softmax_Type softmax_type;
112
113
114
  std::int64_t window_size_left;
  std::int64_t window_size_right;
  bool deterministic;
115
116
117
118
  cudnn_frontend::DataType_t qkv_tensor_type;
  cudnn_frontend::DataType_t o_tensor_type;
  cudnn_frontend::DataType_t do_tensor_type;
  cudnn_frontend::DataType_t dqkv_tensor_type;
119
  bool generate_max_sum_exp;
120
121

  bool operator<(const FADescriptor_v1 &rhs) const {
122
123
    return std::tie(b, h, hg, s_q, s_kv, d_qk, d_v, num_pages_k, num_pages_v, page_size_k,
                    page_size_v, max_pages_per_seq_k, max_pages_per_seq_v, bias_b, bias_h,
124
                    attnScale, isTraining, dropoutProbability, layout, mask_type, softmax_type,
125
                    window_size_left, window_size_right, deterministic, bias_type, qkv_tensor_type,
126
                    o_tensor_type, do_tensor_type, dqkv_tensor_type, generate_max_sum_exp) <
127
128
129
           std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d_qk, rhs.d_v, rhs.num_pages_k,
                    rhs.num_pages_v, rhs.page_size_k, rhs.page_size_v, rhs.max_pages_per_seq_k,
                    rhs.max_pages_per_seq_v, rhs.bias_b, rhs.bias_h, rhs.attnScale, rhs.isTraining,
130
131
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.softmax_type,
                    rhs.window_size_left, rhs.window_size_right, rhs.deterministic, rhs.bias_type,
132
                    rhs.qkv_tensor_type, rhs.o_tensor_type, rhs.do_tensor_type,
133
                    rhs.dqkv_tensor_type, rhs.generate_max_sum_exp);
134
135
136
  }
};

137
__global__ void cu_seqlens_to_offsets(int64_t b, int64_t h, int64_t d, int32_t *cu_seqlens_q,
138
139
                                      int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
                                      int32_t *o_ragged_offset);
cyanguwa's avatar
cyanguwa committed
140

141
142
__global__ void cu_seqlens_to_actual_seqlens(int64_t actual_b, int64_t max_b,
                                             int32_t const *const q_cu_seqlens,
143
144
                                             int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
                                             int32_t *kv_seqlens);
145

146
147
148
__global__ void cu_seqlens_padded_to_offsets(NVTE_QKV_Layout_Group layout_group, int64_t actual_b,
                                             int64_t max_b, int64_t h, int64_t hg, int64_t d_qk,
                                             int64_t d_v, const int32_t *cu_seqlens_q_padded,
149
150
                                             const int32_t *cu_seqlens_kv_padded,
                                             DType offset_dtype, void *offsets_q, void *offsets_k,
151
                                             void *offsets_v, void *offsets_o, void *offsets_s);
152
153
154
155
156

DType get_ragged_offset_dtype(NVTE_QKV_Layout_Group layout_group, int64_t num_attn_heads,
                              int64_t num_gqa_groups, int64_t max_seqlen_q, int64_t max_seqlen_kv,
                              int64_t head_dim_qk, int64_t head_dim_v);

157
158
size_t get_max_batch_size(size_t batch_size);
size_t get_max_tokens(size_t num_tokens);
cyanguwa's avatar
cyanguwa committed
159

160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
class FusedAttnOffsetManager {
 public:
  static FusedAttnOffsetManager &Instance() {
    static thread_local FusedAttnOffsetManager instance;
    return instance;
  }

  size_t GetAndUpdateOffset(size_t increment) {
    size_t ret = offset_;
    offset_ += increment;
    return ret;
  }

  FusedAttnOffsetManager(FusedAttnOffsetManager const &) = delete;
  void operator=(FusedAttnOffsetManager const &) = delete;

 private:
  FusedAttnOffsetManager() {}
  size_t offset_ = 0;
};

__global__ void populate_rng_state_kernel(int64_t *rng_state_dst, const int64_t *const seed,
                                          int64_t offset);

__global__ void get_runtime_num_segments_kernel(int32_t *cu_seqlen, size_t len, uint32_t *out);

void PopulateRngStateAsync(void *rng_state_dst, const void *const seed, size_t q_max_seqlen,
                           size_t kv_max_seqlen, NVTE_Fused_Attn_Backend backend,
                           cudaStream_t stream);

uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cudaStream_t stream);
191
#endif
192
}  // namespace fused_attn
cyanguwa's avatar
cyanguwa committed
193
194
195
}  // namespace transformer_engine

#endif