common.h 13.6 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
/*************************************************************************
 * Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
#include <transformer_engine/transpose.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/logging.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/cast.h>
#include <ATen/ATen.h>
#include <ATen/cudnn/Handle.h>
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <torch/torch.h>
#include <cuda.h>
#include <cuda_runtime.h>
#include <cuda_bf16.h>
#include <stdexcept>
#include <memory>
#include <iomanip>
#include <random>
#include <cstring>
#include <vector>
#include <iostream>


namespace transformer_engine {

// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
    at::Tensor scale;
    at::Tensor scale_inv;
    at::Tensor amax_history;
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
    GEMM1_INPUT  = 0,
    GEMM1_WEIGHT = 1,
    GEMM2_INPUT  = 2,
    GEMM2_WEIGHT = 3
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
    GRAD_OUTPUT1 = 0,
    GRAD_OUTPUT2 = 1
};


}  // namespace transformer_engine


transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
                                                      const std::string &fp8_recipe);


inline at::ScalarType GetATenDType(transformer_engine::DType t) {
    switch (t) {
        case transformer_engine::DType::kInt32:
        case transformer_engine::DType::kFloat32:
            return at::kFloat;
        case transformer_engine::DType::kFloat16:
            return at::kHalf;
        case transformer_engine::DType::kBFloat16:
            return at::kBFloat16;
        case transformer_engine::DType::kByte:
        case transformer_engine::DType::kFloat8E4M3:
        case transformer_engine::DType::kFloat8E5M2:
            return at::kByte;
        default:
            NVTE_ERROR("Invalid type");
    }
}


inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
    switch (t) {
        case at::kHalf:
            return transformer_engine::DType::kFloat16;
        case at::kFloat:
            return transformer_engine::DType::kFloat32;
        case at::kBFloat16:
            return transformer_engine::DType::kBFloat16;
        default:
            NVTE_ERROR("Invalid type");
    }
}


inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
    return static_cast<transformer_engine::DType>(DType_value);
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
                                                              const transformer_engine::DType type
);


transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
                                                              const transformer_engine::DType type
);


transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);


size_t product(const std::vector<size_t> &shape);


at::Tensor allocateSpace(const NVTEShape &shape,
                         const transformer_engine::DType type,
                         bool init_to_zeros = false);


at::Tensor allocateTorchTensor(int M,
                               int N,
                               transformer_engine::DType dtype
);


at::Tensor allocateTorchTensor(int M,
                               transformer_engine::DType dtype
);


void dispatch_layernorm(void* input,                                    // i
                        const std::vector<size_t>& input_shape,
                        const transformer_engine::DType input_type,
                        void* gamma,                                    // i
                        const std::vector<size_t>& gamma_shape,
                        const transformer_engine::DType gamma_type,
                        void* beta,                                     // i
                        const std::vector<size_t>& beta_shape,
                        const transformer_engine::DType beta_type,
                        void* scale,                                    // i
                        const std::vector<size_t>& scale_shape,
                        const transformer_engine::DType scale_type,
                        const float epsilon,                            // i
                        void* z,                                        // o
                        const std::vector<size_t>& z_shape,
                        const transformer_engine::DType z_type,
                        void* mu,                                       // o
                        const std::vector<size_t>& mu_shape,
                        const transformer_engine::DType mu_type,
                        void* rsigma,                                   // o
                        const std::vector<size_t>& rsigma_shape,
                        const transformer_engine::DType rsigma_type,
                        void* amax,                                     // o
                        const std::vector<size_t>& amax_shape,
                        const transformer_engine::DType amax_type,
                        void* scale_inv,                                // o
                        const std::vector<size_t>& scale_inv_shape,
                        const transformer_engine::DType scale_inv_type,
168
                        const int multiProcessorCount
Przemek Tredak's avatar
Przemek Tredak committed
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
);


void dispatch_cast_transpose_fusion(void* input,                                            // i
                                    const std::vector<size_t>& input_shape,
                                    const transformer_engine::DType input_type,
                                    void* scale,                                            // i
                                    const std::vector<size_t>& scale_shape,
                                    const transformer_engine::DType scale_type,
                                    void* output_cast,                                      // o
                                    const std::vector<size_t>& output_cast_shape,
                                    const transformer_engine::DType output_cast_type,
                                    void* output_transpose,                                 // o
                                    const std::vector<size_t>& output_transpose_shape,
                                    const transformer_engine::DType output_transpose_type,
                                    void* amax,                                             // o
                                    const std::vector<size_t>& amax_shape,
                                    const transformer_engine::DType amax_type,
                                    void* scale_inv,                                        // o
                                    const std::vector<size_t>& scale_inv_shape,
                                    const transformer_engine::DType scale_inv_type
);


void dispatch_gelu(void* input,                                            // i
                   const std::vector<size_t>& input_shape,
                   const transformer_engine::DType input_type,
                   void* scale,                                            // i
                   const std::vector<size_t>& scale_shape,
                   const transformer_engine::DType scale_type,
                   void* output,                                           // o
                   const std::vector<size_t>& output_shape,
                   const transformer_engine::DType output_type,
                   void* amax,                                             // o
                   const std::vector<size_t>& amax_shape,
                   const transformer_engine::DType amax_type,
                   void* scale_inv,                                        // o
                   const std::vector<size_t>& scale_inv_shape,
                   const transformer_engine::DType scale_inv_type
);


void dispatch_transpose(void* input,                                            // i
                        const std::vector<size_t>& input_shape,
                        const transformer_engine::DType input_type,
                        void* output,                                           // o
                        const std::vector<size_t>& output_shape,
                        const transformer_engine::DType output_type
);


void dispatch_bgrad_cast_transpose_fusion(void* input,                                          // i
                                          const std::vector<size_t>& input_shape,
                                          const transformer_engine::DType input_type,
                                          void* scale,                                          // i
                                          const std::vector<size_t>& scale_shape,
                                          const transformer_engine::DType scale_type,
                                          void* cast_output,                                    // o
                                          const std::vector<size_t>& cast_output_shape,
                                          const transformer_engine::DType cast_output_type,
                                          void* transposed_output,                              // o
                                          const std::vector<size_t>& transposed_output_shape,
                                          const transformer_engine::DType transposed_output_type,
                                          void* amax,                                           // o
                                          const std::vector<size_t>& amax_shape,
                                          const transformer_engine::DType amax_type,
                                          void* dbias,                                          // o
                                          const std::vector<size_t>& dbias_shape,
                                          const transformer_engine::DType dbias_type,
                                          void* scale_inv,                                      // o
                                          const std::vector<size_t>& scale_inv_shape,
                                          const transformer_engine::DType scale_inv_type
);


void dispatch_bgrad_dgelu_cast_transpose_fusion(
        void* input,                                            // i
        const std::vector<size_t>& input_shape,
        const transformer_engine::DType input_type,
        void* gelu_input,                                       // i
        const std::vector<size_t>& gelu_input_shape,
        const transformer_engine::DType gelu_input_type,
        void* scale,                                            // i
        const std::vector<size_t>& scale_shape,
        const transformer_engine::DType scale_type,
        void* cast_output,                                      // o
        const std::vector<size_t>& cast_output_shape,
        const transformer_engine::DType cast_output_type,
        void* transposed_output,                                // o
        const std::vector<size_t>& transposed_output_shape,
        const transformer_engine::DType transposed_output_type,
        void* amax,                                             // o
        const std::vector<size_t>& amax_shape,
        const transformer_engine::DType amax_type,
        void* dbias,                                            // o
        const std::vector<size_t>& dbias_shape,
        const transformer_engine::DType dbias_type,
        void* scale_inv,                                        // o
        const std::vector<size_t>& scale_inv_shape,
        const transformer_engine::DType scale_inv_type
);


Tim Moon's avatar
Tim Moon committed
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
void dispatch_multi_cast_transpose(
        std::vector<void*> input_dptr_list,                     // i
        const std::vector<std::vector<size_t>>& input_shape_list,
        const std::vector<transformer_engine::DType>& input_type_list,
        std::vector<void*> scale_dptr_list,                     // i
        const std::vector<std::vector<size_t>>& scale_shape_list,
        const std::vector<transformer_engine::DType>& scale_type_list,
        std::vector<void*> cast_output_dptr_list,               // o
        const std::vector<std::vector<size_t>>& cast_output_shape_list,
        const std::vector<transformer_engine::DType>& cast_output_type_list,
        std::vector<void*> transposed_output_dptr_list,         // o
        const std::vector<std::vector<size_t>>& transposed_output_shape_list,
        const std::vector<transformer_engine::DType>& transposed_output_type_list,
        std::vector<void*> amax_dptr_list,                      // o
        const std::vector<std::vector<size_t>>& amax_shape_list,
        const std::vector<transformer_engine::DType>& amax_type_list,
        std::vector<void*> scale_inv_dptr_list,                 // o
        const std::vector<std::vector<size_t>>& scale_inv_shape_list,
        const std::vector<transformer_engine::DType>& scale_inv_type_list
);


Przemek Tredak's avatar
Przemek Tredak committed
294
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_