transformer_engine.cpp 6.1 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
8

Przemek Tredak's avatar
Przemek Tredak committed
9
10
11
12
13
#include "common.h"

namespace transformer_engine {

size_t typeToSize(const transformer_engine::DType type) {
14
15
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
                                     return TypeInfo<T>::size;);  // NOLINT(*)
Przemek Tredak's avatar
Przemek Tredak committed
16
17
}

18
bool is_fp8_dtype(const transformer_engine::DType t) {
19
  return t == transformer_engine::DType::kFloat8E4M3 || t == transformer_engine::DType::kFloat8E5M2;
20
21
22
23
24
25
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
  const DType type = t.data.dtype;
  if (is_fp8_dtype(type)) {
    // FP8 input needs to have scale_inv
26
    NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 input " + name + " must have inverse of scale.");
27
    NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
28
    NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
29
  } else {
30
31
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 input " + name + ".");
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 input " + name + ".");
32
33
34
    NVTE_CHECK(t.scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input " + name + ".");
  }
35
  NVTE_CHECK(t.data.dptr != nullptr, "Input " + name + " is not allocated!");
36
37
38
39
40
41
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
  const DType type = t.data.dtype;
  if (is_fp8_dtype(type)) {
    // FP8 output needs to have scale, amax and scale_inv
42
    NVTE_CHECK(t.amax.dptr != nullptr, "FP8 output " + name + " must have amax tensor.");
43
    NVTE_CHECK(t.amax.dtype == DType::kFloat32);
44
45
    NVTE_CHECK(t.amax.shape == std::vector<size_t>{1});
    NVTE_CHECK(t.scale_inv.dptr != nullptr, "FP8 output " + name + " must have scale.");
46
    NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
47
48
    NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{1});
    NVTE_CHECK(t.scale.dptr != nullptr, "FP8 output " + name + " must have inverse of scale.");
49
    NVTE_CHECK(t.scale.dtype == DType::kFloat32);
50
    NVTE_CHECK(t.scale.shape == std::vector<size_t>{1});
51
  } else {
52
53
    NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output " + name + ".");
    NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output " + name + ".");
54
55
56
57
58
    NVTE_CHECK(t.scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 output " + name + ".");
  }

  if (!allow_empty) {
59
    NVTE_CHECK(t.data.dptr != nullptr, "Output " + name + " is not allocated!");
60
61
62
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
63
64
}  // namespace transformer_engine

65
66
NVTETensor nvte_create_tensor(void *dptr, const NVTEShape shape, const NVTEDType dtype, float *amax,
                              float *scale, float *scale_inv) {
Przemek Tredak's avatar
Przemek Tredak committed
67
  transformer_engine::Tensor *ret = new transformer_engine::Tensor;
68
69
70
71
72
73
  ret->data.dptr = dptr;
  ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
  ret->data.dtype = static_cast<transformer_engine::DType>(dtype);
  ret->amax.dptr = amax;
  ret->scale.dptr = scale;
  ret->scale_inv.dptr = scale_inv;
Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
77
78
79
80
81
82
83
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
  if (tensor == nullptr) return;
  auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor);
  delete t;
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
84
  return static_cast<NVTEDType>(
85
      reinterpret_cast<const transformer_engine::Tensor *>(tensor)->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
86
87
88
}

NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
89
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
Przemek Tredak's avatar
Przemek Tredak committed
90
  NVTEShape ret;
91
92
  ret.data = t.data.shape.data();
  ret.ndim = t.data.shape.size();
Przemek Tredak's avatar
Przemek Tredak committed
93
94
95
  return ret;
}

96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
size_t nvte_tensor_ndim(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
  return t.data.shape.size();
}

size_t nvte_tensor_size(const NVTETensor tensor, const size_t dim) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
  NVTE_CHECK(dim >= 0 && dim < t.data.shape.size(), "Invalid dimension index: ", dim);
  return t.data.shape[dim];
}

size_t nvte_tensor_numel(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
  size_t numel = 1;
  for (auto size : t.data.shape) {
    numel *= size;
  }
  return numel;
}

size_t nvte_tensor_element_size(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
  return transformer_engine::typeToSize(t.data.dtype);
}

Przemek Tredak's avatar
Przemek Tredak committed
121
void *nvte_tensor_data(const NVTETensor tensor) {
122
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
123
124
125
126
  return t.data.dptr;
}

float *nvte_tensor_amax(const NVTETensor tensor) {
127
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
128
129
  NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
             "Tensor's amax must have Float32 type!");
130
  return reinterpret_cast<float *>(t.amax.dptr);
131
132
133
}

float *nvte_tensor_scale(const NVTETensor tensor) {
134
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
135
136
  NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
             "Tensor's scale must have Float32 type!");
137
  return reinterpret_cast<float *>(t.scale.dptr);
138
139
140
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
141
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor *>(tensor);
142
143
  NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
             "Tensor's inverse of scale must have Float32 type!");
144
  return reinterpret_cast<float *>(t.scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
145
}
cyanguwa's avatar
cyanguwa committed
146

147
void nvte_tensor_pack_create(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
148
  for (int i = 0; i < pack->MAX_SIZE; i++) {
149
    pack->tensors[i] = reinterpret_cast<NVTETensor>(new transformer_engine::Tensor);
cyanguwa's avatar
cyanguwa committed
150
151
152
  }
}

153
void nvte_tensor_pack_destroy(NVTETensorPack *pack) {
cyanguwa's avatar
cyanguwa committed
154
  for (int i = 0; i < pack->MAX_SIZE; i++) {
155
156
    auto *t = reinterpret_cast<transformer_engine::Tensor *>(pack->tensors[i]);
    delete t;
cyanguwa's avatar
cyanguwa committed
157
158
  }
}