transpose.cpp 13.4 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "transformer_engine/transpose.h"

9
#include "extensions.h"
10
#include "transformer_engine/cast.h"
11
#include "xla/ffi/api/ffi.h"
12

13
14
15
16
17
namespace transformer_engine {
namespace jax {

void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream,
                   void *output) {
18
19
  auto input_shape = std::vector<size_t>{rows, cols};
  auto output_shape = std::vector<size_t>{cols, rows};
20

21
22
  auto input_tensor = TensorWrapper(input, input_shape, dtype);
  auto transposed_tensor = TensorWrapper(output, output_shape, dtype);
23

24
  nvte_transpose(input_tensor.data(), transposed_tensor.data(), stream);
25
26
27
}

void Transpose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
28
29
  void *input = buffers[0];
  void *output = buffers[1];
30

31
32
33
34
35
  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  auto rows = desc.shape.dims[0];
  auto cols = desc.shape.dims[1];
  assert(desc.in_dtype == desc.out_dtype);
  auto dtype = desc.out_dtype;
36

37
  TransposeImpl(input, rows, cols, dtype, stream, output);
38
39
}

40
41
42
43
44
45
46
47
48
49
Error_Type TransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Result_Type output_buf,
                        int64_t transpose_axis) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());

  void *input = input_buf.untyped_data();
  void *output = output_buf->untyped_data();

  auto input_dims = input_buf.dimensions();
  if (transpose_axis < 0) transpose_axis += input_dims.size();
50
51
52
  auto m = product(input_dims, 0, transpose_axis);
  auto n = product(input_dims, transpose_axis, input_dims.size());

53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{n, m};

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, out_dtype);

  nvte_transpose(input_tensor.data(), output_tensor.data(), stream);
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(TransposeHandler, TransposeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Ret<Buffer_Type>()      // output
                                  .Attr<int64_t>("transpose_axis"),
                              FFI_CudaGraph_Traits);

71
void CastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) {
72
73
74
75
76
77
78
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *input_cast = buffers[4];
  auto *input_cast_trans = buffers[5];
  float *amax_out = reinterpret_cast<float *>(buffers[6]);
79
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
80
81
82
83
84
85
86
87
88
89
90
91
92

  const auto &desc = *UnpackOpaque<CustomCallCommonDescriptor>(opaque, opaque_len);
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto input_shape = std::vector<size_t>{m, n};
  auto input_trans_shape = std::vector<size_t>{n, m};

  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
93
  auto output_tensor =
94
      TensorWrapper(input_cast, input_shape, desc.out_dtype, amax_out, scale, scale_inv);
95
96
  output_tensor.set_columnwise_data(input_cast_trans, desc.out_dtype, input_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
97

98
  nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
99
100
}

101
102
Error_Type CastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                            Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
103
                            Result_Type output_buf, Result_Type output_trans_buf,
104
105
                            Result_Type amax_out_buf, int64_t transpose_axis) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
106
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
107
108
109
110
111
112

  auto *input = input_buf.untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

113
114
  auto *output = output_buf->untyped_data();
  auto *output_trans = output_trans_buf->untyped_data();
115
  float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
116
  NVTE_CHECK(amax == amax_out, "amax not bound to amax_out in TE/JAX CastTranspose primitive.");
117
118
119
120
121
122
123
124
125

  if (!use_fp8(out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }

  auto input_dims = input_buf.dimensions();
  if (transpose_axis < 0) transpose_axis += input_dims.size();
126
127
  auto m = product(input_dims, 0, transpose_axis);
  auto n = product(input_dims, transpose_axis, input_dims.size());
128
  auto input_shape = std::vector<size_t>{m, n};
129
130
  auto output_shape = input_shape;
  auto output_trans_shape = std::vector<size_t>{n, m};
131
132

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
133
  auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
134
135
136
137
  output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});

  nvte_quantize(input_tensor.data(), output_tensor.data(), stream);
138
139
140
141
142
143
144
145
146
147
148

  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(CastTransposeHandler, CastTransposeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
149
150
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // output_trans
151
                                  .Ret<Buffer_Type>()      // amax_out
152
153
                                  .Attr<int64_t>("transpose_axis"),
                              FFI_CudaGraph_Traits);
154

155
pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size,
156
157
158
159
160
                                                    DType in_dtype, DType out_dtype) {
  auto input_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_shape = std::vector<size_t>{batch_size, hidden_size};
  auto output_trans_shape = std::vector<size_t>{hidden_size, batch_size};
  auto dbias_shape = std::vector<size_t>{hidden_size};
161

162
163
164
165
166
167
168
169
170
171
172
  // Evil hack to specify TE impl
  // Note: nvte_quantize_dbias chooses its internal impl based on what
  // pointers are allocated, e.g. whether to output with column-wise
  // data. However, we don't have access to any allocated buffers in
  // this function. We pass a dummy pointer as a workaround.
  int temp = 0;

  auto input_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), input_shape, in_dtype);
  auto output_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), output_shape, out_dtype);
  output_tensor.set_columnwise_data(reinterpret_cast<void *>(&temp), out_dtype, output_trans_shape);
  auto dbias_tensor = TensorWrapper(reinterpret_cast<void *>(&temp), dbias_shape, in_dtype);
