"transformer_engine/jax/__init__.py" did not exist on "996ea169cf42ba887437760d2001b26812ee95bc"
utils.h 2.75 KB
Newer Older
cyanguwa's avatar
cyanguwa committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
/*************************************************************************
 * Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_
#define TRANSFORMER_ENGINE_FUSED_ATTN_UTILS_H_

#include "transformer_engine/transformer_engine.h"
#include <cudnn_frontend.h>

namespace transformer_engine {
namespace fused_attn {

using namespace transformer_engine;

enum NVTE_QKV_Matrix {
    NVTE_Q_Matrix            = 0,  // queries
    NVTE_K_Matrix            = 1,  // keys
    NVTE_K_Matrix_Transpose  = 2,  // keys transposed
    NVTE_V_Matrix            = 3,  // values
    NVTE_V_Matrix_Transpose  = 4,  // value matrix transposed
    NVTE_S_Matrix            = 5,  // output of GEMM1
    NVTE_O_Matrix            = 6,  // final output
};

void generateMatrixStrides(
            int64_t b, int64_t h,
            int64_t s_q, int64_t s_kv,
            int64_t d, int64_t* strideA,
            NVTE_QKV_Layout layout, NVTE_QKV_Matrix matrix);

struct FADescriptor {
  std::int64_t b;
  std::int64_t h;
  std::int64_t s_q;
  std::int64_t s_kv;
  std::int64_t d;
  float attnScale;
  bool isTraining;
  float dropoutProbability;
  NVTE_QKV_Layout layout;
  cudnnDataType_t tensor_type;

  bool operator<(const FADescriptor &rhs) const {
    return std::tie(b, h, s_q, s_kv, d,
                    attnScale, isTraining, dropoutProbability,
                    layout, tensor_type) < std::tie(
                            rhs.b, rhs.h, rhs.s_q, rhs.s_kv, rhs.d,
                            rhs.attnScale, rhs.isTraining,
                            rhs.dropoutProbability, rhs.layout, rhs.tensor_type);
  }
};

__global__ void cu_seqlens_to_offsets(size_t b, size_t h, size_t d,
                int32_t *cu_seqlens_q, int32_t *actual_seqlens_q,
                int32_t *qkv_ragged_offset, int32_t *o_ragged_offset);

}  // namespace fused_attn

cudnnDataType_t get_cudnn_dtype(const transformer_engine::DType t);

class cudnnExecutionPlanManager {
 public:
    static cudnnExecutionPlanManager &Instance() {
        static thread_local cudnnExecutionPlanManager instance;
        return instance;
    }

    cudnnHandle_t GetCudnnHandle() {
        static thread_local std::once_flag flag;
        std::call_once(flag, [&] { cudnnCreate(&handle_); });
        return handle_;
    }

    ~cudnnExecutionPlanManager() {
        static thread_local std::once_flag flag;
        std::call_once(flag, [&] {
                        if (handle_ != nullptr) {
                          cudnnDestroy(handle_);
                        }});
    }

 private:
    cudnnHandle_t handle_ = nullptr;
};
}  // namespace transformer_engine

#endif