utils.h 5.52 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
cyanguwa's avatar
cyanguwa committed
3
4
5
6
7
8
9
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

10
#include "transformer_engine/fused_attn.h"
cyanguwa's avatar
cyanguwa committed
11
#include "transformer_engine/transformer_engine.h"
12
13

#include <cudnn.h>
cyanguwa's avatar
cyanguwa committed
14
#include <cudnn_frontend.h>
15
#include <cudnn_frontend_utils.h>
cyanguwa's avatar
cyanguwa committed
16

17
18
19
#include <cstdint>
#include <mutex>

cyanguwa's avatar
cyanguwa committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

enum NVTE_QKV_Matrix {
    NVTE_Q_Matrix            = 0,  // queries
    NVTE_K_Matrix            = 1,  // keys
    NVTE_K_Matrix_Transpose  = 2,  // keys transposed
    NVTE_V_Matrix            = 3,  // values
    NVTE_V_Matrix_Transpose  = 4,  // value matrix transposed
    NVTE_S_Matrix            = 5,  // output of GEMM1
    NVTE_O_Matrix            = 6,  // final output
};

void generateMatrixStrides(
            int64_t b, int64_t h,
            int64_t s_q, int64_t s_kv,
            int64_t d, int64_t* strideA,
            NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);

41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
bool allowAllConfig(cudnnBackendDescriptor_t engine_config);

cudnn_frontend::Tensor tensor_create(cudnnDataType_t type, int64_t id,
                                            int64_t const *dim,
                                            int64_t const *stride,
                                            bool is_virtual, bool is_value);

cudnn_frontend::Tensor tensor_create_with_offset(
                cudnnDataType_t type, int64_t id,
                int64_t const * dim, int64_t const * stride,
                bool is_virtual, bool is_value,
                std::shared_ptr<cudnn_frontend::Tensor> raggedOffset);

cudnn_frontend::PointWiseDesc pw_desc_create(cudnnDataType_t type,
                                                    cudnnPointwiseMode_t mode);

cudnn_frontend::Operation unary_pw_op_create(
    cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &yDesc,
    cudnn_frontend::PointWiseDesc const &pwDesc);

cudnn_frontend::Operation binary_pw_op_create(
    cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
    cudnn_frontend::Tensor const &yDesc,
    cudnn_frontend::PointWiseDesc const &pwDesc);

cudnn_frontend::Operation ternary_pw_op_create(
    cudnn_frontend::Tensor const &xDesc, cudnn_frontend::Tensor const &bDesc,
    cudnn_frontend::Tensor const &tDesc, cudnn_frontend::Tensor const &yDesc,
    cudnn_frontend::PointWiseDesc const &pwDesc);

cyanguwa's avatar
cyanguwa committed
71
72
73
74
75
76
77
78
79
80
struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
81
82
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
cyanguwa's avatar
cyanguwa committed
83
  cudnnDataType_t tensor_type;
84
  bool use_workspace_opt;
cyanguwa's avatar
cyanguwa committed
85
86
87
88

  bool operator<(const FADescriptor &rhs) const {
    return std::tie(b, h, s_q, s_kv, d,
                    attnScale, isTraining, dropoutProbability,
89
                    layout, mask_type, bias_type, tensor_type, use_workspace_opt)
90
91
92
93
                    < std::tie(
                      rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
                      rhs.attnScale, rhs.isTraining,
                      rhs.dropoutProbability, rhs.layout,
94
95
                      rhs.mask_type, rhs.bias_type,
                      rhs.tensor_type, rhs.use_workspace_opt);
cyanguwa's avatar
cyanguwa committed
96
97
98
  }
};

99
100
101
102
103
104
105
struct FADescriptor_v1 {
  std::int64_t b;
  std::int64_t h;
  std::int64_t hg;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
106
107
  std::int64_t bias_b;
  std::int64_t bias_h;
108
109
110
111
112
113
114
115
116
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  NVTE_Bias_Type bias_type;
  NVTE_Mask_Type mask_type;
  cudnn_frontend::DataType_t tensor_type;

  bool operator<(const FADescriptor_v1 &rhs) const {
117
    return std::tie(b, h, hg, s_q, s_kv, d, bias_b, bias_h,
118
119
120
121
                    attnScale, isTraining, dropoutProbability,
                    layout, mask_type, bias_type, tensor_type)
                    < std::tie(
                      rhs.b, rhs.h, rhs.hg, rhs.s_q, rhs.s_kv, rhs.d,
122
                      rhs.bias_b, rhs.bias_h,
123
124
125
126
127
128
129
                      rhs.attnScale, rhs.isTraining,
                      rhs.dropoutProbability, rhs.layout,
                      rhs.mask_type, rhs.bias_type,
                      rhs.tensor_type);
  }
};

cyanguwa's avatar
cyanguwa committed
130
131
132
133
__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
                int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
                int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);

134
135
136
137
138
__global__ void cu_seqlens_to_actual_seqlens(size_t b,
                int32_t const * const q_cu_seqlens,
                int32_t const * const kv_cu_seqlens,
                int32_t *q_seqlens, int32_t *kv_seqlens);

cyanguwa's avatar
cyanguwa committed
139
140
141
}  // namespace fused_attn

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);
142
cudnn_frontend::DataType_t get_cudnn_fe_dtype(const transformer_engine::DType t);
cyanguwa's avatar
cyanguwa committed
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

class cudnnExecutionPlanManager {
 public:
    static cudnnExecutionPlanManager &Instance() {
        static thread_local cudnnExecutionPlanManager instance;
        return instance;
    }

    cudnnHandle_t GetCudnnHandle() {
        static thread_local std::once_flag flag;
        std::call_once(flag, [&] { cudnnCreate(&handle_); });
        return handle_;
    }

    ~cudnnExecutionPlanManager() {
    }

 private:
    cudnnHandle_t handle_ = nullptr;
};
}  // namespace transformer_engine

#endif