util.cpp 11.2 KB
Newer Older
1
2
3
4
5
6
7
8
9
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "util.h"

#include "common.h"
10
#include "common/common.h"
11
12
13
14
15
16
17

std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper& input,
                                                  bool rowwise) {
  using namespace transformer_engine::pytorch;

  if (input.scaling_mode() == NVTE_INVALID_SCALING) {
    NVTE_ERROR("Invalid scaling mode for swizzle.");
18
19
  } else if (input.scaling_mode() != NVTE_MXFP8_1D_SCALING &&
             input.scaling_mode() != NVTE_NVFP4_1D_SCALING) {
20
21
22
    return std::nullopt;
  }

23
24
25
26
  NVTE_CHECK(input.element_size_bits() == 4 || input.element_size_bits() == 8,
             "4-bit or 8-bit input required for swizzling scaling factors.");

  const auto nvfp4 = input.scaling_mode() == NVTE_NVFP4_1D_SCALING;
27
28

  NVTEBasicTensor scale_inv;
29
  NVTEShape nvte_input_shape;
30
  if (rowwise) {
31
    nvte_input_shape = input.shape();
32
33
    scale_inv = input.get_rowwise_scale_inv();
  } else {
34
    nvte_input_shape = input.get_columnwise_data().shape;
35
36
37
    scale_inv = input.get_columnwise_scale_inv();
  }

38
  auto input_shape = nvte_shape_to_vector(nvte_input_shape);
39
40
  auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);

41
42
  NVTE_CHECK(input_shape.size() >= 2, "Wrong ndims for swizzle input shape.");

43
44
45
46
47
48
49
50
51
52
  // Allocate memory for swizzled output.
  auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
  std::vector<int64_t> scale_inv_shape_int;
  for (size_t i = 0; i < scale_inv_shape.size(); ++i) {
    scale_inv_shape_int.push_back(static_cast<int64_t>(scale_inv_shape[i]));
  }
  auto swizzled_scale_inv = at::empty(scale_inv_shape_int, options);
  void* scale_inv_dptr = scale_inv.data_ptr;
  void* swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);

53
54
55
56
57
58
59
60
  transformer_engine::TensorWrapper input_cu(input.scaling_mode());
  transformer_engine::TensorWrapper output_cu(input.scaling_mode());

  const auto input_dtype =
      (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
  const auto scale_inv_dtype =
      (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;

61
  if (rowwise) {
62
63
64
65
    input_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
    input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
    output_cu.set_rowwise_data(input.dptr(), input_dtype, input_shape);
    output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
66
  } else {
67
68
69
70
    input_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
    input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
    output_cu.set_columnwise_data(input.columnwise_dptr(), input_dtype, input_shape);
    output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
71
72
73
74
75
76
  }

  // Launch kernel
  nvte_swizzle_scaling_factors(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream());

  if (rowwise) {
77
    input.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
78
  } else {
79
    input.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shape);
80
81
82
83
  }

  return swizzled_scale_inv;
}
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100

std::optional<at::Tensor> multi_tensor_swizzle_scaling_factors(
    std::vector<transformer_engine::TensorWrapper>& tensors, bool rowwise) {
  using namespace transformer_engine::pytorch;

  if (tensors.empty()) {
    return std::nullopt;
  }

  bool all_same_scaling_mode = std::all_of(
      tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) {
        return val.scaling_mode() == tensors.front().scaling_mode();
      });
  NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same.");

  if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) {
    NVTE_ERROR("Invalid scaling mode for swizzle.");
101
102
  } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING &&
             tensors.front().scaling_mode() != NVTE_NVFP4_1D_SCALING) {
103
104
105
    return std::nullopt;
  }

106
107
108
  const auto scaling_mode = tensors.front().scaling_mode();
  const auto nvfp4 = scaling_mode == NVTE_NVFP4_1D_SCALING;

109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
  std::vector<transformer_engine::TensorWrapper> wrappers;
  std::vector<NVTETensor> input_tensors, output_tensors;

  // Collect scale_inv shapes and calculate buffer size and offsets for scale_invs
  std::vector<std::vector<size_t>> scale_inv_shapes;
  std::vector<void*> scale_inv_dptrs;
  size_t buffer_size = 0;
  std::vector<size_t> scale_inv_offsets;
  constexpr size_t scale_elem_size = 1;
  for (auto& tensor : tensors) {
    NVTEBasicTensor scale_inv;
    if (rowwise) {
      scale_inv = tensor.get_rowwise_scale_inv();
    } else {
      scale_inv = tensor.get_columnwise_scale_inv();
    }
    auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape);
    buffer_size = roundup(buffer_size, 16);  // align to 16B
    scale_inv_offsets.push_back(buffer_size);
    buffer_size += product(scale_inv_shape) * scale_elem_size;
    scale_inv_shapes.emplace_back(scale_inv_shape);
    scale_inv_dptrs.push_back(scale_inv.data_ptr);
  }

  // Allocate full buffer
  auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8));

136
137
138
139
140
  const auto input_dtype =
      (nvfp4) ? transformer_engine::DType::kFloat4E2M1 : transformer_engine::DType::kFloat8E4M3;
  const auto scale_inv_dtype =
      (nvfp4) ? transformer_engine::DType::kFloat8E4M3 : transformer_engine::DType::kFloat8E8M0;

141
142
143
144
  for (size_t i = 0; i < tensors.size(); ++i) {
    auto& tensor = tensors[i];
    void* scale_inv_dptr = scale_inv_dptrs[i];
    void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]);
