common.cu 5.7 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>

9
10
#include <bit>

11
12
#include "./common.h"
#include "./utils.cuh"
13
14
#include "common/util/cuda_runtime.h"
#include "common/util/logging.h"
15
16
17
18
19
20

namespace transformer_engine {

namespace {

__global__ void __launch_bounds__(1)
21
22
    update_tensor_scale_inv_kernel(const float *__restrict__ scale_ptr,
                                   float *__restrict__ scale_inv_ptr) {
23
24
25
26
27
28
  const float scale = scale_ptr == nullptr ? 1 : *scale_ptr;
  reciprocal<float>(scale_inv_ptr, scale);
}

}  // namespace

29
30
31
void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) {
  if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) {
    NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv.");
32
    update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>(
33
34
        reinterpret_cast<const float *>(t->scale.dptr),
        reinterpret_cast<float *>(t->scale_inv.dptr));
35
36
37
  }
}

38
void checkCuDriverContext(CUstream stream) {
yuguo's avatar
yuguo committed
39
40
41
#ifdef __HIP_PLATFORM_AMD__
  return;
#else
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
  CUcontext ctx;
  const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx);
  switch (driver_status) {
    case CUDA_SUCCESS:
      break;

    case CUDA_ERROR_INVALID_CONTEXT:
      int current_device;
      NVTE_CHECK_CUDA(cudaGetDevice(&current_device));
      NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &ctx, current_device);
      NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, ctx);
      break;

    default:
      const char *desc_NVTE_CHECK_CUDA_DRIVER;
      cuda_driver::call("cuGetErrorString", driver_status, &desc_NVTE_CHECK_CUDA_DRIVER);
      NVTE_ERROR("CUDA Error: ", desc_NVTE_CHECK_CUDA_DRIVER);
  }
yuguo's avatar
yuguo committed
60
#endif
61
62
}

yuguo's avatar
yuguo committed
63
#ifndef __HIP_PLATFORM_AMD__
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
CUtensorMapDataType get_CUtensorMapDataType(DType dtype) {
  static const std::unordered_map<DType, CUtensorMapDataType> dtypeMapping = {
      {DType::kByte, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
      {DType::kFloat32, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT32},
      {DType::kFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_FLOAT16},
      {DType::kBFloat16, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_BFLOAT16},
      {DType::kFloat8E4M3, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8},
      {DType::kFloat8E5M2, CUtensorMapDataType::CU_TENSOR_MAP_DATA_TYPE_UINT8}};
  return dtypeMapping.at(dtype);
}

// Set up parameters to create TMA descriptor.
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
                          const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
                          const uint32_t shmemX, const uint32_t stride_elems,
                          const uint32_t offset_elems, const size_t type_size) {
  // Get a function pointer to the cuTensorMapEncodeTiled driver API
  static PFN_cuTensorMapEncodeTiled cuDriverTensorMapEncodeTiled = []() {
    void *driver_ptr = cuda_driver::get_symbol("cuTensorMapEncodeTiled");
    return reinterpret_cast<PFN_cuTensorMapEncodeTiled>(driver_ptr);
  }();
  // rank is the number of dimensions of the array
  constexpr uint32_t rank = 2;
  uint64_t size[rank] = {globalX, globalY};

  // The stride is the number of bytes to traverse from the first element of one row to the next
  uint64_t stride[rank - 1] = {stride_elems * type_size};

  // The boxSize is the size of the shared memory buffer that is used as the
  // source/destination of a TMA transfer
  uint32_t boxSize[rank] = {shmemX, shmemY};

  // The distance between elements in units of sizeof(element)
  uint32_t elemStride[rank] = {1, 1};

  const CUtensorMapDataType tensorDataType = get_CUtensorMapDataType(tensor.dtype);
  void *dataPtr =
      reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(tensor.dptr) + offset_elems * type_size);

103
  NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment),
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
             "Tensor data pointer must be 16B aligned");

  const int TMA_needed_size = TMA_gmem_alignment / type_size;
  NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_size,
             "-byte data type, expected multiple of ", TMA_needed_size, ", got ", globalX);

  // Create the tensor descriptor.
  NVTE_CHECK_CUDA_DRIVER(cuDriverTensorMapEncodeTiled(
      &tensorMap,  // CUtensorMap *tensorMap,
      tensorDataType,
      rank,        // cuuint32_t tensorRank,
      dataPtr,     // void *globalAddress,
      size,        // const cuuint64_t *globalDim,
      stride,      // const cuuint64_t *globalStrides,
      boxSize,     // const cuuint32_t *boxDim,
      elemStride,  // const cuuint32_t *elementStrides,
      // Interleave patterns can be used to accelerate loading of values that
      // are less than 4 bytes long.
      CUtensorMapInterleave::CU_TENSOR_MAP_INTERLEAVE_NONE,

      // Swizzling can be used to avoid shared memory bank conflicts.
      CUtensorMapSwizzle::CU_TENSOR_MAP_SWIZZLE_NONE,

      // L2 Promotion can be used to widen the effect of a cache-policy to a wider
      // set of L2 cache lines.
      CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_NONE,
      // CUtensorMapL2promotion::CU_TENSOR_MAP_L2_PROMOTION_L2_256B,

      // Any element that is outside of bounds will be set to zero by the TMA transfer.
      CUtensorMapFloatOOBfill::CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE));
}
yuguo's avatar
yuguo committed
135
#endif
136
137

bool is_supported_by_CC_100() {
yuguo's avatar
yuguo committed
138
139
140
#ifdef __HIP_PLATFORM_AMD__
  return false;
#else
141
142
143
  int deviceComputeCapability = cuda::sm_arch(cuda::current_device());

  return deviceComputeCapability >= 100;
yuguo's avatar
yuguo committed
144
#endif
145
146
}

147
}  // namespace transformer_engine