common.cu 10.4 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>

9
10
#include <bit>

11
12
#include "./common.h"
#include "./utils.cuh"
13
14
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
15
16
17
18
19
20

namespace transformer_engine {

namespace {

__global__ void __launch_bounds__(1)
21
22
    update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr,
                                   float *__restrict__ scale_inv_ptr) {
23
24
25
26
27
28
  const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
  reciprocal<float>(scale_inv_ptr, scale);
}

}  // namespace

29
30
31
32
33
34
35
36
37
38
39
40
41
cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) {
  using namespace transformer_engine;
  switch (t) {
    case DType::kFloat16:
      return CUDA_R_16F;
    case DType::kFloat32:
      return CUDA_R_32F;
    case DType::kBFloat16:
      return CUDA_R_16BF;
    case DType::kFloat8E4M3:
      return CUDA_R_8F_E4M3;
    case DType::kFloat8E5M2:
      return CUDA_R_8F_E5M2;
42
43
44
45
#if CUDA_VERSION >= 12080
    case DType::kFloat4E2M1:
      return CUDA_R_4F_E2M1;
#endif
46
47
48
49
50
    default:
      NVTE_ERROR("Invalid type");
  }
}

51
52
53
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
  if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
    NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
54
    update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
55
56
        reinterpret_cast<const float *>(t->scale.dptr),
        reinterpret_cast<float *>(t->scale_inv.dptr));
57
    NVTE_CHECK_CUDA(cudaGetLastError());
58
59
60
  }
}

61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
namespace {

constexpr size_t kThreadsPerBlock = 256;
template <typename TVectorized>
__global__ void __launch_bounds__(kThreadsPerBlock)
    memset_kernel(void *__restrict__ ptr, int value, size_t size_in_bytes) {
  size_t idx = blockIdx.x * blockDim.x + threadIdx.x;

  if (idx * sizeof(TVectorized) >= size_in_bytes) {
    return;  // Out of bounds
  }

  if ((idx + 1) * sizeof(TVectorized) > size_in_bytes) {
    // If the buffer size is not an even multiple of the vectorization, manually set the remaining bytes unvectorized.
    size_t remaining_bytes = size_in_bytes - idx * sizeof(TVectorized);
    memset(reinterpret_cast<uint8_t *>(ptr) + idx * sizeof(TVectorized), value, remaining_bytes);
    return;
  }

  union {
    TVectorized value;
    uint8_t data[sizeof(TVectorized)];
  } data;
  for (size_t i = 0; i < sizeof(TVectorized); ++i) {
    data.data[i] = static_cast<uint8_t>(value);
  }
  reinterpret_cast<TVectorized *>(ptr)[idx] = data.value;
}

}  // namespace

#define MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, vectorizedType, stream) \
  if (size_in_bytes >= sizeof(vectorizedType) &&                                             \
      reinterpret_cast<size_t>(ptr) % sizeof(vectorizedType) == 0) {                         \
    size_t numBlocks = DIVUP(size_in_bytes, kThreadsPerBlock * sizeof(vectorizedType));      \
    dim3 grid(numBlocks, 1, 1);                                                              \
    memset_kernel<vectorizedType>                                                            \
        <<<grid, kThreadsPerBlock, 0, stream>>>(ptr, value, size_in_bytes);                  \
99
    NVTE_CHECK_CUDA(cudaGetLastError());                                                     \
100
101
102
103
104
105
106
107
108
109
    return;                                                                                  \
  }

extern "C" {
void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream) {
  NVTE_API_CALL(nvte_memset);
  NVTE_CHECK(ptr != nullptr, "Pointer for memset must be allocated.");

  if (size_in_bytes > 4096) {
    // Use cudaMemsetAsync for larger sizes.
110
    NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream));
111
112
113
114
115
116
117
118
119
120
    return;
  }

  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float4, stream);
  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float2, stream);
  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, float, stream);
  MEMSET_VECTORIZED_KERNEL_DISPATCH(ptr, size_in_bytes, value, uint8_t, stream);
}
}  // extern "C"

121
void checkCuDriverContext(CUstream stream) {
yuguo's avatar
yuguo committed
122
123
124
#ifdef __HIP_PLATFORM_AMD__
  return;
#else
125
126
127
  // Ensure the thread's "current" CUDA context is set.
  cuda_driver::ensure_context_exists();

128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
  CUcontext ctx;
  const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
  switch (driver_status) {
    case CUDA_SUCCESS:
      break;

    case CUDA_ERROR_INVALID_CONTEXT:
      int current_device;
      NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device);
      NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx);
      break;

    default:
      const char *desc_NVTE_CHECK_CUDA_DRIVER;
      cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
      NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
  }
yuguo's avatar
yuguo committed
146
#endif
147
148
}