173

174
  TensorWrapper dummy_workspace;
175

176
177
  nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                      dummy_workspace.data(), nullptr);
178

179
180
  auto work_shape = MakeShapeVector(dummy_workspace.shape());
  return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype()));
181
182
183
}

void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque,
184
185
186
187
188
189
190
191
192
193
194
195
                        size_t opaque_len) {
  auto *input = buffers[0];
  float *amax = reinterpret_cast<float *>(buffers[1]);
  float *scale = reinterpret_cast<float *>(buffers[2]);
  float *scale_inv = reinterpret_cast<float *>(buffers[3]);
  auto *output = buffers[4];
  auto *output_trans = buffers[5];
  auto *dbias = buffers[6];
  float *amax_out = reinterpret_cast<float *>(buffers[7]);
  void *workspace_ptr = buffers[8];

  const auto &desc = *UnpackOpaque<CustomCallCommonWkDescriptor>(opaque, opaque_len);
196
197
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
  if (!use_fp8(desc.out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }
  auto m = desc.shape.dims[0];
  auto n = desc.shape.dims[1];
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};

  auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype);
  auto output_tensor =
      TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv);
213
214
  output_tensor.set_columnwise_data(output_trans, desc.out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
215
216
217
218
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype);

  auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype);

219
220
  nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                      workspace.data(), stream);
221
222
}

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
Error_Type DBiasCastTransposeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type amax_buf,
                                 Buffer_Type scale_buf, Buffer_Type scale_inv_buf,
                                 Result_Type output_buf, Result_Type output_trans_buf,
                                 Result_Type dbias_buf, Result_Type amax_out_buf,
                                 Result_Type workspace_buf, int64_t transpose_axis) {
  auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type());
  auto out_dtype = convert_ffi_datatype_to_te_dtype(output_buf->element_type());
  auto workspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type());

  auto *input = input_buf.untyped_data();
  float *amax = reinterpret_cast<float *>(amax_buf.untyped_data());
  float *scale = reinterpret_cast<float *>(scale_buf.untyped_data());
  float *scale_inv = reinterpret_cast<float *>(scale_inv_buf.untyped_data());

  auto *output = output_buf->untyped_data();
  auto *output_trans = output_trans_buf->untyped_data();
  auto *dbias = dbias_buf->untyped_data();
  float *amax_out = reinterpret_cast<float *>(amax_out_buf->untyped_data());
  void *workspace = workspace_buf->untyped_data();
  NVTE_CHECK(amax == amax_out,
             "amax not bound to amax_out in TE/JAX DBiasCastTranspose primitive.");
  if (!use_fp8(out_dtype)) {
    scale = nullptr;
    scale_inv = nullptr;
    amax_out = nullptr;
  }

  auto input_dims = input_buf.dimensions();
  auto workspace_dims = workspace_buf->dimensions();
  if (transpose_axis < 0) transpose_axis += input_dims.size();
  auto m = product(input_dims, 0, transpose_axis);
  auto n = product(input_dims, transpose_axis, input_dims.size());
  auto input_shape = std::vector<size_t>{m, n};
  auto output_shape = std::vector<size_t>{m, n};
  auto output_trans_shape = std::vector<size_t>{n, m};
  auto dbias_shape = std::vector<size_t>{n};
  std::vector<size_t> workspace_shape(workspace_dims.begin(), workspace_dims.end());

  auto input_tensor = TensorWrapper(input, input_shape, in_dtype);
  auto output_tensor = TensorWrapper(output, output_shape, out_dtype, amax_out, scale, scale_inv);
263
264
  output_tensor.set_columnwise_data(output_trans, out_dtype, output_trans_shape);
  output_tensor.set_columnwise_scale_inv(scale_inv, DType::kFloat32, std::vector<size_t>{1});
265
266
267
  auto dbias_tensor = TensorWrapper(dbias, dbias_shape, in_dtype);
  auto workspace_tensor = TensorWrapper(workspace, workspace_shape, workspace_dtype);

268
269
  nvte_quantize_dbias(input_tensor.data(), output_tensor.data(), dbias_tensor.data(),
                      workspace_tensor.data(), stream);
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
  return ffi_with_cuda_error_check();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasCastTransposeHandler, DBiasCastTransposeFFI,
                              FFI::Bind()
                                  .Ctx<FFI_Stream_Type>()  // stream
                                  .Arg<Buffer_Type>()      // input
                                  .Arg<Buffer_Type>()      // amax
                                  .Arg<Buffer_Type>()      // scale
                                  .Arg<Buffer_Type>()      // scale_inv
                                  .Ret<Buffer_Type>()      // output
                                  .Ret<Buffer_Type>()      // output_trans
                                  .Ret<Buffer_Type>()      // dbias
                                  .Ret<Buffer_Type>()      // amax_out
                                  .Ret<Buffer_Type>()      // workspace
                                  .Attr<int64_t>("transpose_axis"),
                              FFI_CudaGraph_Traits);

288
289
}  // namespace jax
}  // namespace transformer_engine