test_common.h 4.91 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#pragma once

#include <memory>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/logging.h>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime_api.h>
#include <vector>
#include <iostream>

namespace test {
using namespace transformer_engine;

template <size_t i>
struct BytesToType {};

template <>
struct BytesToType<1> {
  using Type = uint8_t;
};

template <>
struct BytesToType<2> {
  using Type = uint16_t;
};

template <>
struct BytesToType<4> {
  using Type = uint32_t;
};

template <>
struct BytesToType<8> {
  using Type = uint64_t;
};

using byte = uint8_t;
using int32 = int32_t;
using fp32 = float;
using fp16 = half;
using bf16 = nv_bfloat16;
using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;

template <typename T>
struct TypeInfo{
    using types = std::tuple<byte,
                             int32,
                             fp32,
                             fp16,
                             bf16,
                             fp8e4m3,
                             fp8e5m2>;

    template <typename U, DType current>
    struct Helper {
        constexpr static DType getType() {
            constexpr int i = static_cast<int>(current);
            if (std::is_same<U, typename std::tuple_element<i, types>::type>::value) {
                return current;
            } else {
                return Helper<U, static_cast<DType>(i + 1)>::getType();
            }
        }
    };

    template <typename U>
    struct Helper<U, DType::kNumTypes> {
        constexpr static DType getType() {
            return DType::kNumTypes;
        }
    };

    template <typename U>
    constexpr static DType getType() {
        return Helper<U, DType::kByte>::getType();
    }

    constexpr static DType dtype = getType<T>();
    constexpr static size_t size = sizeof(T);
};

class Tensor {
 public:
  Tensor(const NVTEShape &shape, const DType type);

  Tensor(const std::vector<size_t> &shape, const DType type) :
    Tensor(NVTEShape{shape.data(), shape.size()}, type) {}

  Tensor() {}

  Tensor& operator=(const Tensor &other) = delete;
  Tensor(const Tensor &other) = delete;

  Tensor(Tensor &&other) = default;
  Tensor& operator=(Tensor &&other) = default;

  ~Tensor() {
    if (tensor_.dptr() != nullptr) {
      cudaFree(tensor_.dptr());
    }
  }
  NVTETensor data() const noexcept {
    return tensor_.data();
  }

  const NVTEShape shape() const noexcept {
    return tensor_.shape();
  }

  DType dtype() const noexcept {
    return tensor_.dtype();
  }

  void *dptr() const noexcept {
    return tensor_.dptr();
  }

  template <typename T>
  T *cpu_dptr() const {
    NVTE_CHECK(TypeInfo<T>::dtype == tensor_.dtype(), "Invalid type!");
    return reinterpret_cast<T *>(cpu_data_.get());
  }

  void to_cpu() const;
  void from_cpu() const;

 private:
  TensorWrapper tensor_;
  std::unique_ptr<unsigned char[]> cpu_data_;
};

size_t typeToSize(DType type);
size_t product(const NVTEShape &shape);

bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2);

void compareResults(const std::string &name, const Tensor &test, const void *ref,
                    double atol = 1e-5, double rtol = 1e-8);

std::pair<double, double> getTolerances(const DType type);

void fillUniform(const Tensor &t);

constexpr int THREADS_PER_WARP = 32;

const std::string &typeName(DType type);

extern std::vector<DType> all_fp_types;

159
160
bool isFp8Type(DType type);

Przemek Tredak's avatar
Przemek Tredak committed
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
}  // namespace test

#define TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(dtype, type, ...) \
    switch (dtype) { \
        using namespace transformer_engine; \
        case DType::kByte: \
            { \
                using type = byte; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kInt32: \
            { \
                using type = int32; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat32: \
            { \
                using type = float; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat16: \
            { \
                using type = fp16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kBFloat16: \
            { \
                using type = bf16; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E4M3: \
            { \
                using type = fp8e4m3; \
                {__VA_ARGS__} \
            } \
        break; \
        case DType::kFloat8E5M2: \
            { \
                using type = fp8e5m2; \
                {__VA_ARGS__} \
            } \
        break; \
        default: \
            NVTE_ERROR("Invalid type."); \
    }