common.h 5.78 KB
Newer Older
Przemek Tredak's avatar
Przemek Tredak committed
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
3
4
5
6
7
8
9
10
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_

#include <ATen/ATen.h>
cyanguwa's avatar
cyanguwa committed
11
#include <ATen/Dispatch.h>
Tim Moon's avatar
Tim Moon committed
12
#include <ATen/cuda/CUDAContext.h>
cyanguwa's avatar
cyanguwa committed
13
#include <ATen/cuda/CUDAGeneratorImpl.h>
Tim Moon's avatar
Tim Moon committed
14
15
16
17
#include <ATen/cudnn/Handle.h>
#include <ATen/native/DispatchStub.h>
#include <c10/macros/Macros.h>
#include <cublasLt.h>
Przemek Tredak's avatar
Przemek Tredak committed
18
19
#include <cuda.h>
#include <cuda_bf16.h>
Tim Moon's avatar
Tim Moon committed
20
#include <cuda_runtime.h>
21
#include <cudnn.h>
Tim Moon's avatar
Tim Moon committed
22
23
24
25
#include <torch/extension.h>
#include <torch/torch.h>
#include <transformer_engine/activation.h>
#include <transformer_engine/cast.h>
26
#include <transformer_engine/cast_transpose_noop.h>
Tim Moon's avatar
Tim Moon committed
27
#include <transformer_engine/fused_attn.h>
28
#include <transformer_engine/fused_rope.h>
Tim Moon's avatar
Tim Moon committed
29
30
#include <transformer_engine/gemm.h>
#include <transformer_engine/layer_norm.h>
31
#include <transformer_engine/permutation.h>
32
#include <transformer_engine/recipe.h>
Tim Moon's avatar
Tim Moon committed
33
34
35
36
#include <transformer_engine/rmsnorm.h>
#include <transformer_engine/softmax.h>
#include <transformer_engine/transformer_engine.h>
#include <transformer_engine/transpose.h>
37
38
39
40
41
42
43
44
45
46
47

#include <ATen/cuda/CUDAGraphsUtils.cuh>
#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <stdexcept>
#include <vector>

#include "common/util/logging.h"
Przemek Tredak's avatar
Przemek Tredak committed
48
49
50
51
52
53
54

namespace transformer_engine {

// Each tensor here is shape (N, ) holding all scaling
// data for a single FP8 block, e.g. LayerNormLinear
class FP8TensorMeta {
 public:
55
56
57
  at::Tensor scale;
  at::Tensor scale_inv;
  at::Tensor amax_history;
Przemek Tredak's avatar
Przemek Tredak committed
58
59
60
61
62
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8FwdTensors {
63
64
65
66
67
68
69
70
71
  GEMM1_INPUT = 0,
  GEMM1_WEIGHT = 1,
  GEMM1_OUTPUT = 2,
  GEMM2_INPUT = 3,
  GEMM2_WEIGHT = 4,
  GEMM2_OUTPUT = 5,
  GEMM3_INPUT = 6,
  GEMM3_WEIGHT = 7,
  GEMM3_OUTPUT = 8
Przemek Tredak's avatar
Przemek Tredak committed
72
73
74
75
76
};

// Used as named indices on the `scale`, `scale_inv`,
// and `amax` tensors in the `FP8TensorMeta` class.
enum FP8BwdTensors {
77
78
79
80
81
82
  GRAD_OUTPUT1 = 0,
  GRAD_INPUT1 = 1,
  GRAD_OUTPUT2 = 2,
  GRAD_INPUT2 = 3,
  GRAD_OUTPUT3 = 4,
  GRAD_INPUT3 = 5
Przemek Tredak's avatar
Przemek Tredak committed
83
84
85
86
87
};

}  // namespace transformer_engine

transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid,
88
                                                      const std::string& fp8_recipe);
Przemek Tredak's avatar
Przemek Tredak committed
89
90

inline at::ScalarType GetATenDType(transformer_engine::DType t) {
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  switch (t) {
    case transformer_engine::DType::kInt32:
      return torch::kInt32;
    case transformer_engine::DType::kInt64:
      return torch::kInt64;
    case transformer_engine::DType::kFloat32:
      return at::kFloat;
    case transformer_engine::DType::kFloat16:
      return at::kHalf;
    case transformer_engine::DType::kBFloat16:
      return at::kBFloat16;
    case transformer_engine::DType::kByte:
    case transformer_engine::DType::kFloat8E4M3:
    case transformer_engine::DType::kFloat8E5M2:
      return at::kByte;
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
109
110
111
}

inline transformer_engine::DType GetTransformerEngineDType(at::ScalarType t) {
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  switch (t) {
    case at::kHalf:
      return transformer_engine::DType::kFloat16;
    case at::kFloat:
      return transformer_engine::DType::kFloat32;
    case at::kBFloat16:
      return transformer_engine::DType::kBFloat16;
    case at::kBool:
      return transformer_engine::DType::kByte;
    case torch::kByte:
      return transformer_engine::DType::kByte;
    case torch::kInt32:
      return transformer_engine::DType::kInt32;
    case torch::kInt64:
      return transformer_engine::DType::kInt64;
    default:
      NVTE_ERROR("Invalid type");
  }
Przemek Tredak's avatar
Przemek Tredak committed
130
131
132
}

inline transformer_engine::DType GetTransformerEngineDType(int DType_value) {
133
  return static_cast<transformer_engine::DType>(DType_value);
Przemek Tredak's avatar
Przemek Tredak committed
134
135
136
137
}

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
138
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
139

140
141
142
transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const std::vector<size_t>& shape,
                                                              const transformer_engine::DType type,
143
144
                                                              void* amax_ptr, void* scale_ptr,
                                                              void* scale_inv_ptr);
Przemek Tredak's avatar
Przemek Tredak committed
145
146
147

transformer_engine::TensorWrapper makeTransformerEngineTensor(void* data_ptr,
                                                              const NVTEShape& shape,
148
                                                              const transformer_engine::DType type);
Przemek Tredak's avatar
Przemek Tredak committed
149
150
151

transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor);

152
transformer_engine::TensorWrapper makeTransformerEngineTensor(at::Tensor tensor, at::Tensor amax,
153
154
155
                                                              const at::Tensor scale,
                                                              at::Tensor scale_inv);

156
size_t product(const std::vector<size_t>& shape);
Przemek Tredak's avatar
Przemek Tredak committed
157

158
at::Tensor allocateSpace(const std::vector<size_t>& shape, const transformer_engine::DType type,
cyanguwa's avatar
cyanguwa committed
159
                         bool init_to_zeros);
Przemek Tredak's avatar
Przemek Tredak committed
160

161
at::Tensor allocateSpace(const NVTEShape& shape, const transformer_engine::DType type,
Przemek Tredak's avatar
Przemek Tredak committed
162
163
                         bool init_to_zeros = false);

164
at::Tensor allocateTorchTensor(int M, int N, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
165

166
at::Tensor allocateTorchTensor(int M, transformer_engine::DType dtype);
Przemek Tredak's avatar
Przemek Tredak committed
167

168
void* getDataPtr(at::Tensor tensor, int offset = 0);
169

Przemek Tredak's avatar
Przemek Tredak committed
170
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_COMMON_H_