test_common.h 5.91 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#pragma once

Tim Moon's avatar
Tim Moon committed
9
#include <iostream>
Przemek Tredak's avatar
Przemek Tredak committed
10
#include <memory>
Tim Moon's avatar
Tim Moon committed
11
12
#include <vector>

Przemek Tredak's avatar
Przemek Tredak committed
13
#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
14
#include <cuda_fp16.h>
Przemek Tredak's avatar
Przemek Tredak committed
15
16
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
Tim Moon's avatar
Tim Moon committed
17
18
19

#include <transformer_engine/transformer_engine.h>
#include "util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

namespace test {
using namespace transformer_engine;

template <size_t i>
struct BytesToType {};

template <>
struct BytesToType<1> {
  using Type = uint8_t;
};

template <>
struct BytesToType<2> {
  using Type = uint16_t;
};

template <>
struct BytesToType<4> {
  using Type = uint32_t;
};

template <>
struct BytesToType<8> {
  using Type = uint64_t;
};

using byte = uint8_t;
using int32 = int32_t;
cyanguwa's avatar
cyanguwa committed
49
using int64 = int64_t;
Przemek Tredak's avatar
Przemek Tredak committed
50
51
52
53
54
55
56
57
58
59
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

template <typename T>
struct TypeInfo{
    using types = std::tuple<byte,
                             int32,
cyanguwa's avatar
cyanguwa committed
60
                             int64,
Przemek Tredak's avatar
Przemek Tredak committed
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
                             fp32,
                             fp16,
                             bf16,
                             fp8e4m3,
                             fp8e5m2>;

    template <typename U, DType current>
    struct Helper {
        constexpr static DType getType() {
            constexpr int i = static_cast<int>(current);
            if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
                return current;
            } else {
                return Helper<U, static_cast<DType>(i + 1)>::getType();
            }
        }
    };

    template <typename U>
    struct Helper<U, DType::kNumTypes> {
        constexpr static DType getType() {
            return DType::kNumTypes;
        }
    };

    template <typename U>
    constexpr static DType getType() {
        return Helper<U, DType::kByte>::getType();
    }

    constexpr static DType dtype = getType<T>();
    constexpr static size_t size = sizeof(T);
};

class Tensor {
 public:
  Tensor(const NVTEShape &shape, const DType type);

  Tensor(const std::vector<size_t> &shape, const DType type) :
    Tensor(NVTEShape{shape.data(), shape.size()}, type) {}

  Tensor() {}

  Tensor& operator=(const Tensor &other) = delete;
  Tensor(const Tensor &other) = delete;

  Tensor(Tensor &&other) = default;
  Tensor& operator=(Tensor &&other) = default;

  ~Tensor() {
    if (tensor_.dptr() != nullptr) {
      cudaFree(tensor_.dptr());
    }
  }
  NVTETensor data() const noexcept {
    return tensor_.data();
  }

  const NVTEShape shape() const noexcept {
    return tensor_.shape();
  }

  DType dtype() const noexcept {
    return tensor_.dtype();
  }

  void *dptr() const noexcept {
    return tensor_.dptr();
  }

  template <typename T>
  T *cpu_dptr() const {
    NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
    return reinterpret_cast<T *>(cpu_data_.get());
  }

137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
  float amax() const {
    if(amax_cpu_data_) {
      to_cpu();
      return *amax_cpu_data_;
    } else {
      return 0;
    }
  }

  float scale() const {
    if(scale_cpu_data_) {
      to_cpu();
      return *scale_cpu_data_;
    } else {
      return 1;
    }
  }

  float scale_inv() const {
    if(scale_inv_cpu_data_) {
      to_cpu();
      return *scale_inv_cpu_data_;
    } else {
      return 1;
    }
  }

Przemek Tredak's avatar
Przemek Tredak committed
164
165
  void to_cpu() const;
  void from_cpu() const;
166
167
168
  void set_scale(float scale);
  void set_scale_inv(float scale_inv);
  void shareFP8Meta(const Tensor &other);
Przemek Tredak's avatar
Przemek Tredak committed
169
170
171
172

 private:
  TensorWrapper tensor_;
  std::unique_ptr<unsigned char[]> cpu_data_;
173
174
175
  std::shared_ptr<float> amax_cpu_data_;
  std::shared_ptr<float> scale_cpu_data_;
  std::shared_ptr<float> scale_inv_cpu_data_;
Przemek Tredak's avatar
Przemek Tredak committed
176
177
178
179
180
181
182
183
184
};

size_t typeToSize(DType type);
size_t product(const NVTEShape &shape);

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);

void compareResults(const std::string &name, const Tensor &test, const void *ref,
                    double atol = 1e-5, double rtol = 1e-8);
185
186
void compareResults(const std::string &name, const float test, const float ref,
                    double atol = 1e-5, double rtol = 1e-8);
Przemek Tredak's avatar
Przemek Tredak committed
187
188
189

std::pair<double, double> getTolerances(const DType type);

190
191
void fillUniform(Tensor *t);
void setRandomScale(Tensor *t);
Przemek Tredak's avatar
Przemek Tredak committed
192
193
194
195
196
197
198

constexpr int THREADS_PER_WARP = 32;

const std::string &typeName(DType type);

extern std::vector<DType> all_fp_types;

199
200
bool isFp8Type(DType type);

Przemek Tredak's avatar
Przemek Tredak committed
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
}  // namespace test

#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
    switch (dtype) { \
        using namespace transformer_engine; \
        case DType::kByte: \
            { \
                using type = byte; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kInt32: \
            { \
                using type = int32; \
                {__VA_ARGS__} \
            } \
        break; \
cyanguwa's avatar
cyanguwa committed
218
219
220
221
222
223
        case DType::kInt64: \
            { \
                using type = int64; \
                {__VA_ARGS__} \
            } \
        break; \
Przemek Tredak's avatar
Przemek Tredak committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
        case DType::kFloat32: \
            { \
                using type = float; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat16: \
            { \
                using type = fp16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kBFloat16: \
            { \
                using type = bf16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E4M3: \
            { \
                using type = fp8e4m3; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E5M2: \
            { \
                using type = fp8e5m2; \
                {__VA_ARGS__} \
            } \
        break; \
        default: \
            NVTE_ERROR("Invalid type."); \
    }