145
146
147
148
149
150
151
152
    // auto input_shape = nvte_shape_to_vector(tensor.shape());
    NVTEShape nvte_input_shape;
    if (rowwise) {
      nvte_input_shape = tensor.shape();
    } else {
      nvte_input_shape = tensor.get_columnwise_data().shape;
    }
    auto input_shape = nvte_shape_to_vector(nvte_input_shape);
153
154
    // Reconstruct input only to avoid swizzling both directions if not needed.
    // Use any 8 bit type, it's irrelevant.
155
156
    transformer_engine::TensorWrapper input_cu(scaling_mode);
    transformer_engine::TensorWrapper output_cu(scaling_mode);
157
    if (rowwise) {
158
159
160
161
162
      input_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
      input_cu.set_rowwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
      output_cu.set_rowwise_data(tensor.dptr(), input_dtype, input_shape);
      output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
                                      scale_inv_shapes[i]);
163
      // Set the swizzled scaling factor to the original tensor.
164
      tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
165
    } else {
166
167
168
169
170
      input_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
      input_cu.set_columnwise_scale_inv(scale_inv_dptr, scale_inv_dtype, scale_inv_shapes[i]);
      output_cu.set_columnwise_data(tensor.columnwise_dptr(), input_dtype, input_shape);
      output_cu.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
                                         scale_inv_shapes[i]);
171
      // Set the swizzled scaling factor to the original tensor.
172
173
      tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, scale_inv_dtype,
                                      scale_inv_shapes[i]);
174
175
176
177
178
179
180
181
182
183
184
185
186
187
    }

    input_tensors.emplace_back(input_cu.data());
    output_tensors.emplace_back(output_cu.data());
    wrappers.emplace_back(std::move(input_cu));
    wrappers.emplace_back(std::move(output_cu));
  }

  // Launch kernel
  nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(),
                                            input_tensors.size(), at::cuda::getCurrentCUDAStream());

  return buffer;
}
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256

at::Tensor convert_block_scaling_to_mxfp8_tensor(transformer_engine::TensorWrapper& input,
                                                 bool rowwise) {
  using namespace transformer_engine::pytorch;
  using transformer_engine::DIVUP;

  // Check input tensor
  const NVTEScalingMode scaling_mode = input.scaling_mode();
  NVTE_CHECK(scaling_mode == NVTE_BLOCK_SCALING_1D || scaling_mode == NVTE_BLOCK_SCALING_2D,
             "Input tensor must be a block scaling tensor");

  // Get tensor data
  NVTEBasicTensor data;
  size_t data_flat_first_dim = 1;
  size_t data_flat_last_dim = 1;
  if (rowwise) {
    data = input.get_rowwise_data();
    for (int i = 0; i < data.shape.ndim - 1; ++i) {
      data_flat_first_dim *= data.shape.data[i];
    }
    data_flat_last_dim = data.shape.data[data.shape.ndim - 1];
  } else {
    data = input.get_columnwise_data();
    data_flat_first_dim = data.shape.data[0];
    for (int i = 1; i < data.shape.ndim; ++i) {
      data_flat_last_dim *= data.shape.data[i];
    }
  }
  NVTEShape data_shape{};
  data_shape.data[0] = data_flat_first_dim;
  data_shape.data[1] = data_flat_last_dim;
  data_shape.ndim = 2;

  // Recreate input tensor with rowwise usage
  transformer_engine::TensorWrapper input_cu(scaling_mode);
  input_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
  const NVTEBasicTensor scale_inv =
      rowwise ? input.get_rowwise_scale_inv() : input.get_columnwise_scale_inv();
  input_cu.set_rowwise_scale_inv(
      scale_inv.data_ptr, static_cast<transformer_engine::DType>(scale_inv.dtype), scale_inv.shape);

  // Create output tensor
  transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING);
  output_cu.set_rowwise_data(data.data_ptr, input.dtype(), data_shape);
  // Output swizzled mxfp8 scaling factor dimensions
  const size_t swizzled_scale_inv_first_dim = DIVUP<size_t>(data_flat_first_dim, 128) * 128;
  const size_t swizzled_scale_inv_last_dim = DIVUP<size_t>(data_flat_last_dim, 128) * 4;
  // Allocate memory for swizzled mxfp8 scaling factors
  const auto options = at::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
  at::Tensor swizzled_scale_inv = at::empty(
      std::vector<int64_t>{swizzled_scale_inv_first_dim, swizzled_scale_inv_last_dim}, options);
  // Set rowwise scaling factors on output
  void* const swizzled_scale_inv_dptr = getDataPtr(swizzled_scale_inv, 0);
  NVTEShape swizzled_scale_inv_shape{};
  swizzled_scale_inv_shape.data[0] = swizzled_scale_inv_first_dim;
  swizzled_scale_inv_shape.data[1] = swizzled_scale_inv_last_dim;
  swizzled_scale_inv_shape.ndim = 2;
  output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0,
                                  swizzled_scale_inv_shape);

  // Convert scaling factors from FP8 block scaling GEMM_READY format to mxfp8 swizzled format
  nvte_swizzle_block_scaling_to_mxfp8_scaling_factors(input_cu.data(), output_cu.data(),
                                                      at::cuda::getCurrentCUDAStream());

  // Set the input tensor to be the converted mxfp8 tensor and return the swizzled scaling factor
  // for it to be kept alive during the GEMM
  input = std::move(output_cu);
  return swizzled_scale_inv;
}