common.h 5.88 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/
#pragma once

#include <cublasLt.h>
9
#include <transformer_engine/activation.h>
10
#include <transformer_engine/cast.h>
Shijie's avatar
Shijie committed
11
#include <transformer_engine/fused_attn.h>
12
#include <transformer_engine/gemm.h>
13
#include <transformer_engine/layer_norm.h>
14
#include <transformer_engine/recipe.h>
Shijie's avatar
Shijie committed
15
16
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
17
#include <transformer_engine/transformer_engine.h>
18
#include <transformer_engine/transpose.h>
19
20
21
22

#include <cstdlib>
#include <vector>

23
#include "common/util/logging.h"
24
25
#include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h"
26
27
28
29
30
31

namespace transformer_engine {
namespace paddle_ext {
// Paddle Tensor Utils
template <typename T>
inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) {
32
33
34
35
  if (index < 0 || index >= x.numel()) {
    NVTE_ERROR("Index out of bound");
  }
  return reinterpret_cast<const void *>(x.data<T>() + static_cast<size_t>(index));
36
37
38
39
}

template <typename T>
inline void *GetDataPtr(paddle::Tensor &x, int64_t index) {  // NOLINT
40
41
42
43
  if (index < 0 || index >= x.numel()) {
    NVTE_ERROR("Index out of bound");
  }
  return reinterpret_cast<void *>(x.data<T>() + static_cast<size_t>(index));
44
45
46
47
}

template <typename T>
inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x, int64_t index) {
48
  return x ? GetDataPtr<T>(*x, index) : nullptr;
49
50
51
52
}

template <typename T>
inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x, int64_t index) {  // NOLINT
53
  return x ? GetDataPtr<T>(*x, index) : nullptr;
54
55
56
}

inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x) {
57
  return x ? x->data() : nullptr;
58
59
60
}

inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) {  // NOLINT
61
  return x ? x->data() : nullptr;
62
63
64
}

inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
65
66
67
68
69
  std::vector<size_t> shapes;
  for (auto dim : x.shape()) {
    shapes.push_back(static_cast<size_t>(dim));
  }
  return shapes;
70
71
}

Shijie's avatar
Shijie committed
72
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
73
74
  if (x) return GetShapeArray(x.get());
  return {0};
Shijie's avatar
Shijie committed
75
76
}

77
78
79
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
                             bool init_to_zeros = 0);

80
81
// DType Utils
inline paddle::DataType Nvte2PaddleDType(DType t) {
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  switch (t) {
    case DType::kInt32:
    case DType::kFloat32:
      return paddle::DataType::FLOAT32;
    case DType::kFloat16:
      return paddle::DataType::FLOAT16;
    case DType::kBFloat16:
      return paddle::DataType::BFLOAT16;
    case DType::kByte:
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
      return paddle::DataType::UINT8;
    default:
      NVTE_ERROR("Invalid type");
  }
97
98
99
}

inline DType Paddle2NvteDType(paddle::DataType t) {
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
  switch (t) {
    case paddle::DataType::FLOAT16:
      return DType::kFloat16;
    case paddle::DataType::FLOAT32:
      return DType::kFloat32;
    case paddle::DataType::BFLOAT16:
      return DType::kBFloat16;
    case paddle::DataType::BOOL:
      return DType::kByte;
    case paddle::DataType::UINT8:
      return DType::kByte;
    case paddle::DataType::INT32:
      return DType::kInt32;
    case paddle::DataType::INT64:
      return DType::kInt64;
    default:
      NVTE_ERROR("Invalid type");
  }
118
119
120
}

inline DType Int2NvteDType(int64_t dtype) {
121
122
123
124
125
  if (dtype >= 0 && dtype < static_cast<int64_t>(DType::kNumTypes)) {
    return static_cast<DType>(dtype);
  } else {
    NVTE_ERROR("Type not supported.");
  }
126
127
}

128
129
130
131
// get the fused attention backend
inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
    const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
132
133
134
135
136
    float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
    size_t max_seqlen_kv, size_t head_dim) {
  NVTE_Fused_Attn_Backend fused_attention_backend =
      nvte_get_fused_attn_backend(static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype),
                                  qkv_layout, bias_type, attn_mask_type, p_dropout, num_attn_heads,
137
                                  num_gqa_groups, max_seqlen_q, max_seqlen_kv, head_dim, -1, -1);
138
  return fused_attention_backend;
139
140
}

141
142
143
// CUDA Utils
class cudaDevicePropertiesManager {
 public:
144
145
146
147
148
149
150
151
152
153
154
  static cudaDevicePropertiesManager &Instance() {
    static thread_local cudaDevicePropertiesManager instance;
    return instance;
  }

  int GetMultiProcessorCount() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
155
    }
156
157
158
159
160
161
162
163
164
    return prop_.multiProcessorCount;
  }

  int GetMajor() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
165
    }
166
167
    return prop_.major;
  }
168
169

 private:
170
171
  bool prop_queried_ = false;
  cudaDeviceProp prop_;
172
173
};

174
// NVTE Tensor Utils
Shijie's avatar
Shijie committed
175
176
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
                             const DType type);
177
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
178
179
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
                             void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
180
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor);  // NOLINT
181
182
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor);

183
184
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout);

185
186
}  // namespace paddle_ext
}  // namespace transformer_engine