utils.h 5.59 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

10
#include <cudnn.h>
cyanguwa's avatar
cyanguwa committed
11
#include <cudnn_frontend.h>
12
#include <cudnn_frontend_utils.h>
cyanguwa's avatar
cyanguwa committed
13

14
15
16
#include <cstdint>
#include <mutex>

17
18
19
#include "transformer_engine/fused_attn.h"
#include "transformer_engine/transformer_engine.h"

cyanguwa's avatar
cyanguwa committed
20
21
22
23
24
25
namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

enum NVTE_QKV_Matrix {
26
27
28
29
30
31
32
  NVTE_Q_Matrix = 0,            // queries
  NVTE_K_Matrix = 1,            // keys
  NVTE_K_Matrix_Transpose = 2,  // keys transposed
  NVTE_V_Matrix = 3,            // values
  NVTE_V_Matrix_Transpose = 4,  // value matrix transposed
  NVTE_S_Matrix = 5,            // output of GEMM1
  NVTE_O_Matrix = 6,            // final output
cyanguwa's avatar
cyanguwa committed
33
34
};

35
36
void generateMatrixStrides(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, int64_t d,
                           int64_t *strideA, NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);
cyanguwa's avatar
cyanguwa committed
37

38
39
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);

40
41
cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id, int64_t const *dim,
                                     int64_t const *stride, bool is_virtual, bool is_value);
42
43

cudnn_frontend::Tensor tensor_create_with_offset(
44
45
    cudnnDataType_t type, int64_t id, int64_t const *dim, int64_t const *stride, bool is_virtual,
    bool is_value, std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);
46

47
cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type, cudnnPointwiseMode_t mode);
48

49
50
51
cudnn_frontend::Operation unary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                             cudnn_frontend::Tensor const &yDesc,
                                             cudnn_frontend::PointWiseDesc const &pwDesc);
52

53
54
55
56
cudnn_frontend::Operation binary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                              cudnn_frontend::Tensor const &bDesc,
                                              cudnn_frontend::Tensor const &yDesc,
                                              cudnn_frontend::PointWiseDesc const &pwDesc);
57

58
59
60
61
62
cudnn_frontend::Operation ternary_pw_op_create(cudnn_frontend::Tensor const &xDesc,
                                               cudnn_frontend::Tensor const &bDesc,
                                               cudnn_frontend::Tensor const &tDesc,
                                               cudnn_frontend::Tensor const &yDesc,
                                               cudnn_frontend::PointWiseDesc const &pwDesc);
63

cyanguwa's avatar
cyanguwa committed
64
65
66
67
68
69
70
71
72
73
struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
74
75
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
cyanguwa's avatar
cyanguwa committed
76
  cudnnDataType_t tensor_type;
77
  bool use_workspace_opt;
cyanguwa's avatar
cyanguwa committed
78
79

  bool operator<(const FADescriptor &rhs) const {
80
81
82
83
84
    return std::tie(b, h, s_q, s_kv, d, attnScale, isTraining, dropoutProbability, layout,
                    mask_type, bias_type, tensor_type, use_workspace_opt) <
           std::tie(rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d, rhs.attnScale, rhs.isTraining,
                    rhs.dropoutProbability, rhs.layout, rhs.mask_type, rhs.bias_type,
                    rhs.tensor_type, rhs.use_workspace_opt);
cyanguwa's avatar
cyanguwa committed
85
86
87
  }
};

88
89
90
91
92
93
94
struct FADescriptor_v1 {
  std::int64_t b;
  std::int64_t h;
  std::int64_t hg;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
95
96
  std::int64_t bias_b;
  std::int64_t bias_h;
97
98
99
100
101
102
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
103
104
  cudnn_frontend::DataType_t fwd_tensor_type;
  cudnn_frontend::DataType_t bwd_tensor_type;
105
106

  bool operator<(const FADescriptor_v1 &rhs) const {
107
108
109
110
111
112
    return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h, attnScale, isTraining,
                    dropoutProbability, layout, mask_type, bias_type, fwd_tensor_type,
                    bwd_tensor_type) <
           std::tie(rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d, rhs.bias_b, rhs.bias_h,
                    rhs.attnScale, rhs.isTraining, rhs.dropoutProbability, rhs.layout,
                    rhs.mask_type, rhs.bias_type, rhs.fwd_tensor_type, rhs.bwd_tensor_type);
113
114
115
  }
};

116
117
118
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d, int32_t *cu_seqlens_q,
                                      int32_t *actual_seqlens_q, int32_t *qkv_ragged_offset,
                                      int32_t *o_ragged_offset);
cyanguwa's avatar
cyanguwa committed
119

120
121
122
__global__ void cu_seqlens_to_actual_seqlens(size_t b, int32_t const *const q_cu_seqlens,
                                             int32_t const *const kv_cu_seqlens, int32_t *q_seqlens,
                                             int32_t *kv_seqlens);
123

cyanguwa's avatar
cyanguwa committed
124
125
126
}  // namespace fused_attn

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
127
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
cyanguwa's avatar
cyanguwa committed
128
129
130

class cudnnExecutionPlanManager {
 public:
131
132
133
134
  static cudnnExecutionPlanManager &Instance() {
    static thread_local cudnnExecutionPlanManager instance;
    return instance;
  }
cyanguwa's avatar
cyanguwa committed
135

136
137
138
139
140
  cudnnHandle_t GetCudnnHandle() {
    static thread_local std::once_flag flag;
    std::call_once(flag, [&] { cudnnCreate(&handle_); });
    return handle_;
  }
cyanguwa's avatar
cyanguwa committed
141

142
  ~cudnnExecutionPlanManager() {}
cyanguwa's avatar
cyanguwa committed
143
144

 private:
145
  cudnnHandle_t handle_ = nullptr;
cyanguwa's avatar
cyanguwa committed
146
147
148
149
};
}  // namespace transformer_engine

#endif