transpose.cpp 5.44 KB
Newer Older
1
2
3
4
5
6
7
8
/*************************************************************************
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "transformer_engine/transpose.h"

9
10
#include "jax/csrc/extensions.h"

11
12
13
14
15
namespace transformer_engine {
namespace jax {

void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
                   void *output) {
16
17
  auto input_shape = std::vector<size_t>{rows, cols};
  auto output_shape = std::vector<size_t>{cols, rows};
18

19
20
  auto input_tensor = TensorWrapper(input, input_shape, dtype);
  auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
21

22
  nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
23
24
25
}

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
26
27
  void *input = buffers[0];
  void *output = buffers[1];
28

29
30
31
32
33
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto rows = desc.shape.dims[0];
  auto cols = desc.shape.dims[1];
  assert(desc.in_dtype == desc.out_dtype);
  auto dtype = desc.out_dtype;
34

35
  TransposeImpl(input, rows, cols, dtype, stream, output);
36
37
38
}

void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
39
40
41
42
43
44
45
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *input_cast = buffers[4];
  auto *input_cast_trans = buffers[5];
  float *amax_out = reinterpret_cast<float *>(buffers[6]);
46
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66

  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto input_shape = std::vector<size_t>{m, n};
  auto input_trans_shape = std::vector<size_t>{n, m};

  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto input_cast_tensor =
      TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto input_cast_trans_tensor = TensorWrapper(input_cast_trans, input_trans_shape, desc.out_dtype,
                                               amax_out, scale, scale_inv);

  nvte_cast_transpose(input_tensor.data(), input_cast_tensor.data(), input_cast_trans_tensor.data(),
                      stream);
67
68
69
}

pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
70
71
72
73
74
                                                    DType in_dtype, DType out_dtype) {
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
75

76
77
78
79
  auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype);
  auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype);
  auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype);
  auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype);
80

81
  TensorWrapper dummy_workspace;
82

83
84
  nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
                            dbias_tensor.data(), dummy_workspace.data(), nullptr);
85

86
87
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
88
89
90
}

void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
91
92
93
94
95
96
97
98
99
100
101
102
                        size_t opaque_len) {
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  auto *output_trans = buffers[5];
  auto *dbias = buffers[6];
  float *amax_out = reinterpret_cast<float *>(buffers[7]);
  void *workspace_ptr = buffers[8];

  const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
103
104
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};

  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto output_trans_tensor =
      TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv);
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);

  auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);

  nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), output_trans_tensor.data(),
                            dbias_tensor.data(), workspace.data(), stream);
128
129
130
131
}

}  // namespace jax
}  // namespace transformer_engine