common.h 5.8 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/
#pragma once

#include <cublasLt.h>
9
#include <transformer_engine/activation.h>
10
#include <transformer_engine/cast.h>
Shijie's avatar
Shijie committed
11
#include <transformer_engine/fused_attn.h>
12
#include <transformer_engine/gemm.h>
13
#include <transformer_engine/normalization.h>
14
#include <transformer_engine/recipe.h>
Shijie's avatar
Shijie committed
15
#include <transformer_engine/softmax.h>
16
#include <transformer_engine/transformer_engine.h>
17
#include <transformer_engine/transpose.h>
18
19
20
21

#include <cstdlib>
#include <vector>

22
#include "common/util/logging.h"
23
24
#include "paddle/extension.h"
#include "paddle/phi/backends/all_context.h"
25
26
27
28
29
30

namespace transformer_engine {
namespace paddle_ext {
// Paddle Tensor Utils
template <typename T>
inline const void *GetDataPtr(const paddle::Tensor &x, int64_t index) {
31
32
33
34
  if (index < 0 || index >= x.numel()) {
    NVTE_ERROR("Index out of bound");
  }
  return reinterpret_cast<const void *>(x.data<T>() + static_cast<size_t>(index));
35
36
37
38
}

template <typename T>
inline void *GetDataPtr(paddle::Tensor &x, int64_t index) {  // NOLINT
39
40
41
42
  if (index < 0 || index >= x.numel()) {
    NVTE_ERROR("Index out of bound");
  }
  return reinterpret_cast<void *>(x.data<T>() + static_cast<size_t>(index));
43
44
45
46
}

template <typename T>
inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x, int64_t index) {
47
  return x ? GetDataPtr<T>(*x, index) : nullptr;
48
49
50
51
}

template <typename T>
inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x, int64_t index) {  // NOLINT
52
  return x ? GetDataPtr<T>(*x, index) : nullptr;
53
54
55
}

inline const void *GetOptionalDataPtr(const paddle::optional<paddle::Tensor> &x) {
56
  return x ? x->data() : nullptr;
57
58
59
}

inline void *GetOptionalDataPtr(paddle::optional<paddle::Tensor> &x) {  // NOLINT
60
  return x ? x->data() : nullptr;
61
62
63
}

inline std::vector<size_t> GetShapeArray(const paddle::Tensor &x) {
64
65
66
67
68
  std::vector<size_t> shapes;
  for (auto dim : x.shape()) {
    shapes.push_back(static_cast<size_t>(dim));
  }
  return shapes;
69
70
}

Shijie's avatar
Shijie committed
71
inline std::vector<size_t> GetShapeArray(const paddle::optional<paddle::Tensor> &x) {
72
73
  if (x) return GetShapeArray(x.get());
  return {0};
Shijie's avatar
Shijie committed
74
75
}

76
77
78
paddle::Tensor AllocateSpace(const NVTEShape &shape, const DType type, const paddle::Place &place,
                             bool init_to_zeros = 0);

79
80
// DType Utils
inline paddle::DataType Nvte2PaddleDType(DType t) {
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
  switch (t) {
    case DType::kInt32:
    case DType::kFloat32:
      return paddle::DataType::FLOAT32;
    case DType::kFloat16:
      return paddle::DataType::FLOAT16;
    case DType::kBFloat16:
      return paddle::DataType::BFLOAT16;
    case DType::kByte:
    case DType::kFloat8E4M3:
    case DType::kFloat8E5M2:
      return paddle::DataType::UINT8;
    default:
      NVTE_ERROR("Invalid type");
  }
96
97
98
}

inline DType Paddle2NvteDType(paddle::DataType t) {
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
  switch (t) {
    case paddle::DataType::FLOAT16:
      return DType::kFloat16;
    case paddle::DataType::FLOAT32:
      return DType::kFloat32;
    case paddle::DataType::BFLOAT16:
      return DType::kBFloat16;
    case paddle::DataType::BOOL:
      return DType::kByte;
    case paddle::DataType::UINT8:
      return DType::kByte;
    case paddle::DataType::INT32:
      return DType::kInt32;
    case paddle::DataType::INT64:
      return DType::kInt64;
    default:
      NVTE_ERROR("Invalid type");
  }
117
118
119
}

inline DType Int2NvteDType(int64_t dtype) {
120
121
122
123
124
  if (dtype >= 0 && dtype < static_cast<int64_t>(DType::kNumTypes)) {
    return static_cast<DType>(dtype);
  } else {
    NVTE_ERROR("Type not supported.");
  }
125
126
}

127
128
129
130
// get the fused attention backend
inline NVTE_Fused_Attn_Backend get_fused_attn_backend(
    const transformer_engine::DType q_dtype, const transformer_engine::DType kv_dtype,
    NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type attn_mask_type,
131
132
    float p_dropout, size_t num_attn_heads, size_t num_gqa_groups, size_t max_seqlen_q,
    size_t max_seqlen_kv, size_t head_dim) {
133
134
135
136
  NVTE_Fused_Attn_Backend fused_attention_backend = nvte_get_fused_attn_backend(
      static_cast<NVTEDType>(q_dtype), static_cast<NVTEDType>(kv_dtype), qkv_layout, bias_type,
      attn_mask_type, p_dropout, num_attn_heads, num_gqa_groups, max_seqlen_q, max_seqlen_kv,
      head_dim, head_dim, -1, -1);
137
  return fused_attention_backend;
138
139
}

140
141
142
// CUDA Utils
class cudaDevicePropertiesManager {
 public:
143
144
145
146
147
148
149
150
151
152
153
  static cudaDevicePropertiesManager &Instance() {
    static thread_local cudaDevicePropertiesManager instance;
    return instance;
  }

  int GetMultiProcessorCount() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
154
    }
155
156
157
158
159
160
161
162
163
    return prop_.multiProcessorCount;
  }

  int GetMajor() {
    if (!prop_queried_) {
      int device_id;
      NVTE_CHECK_CUDA(cudaGetDevice(&device_id));
      cudaGetDeviceProperties(&prop_, device_id);
      prop_queried_ = true;
164
    }
165
166
    return prop_.major;
  }
167
168

 private:
169
170
  bool prop_queried_ = false;
  cudaDeviceProp prop_;
171
172
};

173
// NVTE Tensor Utils
Shijie's avatar
Shijie committed
174
175
TensorWrapper MakeNvteTensor(const void *data_ptr, const std::vector<size_t> &shape,
                             const DType type);
176
TensorWrapper MakeNvteTensor(void *data_ptr, const NVTEShape &shape, const DType type);
177
178
TensorWrapper MakeNvteTensor(void *data_ptr, const std::vector<size_t> &shape, const DType type,
                             void *amax_ptr, void *scale_ptr, void *scale_inv_ptr);
179
TensorWrapper MakeNvteTensor(paddle::Tensor &tensor);  // NOLINT
180
181
TensorWrapper MakeNvteTensor(const paddle::Tensor &tensor);

182
183
NVTE_QKV_Layout get_nvte_qkv_layout(const std::string &qkv_layout);

184
185
}  // namespace paddle_ext
}  // namespace transformer_engine