transformer_engine.cpp 5.26 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
#include "common.h"

namespace transformer_engine {

size_t typeToSize(const transformer_engine::DType type) {
14
15
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
                                     return TypeInfo<T>::size;);  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
16
17
}

18
bool is_fp8_dtype(const transformer_engine::DType t) {
19
  return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2;
20
21
22
23
24
25
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
  const DType type = t.data.dtype;
  if (is_fp8_dtype(type)) {
    // FP8 input needs to have scale_inv
26
    NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale.");
27
    NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
28
    NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
29
  } else {
30
31
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + ".");
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + ".");
32
33
34
    NVTE_CHECK(t.scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input " + name + ".");
  }
35
  NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!");
36
37
38
39
40
41
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
  const DType type = t.data.dtype;
  if (is_fp8_dtype(type)) {
    // FP8 output needs to have scale, amax and scale_inv
42
    NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor.");
43
    NVTE_CHECK(t.amax.dtype == DType::kFloat32);
44
45
    NVTE_CHECK(t.amax.shape == std::vector<size_t>{1});
    NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale.");
46
    NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
47
48
    NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
    NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale.");
49
    NVTE_CHECK(t.scale.dtype == DType::kFloat32);
50
    NVTE_CHECK(t.scale.shape == std::vector<size_t>{1});
51
  } else {
52
53
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + ".");
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + ".");
54
55
56
57
58
    NVTE_CHECK(t.scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 output " + name + ".");
  }

  if (!allow_empty) {
59
    NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!");
60
61
62
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
63
64
}  // namespace transformer_engine

65
66
NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax,
                              float *scale, float *scale_inv) {
Przemek Tredak's avatar
Przemek Tredak committed
67
  transformer_engine::Tensor *ret = new transformer_engine::Tensor;
68
69
70
71
72
73
  ret->data.dptr = dptr;
  ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
  ret->data.dtype = static_cast<transformer_engine::DType>(dtype);
  ret->amax.dptr = amax;
  ret->scale.dptr = scale;
  ret->scale_inv.dptr = scale_inv;
Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
77
78
79
80
81
82
83
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
  if (tensor == nullptr) return;
  auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor);
  delete t;
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
84
  return static_cast<NVTEDType>(
85
      reinterpret_cast<const transformer_engine::Tensor *>(tensor)->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
86
87
88
}

NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
89
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
Przemek Tredak's avatar
Przemek Tredak committed
90
  NVTEShape ret;
91
92
  ret.data = t.data.shape.data();
  ret.ndim = t.data.shape.size();
Przemek Tredak's avatar
Przemek Tredak committed
93
94
95
96
  return ret;
}

void *nvte_tensor_data(const NVTETensor tensor) {
97
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
98
99
100
101
  return t.data.dptr;
}

float *nvte_tensor_amax(const NVTETensor tensor) {
102
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
103
104
  NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
             "Tensor's amax must have Float32 type!");
105
  return reinterpret_cast<float *>(t.amax.dptr);
106
107
108
}

float *nvte_tensor_scale(const NVTETensor tensor) {
109
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
110
111
  NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
             "Tensor's scale must have Float32 type!");
112
  return reinterpret_cast<float *>(t.scale.dptr);
113
114
115
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
116
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
117
118
  NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
             "Tensor's inverse of scale must have Float32 type!");
119
  return reinterpret_cast<float *>(t.scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
120
}
cyanguwa's avatar
cyanguwa committed
121

122
void nvte_tensor_pack_create(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
123
  for (int i = 0; i < pack->MAX_SIZE; i++) {
124
    pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
cyanguwa's avatar
cyanguwa committed
125
126
127
  }
}

128
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
129
  for (int i = 0; i < pack->MAX_SIZE; i++) {
130
131
    auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
    delete t;
cyanguwa's avatar
cyanguwa committed
132
133
  }
}