common.cu 9.84 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>

9
10
#include <bit>

11
12
#include "./common.h"
#include "./utils.cuh"
13
14
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
15
16
17
18
19
20

namespace transformer_engine {

namespace {

__global__ void __launch_bounds__(1)
21
22
    update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr,
                                   float *__restrict__ scale_inv_ptr) {
23
24
25
26
27
28
  const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
  reciprocal<float>(scale_inv_ptr, scale);
}

}  // namespace

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
    default:
      NVTE_ERROR("Invalid type");
  }
}

47
48
49
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
  if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
    NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
50
    update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
51
52
        reinterpret_cast<const float *>(t->scale.dptr),
        reinterpret_cast<float *>(t->scale_inv.dptr));
53
54
55
  }
}

56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
namespace {

constexpr size_t kThreadsPerBlock = 256;
template <typename TVectorized>
__global__ void __launch_bounds__(kThreadsPerBlock)
    memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (idx * sizeof(TVectorized) >= size_in_bytes) {
    return;  // Out of bounds
  }

  if ((idx + 1) * sizeof(TVectorized) > size_in_bytes) {
    // If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
    size_t remaining_bytes = size_in_bytes - idx * sizeof(TVectorized);
    memset(reinterpret_cast<uint8_t *>(ptr) + idx * sizeof(TVectorized), value, remaining_bytes);
    return;
  }

  union {
    TVectorized value;
    uint8_t data[sizeof(TVectorized)];
  } data;
  for (size_t i = 0; i < sizeof(TVectorized); ++i) {
    data.data[i] = static_cast<uint8_t>(value);
  }
  reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}

}  // namespace

#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
  if (size_in_bytes >= sizeof(vectorizedType) &&                                             \
      reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) {                         \
    size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType));      \
    dim3 grid(numBlocks, 1, 1);                                                              \
    memset_kernel<vectorizedType>                                                            \
        <<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes);                  \
    return;                                                                                  \
  }

extern "C" {
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream) {
  NVTE_API_CALL(nvte_memset);
  NVTE_CHECK(ptr != nullptr, "Pointer for memset must be allocated.");

  if (size_in_bytes > 4096) {
    // Use cudaMemsetAsync for larger sizes.
    cudaMemsetAsync(ptr, value, size_in_bytes, stream);
    return;
  }

  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float4, stream);
  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float2, stream);
  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}
}  // extern "C"

115
void checkCuDriverContext(CUstream stream) {
116
117
118
  // Ensure the thread's "current" CUDA context is set.
  cuda_driver::ensure_context_exists();

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
  CUcontext ctx;
  const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
  switch (driver_status) {
    case CUDA_SUCCESS:
      break;

    case CUDA_ERROR_INVALID_CONTEXT:
      int current_device;
      NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device);
      NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx);
      break;

    default:
      const char *desc_NVTE_CHECK_CUDA_DRIVER;
      cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
      NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
  }
}

CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
140
141
142
143
144
145
146
147
148
149
150
151
152
153
  static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = []() {
    std::unordered_map<DType, CUtensorMapDataType> typeMapping = {
        {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
        {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
        {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
        {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
        {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
        {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
#if FP4_TYPE_SUPPORTED
    typeMapping.insert(
        {DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B});
#endif
    return typeMapping;
  }();
154
155
156
157
158
159
160
  return dtypeMapping.at(dtype);
}

// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
                          const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
                          const uint32_t shmemX, const uint32_t stride_elems,
161
                          const uint32_t offset_elems, const size_t type_num_bits) {
162
  // Get a function pointer to the cuTensorMapEncodeTiled driver API
163
164
  // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
  static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
165
    void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
166
    return reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(driver_ptr);
167
168
169
170
171
172
  }();
  // rank is the number of dimensions of the array
  constexpr uint32_t rank = 2;
  uint64_t size[rank] = {globalX, globalY};

  // The stride is the number of bytes to traverse from the first element of one row to the next
173
  uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8};
174
175
176
177
178
179
180
181
182

  // The boxSize is the size of the shared memory buffer that is used as the
  // source/destination of a TMA transfer
  uint32_t boxSize[rank] = {shmemX, shmemY};

  // The distance between elements in units of sizeof(element)
  uint32_t elemStride[rank] = {1, 1};

  const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
183
184
  void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
                                           (offset_elems * type_num_bits) / 8);
185

186
  NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
187
188
             "Tensor data pointer must be 16B aligned");

189
  const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
190
191
  NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
             "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224

  // Create the tensor descriptor.
  NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
      &tensorMap,  // CUtensorMap *tensorMap,
      tensorDataType,
      rank,        // cuuint32_t tensorRank,
      dataPtr,     // void *globalAddress,
      size,        // const cuuint64_t *globalDim,
      stride,      // const cuuint64_t *globalStrides,
      boxSize,     // const cuuint32_t *boxDim,
      elemStride,  // const cuuint32_t *elementStrides,
      // Interleave patterns can be used to accelerate loading of values that
      // are less than 4 bytes long.
      CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,

      // Swizzling can be used to avoid shared memory bank conflicts.
      CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,

      // L2 Promotion can be used to widen the effect of a cache-policy to a wider
      // set of L2 cache lines.
      CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
      // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,

      // Any element that is outside of bounds will be set to zero by the TMA transfer.
      CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}

bool is_supported_by_CC_100() {
  int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

  return deviceComputeCapability >= 100;
}

225
226
227
228
229
230
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
                                                        size_t outer_size, size_t inner_size) {
  std::vector<std::vector<Tensor *>> ret;
  for (size_t i = 0; i < outer_size; ++i) {
    ret.emplace_back();
    for (size_t j = 0; j < inner_size; ++j) {
231
      ret.back().push_back(convertNVTETensor(nvte_tensors[i][j]));
232
233
234
235
236
    }
  }
  return ret;
}

237
238
239
240
241
242
243
244
245
246
247
248
249
250
size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) {
  return (elements_num * typeToNumBits(buffer_dtype)) / 8;
}

size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
                             const DType buffer_dtype) {
  if (buffer_dtype == DType::kFloat4E2M1) {
    NVTE_CHECK(dim_last % 2 == 0,
               "Last dimension of a tensor with FP4 type of data must be an even number!");
  }
  const size_t elements_num = dim_first * dim_last;
  return get_buffer_size_bytes(elements_num, buffer_dtype);
}

251
}  // namespace transformer_engine