yuguo's avatar
yuguo committed
149
#ifndef __HIP_PLATFORM_AMD__
150
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
151
152
153
154
155
156
157
158
159
160
161
162
163
164
  static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = []() {
    std::unordered_map<DType, CUtensorMapDataType> typeMapping = {
        {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
        {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
        {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
        {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
        {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
        {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
#if FP4_TYPE_SUPPORTED
    typeMapping.insert(
        {DType::kFloat4E2M1, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_16U4_ALIGN8B});
#endif
    return typeMapping;
  }();
165
166
167
168
169
170
171
  return dtypeMapping.at(dtype);
}

// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
                          const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
                          const uint32_t shmemX, const uint32_t stride_elems,
172
173
174
                          const uint32_t offset_elems, const size_t type_num_bits,
                          const CUtensorMapSwizzle swizzle) {
  cuda_driver::ensure_context_exists();
175
  // Get a function pointer to the cuTensorMapEncodeTiled driver API
176
177
  // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13
  static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() {
178
    void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
179
    return reinterpret_cast<PFN_cuTensorMapEncodeTiled_v12000>(driver_ptr);
180
181
182
  }();
  // rank is the number of dimensions of the array
  constexpr uint32_t rank = 2;
183
184

  // Dimension for the packed data types must reflect the number of individual U# values.
185
186
187
  uint64_t size[rank] = {globalX, globalY};

  // The stride is the number of bytes to traverse from the first element of one row to the next
188
  uint64_t stride[rank - 1] = {(stride_elems * type_num_bits) / 8};
189
190
191
192
193
194
195
196
197

  // The boxSize is the size of the shared memory buffer that is used as the
  // source/destination of a TMA transfer
  uint32_t boxSize[rank] = {shmemX, shmemY};

  // The distance between elements in units of sizeof(element)
  uint32_t elemStride[rank] = {1, 1};

  const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
198
199
  void *dataPtr = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) +
                                           (offset_elems * type_num_bits) / 8);
200

201
  NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT),
202
203
             "Tensor data pointer must be 16B aligned");

204
  const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits;
205
206
  NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits,
             "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX);
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222

  // Create the tensor descriptor.
  NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
      &tensorMap,  // CUtensorMap *tensorMap,
      tensorDataType,
      rank,        // cuuint32_t tensorRank,
      dataPtr,     // void *globalAddress,
      size,        // const cuuint64_t *globalDim,
      stride,      // const cuuint64_t *globalStrides,
      boxSize,     // const cuuint32_t *boxDim,
      elemStride,  // const cuuint32_t *elementStrides,
      // Interleave patterns can be used to accelerate loading of values that
      // are less than 4 bytes long.
      CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,

      // Swizzling can be used to avoid shared memory bank conflicts.
223
      swizzle,
224
225
226
227
228
229
230
231
232

      // L2 Promotion can be used to widen the effect of a cache-policy to a wider
      // set of L2 cache lines.
      CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
      // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,

      // Any element that is outside of bounds will be set to zero by the TMA transfer.
      CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
yuguo's avatar
yuguo committed
233
#endif
234
235

bool is_supported_by_CC_100() {
yuguo's avatar
yuguo committed
236
237
238
#ifdef __HIP_PLATFORM_AMD__
  return false;
#else
239
240
241
  int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

  return deviceComputeCapability >= 100;
yuguo's avatar
yuguo committed
242
#endif
243
244
}

245
246
247
248
249
250
std::vector<std::vector<Tensor *>> convert_tensor_array(NVTETensor **nvte_tensors,
                                                        size_t outer_size, size_t inner_size) {
  std::vector<std::vector<Tensor *>> ret;
  for (size_t i = 0; i < outer_size; ++i) {
    ret.emplace_back();
    for (size_t j = 0; j < inner_size; ++j) {
251
      ret.back().push_back(convertNVTETensor(nvte_tensors[i][j]));
252
253
254
255
256
    }
  }
  return ret;
}

257
258
259
260
261
262
size_t get_buffer_size_bytes(const size_t elements_num, const DType buffer_dtype) {
  return (elements_num * typeToNumBits(buffer_dtype)) / 8;
}

size_t get_buffer_size_bytes(const size_t dim_first, const size_t dim_last,
                             const DType buffer_dtype) {
263
  #if FP4_TYPE_SUPPORTED
264
265
266
267
  if (buffer_dtype == DType::kFloat4E2M1) {
    NVTE_CHECK(dim_last % 2 == 0,
               "Last dimension of a tensor with FP4 type of data must be an even number!");
  }
268
  #endif
269
270
271
272
  const size_t elements_num = dim_first * dim_last;
  return get_buffer_size_bytes(elements_num, buffer_dtype);
}

273
}  // namespace transformer_engine