transformer_engine.cpp 5.12 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include <transformer_engine/transformer_engine.h>
#include "common.h"

namespace transformer_engine {

size_t typeToSize(const transformer_engine::DType type) {
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
        return TypeInfo<T>::size;
    );  // NOLINT(*)
}

18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
bool is_fp8_dtype(const transformer_engine::DType t) {
  return t == transformer_engine::DType::kFloat8E4M3 ||
         t == transformer_engine::DType::kFloat8E5M2;
}

void CheckInputTensor(const Tensor &t, const std::string &name) {
  const DType type = t.data.dtype;
  if (is_fp8_dtype(type)) {
    // FP8 input needs to have scale_inv
    NVTE_CHECK(t.scale_inv.dptr != nullptr,
               "FP8 input " + name + " must have inverse of scale.");
    NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
    NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
  } else {
    NVTE_CHECK(t.scale.dptr == nullptr,
               "Scale is not supported for non-FP8 input " + name + ".");
    NVTE_CHECK(t.amax.dptr == nullptr,
               "Amax is not supported for non-FP8 input " + name + ".");
    NVTE_CHECK(t.scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 input " + name + ".");
  }
  NVTE_CHECK(t.data.dptr != nullptr,
             "Input " + name + " is not allocated!");
}

void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty) {
  const DType type = t.data.dtype;
  if (is_fp8_dtype(type)) {
    // FP8 output needs to have scale, amax and scale_inv
    NVTE_CHECK(t.amax.dptr != nullptr,
               "FP8 output " + name + " must have amax tensor.");
    NVTE_CHECK(t.amax.dtype == DType::kFloat32);
    NVTE_CHECK(t.amax.shape == std::vector<size_t>{ 1 });
    NVTE_CHECK(t.scale_inv.dptr != nullptr,
               "FP8 output " + name + " must have scale.");
    NVTE_CHECK(t.scale_inv.dtype == DType::kFloat32);
    NVTE_CHECK(t.scale_inv.shape == std::vector<size_t>{ 1 });
    NVTE_CHECK(t.scale.dptr != nullptr,
               "FP8 output " + name + " must have inverse of scale.");
    NVTE_CHECK(t.scale.dtype == DType::kFloat32);
    NVTE_CHECK(t.scale.shape == std::vector<size_t>{ 1 });
  } else {
    NVTE_CHECK(t.scale.dptr == nullptr,
               "Scale is not supported for non-FP8 output " + name + ".");
    NVTE_CHECK(t.amax.dptr == nullptr,
               "Amax is not supported for non-FP8 output " + name + ".");
    NVTE_CHECK(t.scale_inv.dptr == nullptr,
               "Scale_inv is not supported for non-FP8 output " + name + ".");
  }

  if (!allow_empty) {
    NVTE_CHECK(t.data.dptr != nullptr,
               "Output " + name + " is not allocated!");
  }
}

Przemek Tredak's avatar
Przemek Tredak committed
74
75
76
77
}  // namespace transformer_engine

NVTETensor nvte_create_tensor(void *dptr,
                              const NVTEShape shape,
78
79
80
81
                              const NVTEDType dtype,
                              float *amax,
                              float *scale,
                              float *scale_inv) {
Przemek Tredak's avatar
Przemek Tredak committed
82
  transformer_engine::Tensor *ret = new transformer_engine::Tensor;
83
84
85
86
87
88
  ret->data.dptr = dptr;
  ret->data.shape = std::vector<size_t>(shape.data, shape.data + shape.ndim);
  ret->data.dtype = static_cast<transformer_engine::DType>(dtype);
  ret->amax.dptr = amax;
  ret->scale.dptr = scale;
  ret->scale_inv.dptr = scale_inv;
Przemek Tredak's avatar
Przemek Tredak committed
89
90
91
92
93
94
95
96
97
98
  return ret;
}

void nvte_destroy_tensor(NVTETensor tensor) {
  if (tensor == nullptr) return;
  auto *t = reinterpret_cast<transformer_engine::Tensor *>(tensor);
  delete t;
}

NVTEDType nvte_tensor_type(const NVTETensor tensor) {
99
100
  return static_cast<NVTEDType>(
          reinterpret_cast<const transformer_engine::Tensor*>(tensor)->data.dtype);
Przemek Tredak's avatar
Przemek Tredak committed
101
102
103
104
105
}

NVTEShape nvte_tensor_shape(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
  NVTEShape ret;
106
107
  ret.data = t.data.shape.data();
  ret.ndim = t.data.shape.size();
Przemek Tredak's avatar
Przemek Tredak committed
108
109
110
111
112
  return ret;
}

void *nvte_tensor_data(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
  return t.data.dptr;
}

float *nvte_tensor_amax(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
  NVTE_CHECK(t.amax.dtype == transformer_engine::DType::kFloat32,
             "Tensor's amax must have Float32 type!");
  return reinterpret_cast<float*>(t.amax.dptr);
}

float *nvte_tensor_scale(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
  NVTE_CHECK(t.scale.dtype == transformer_engine::DType::kFloat32,
             "Tensor's scale must have Float32 type!");
  return reinterpret_cast<float*>(t.scale.dptr);
}

float *nvte_tensor_scale_inv(const NVTETensor tensor) {
  const auto &t = *reinterpret_cast<const transformer_engine::Tensor*>(tensor);
  NVTE_CHECK(t.scale_inv.dtype == transformer_engine::DType::kFloat32,
             "Tensor's inverse of scale must have Float32 type!");
  return reinterpret_cast<float*>(t.scale_inv.dptr);
Przemek Tredak's avatar
Przemek Tredak committed
135
}