util.h 1.98 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
9

10
11
12
#include <torch/extension.h>

#include <optional>
13
14
#include <tuple>
#include <vector>
15
16
17

#include "transformer_engine/transformer_engine.h"

18
19
20
21
namespace transformer_engine {
namespace pytorch {

/*! \brief Convert tensor block scales into GEMM swizzled format.
22
 *
23
 *  The returned swizzled scales should be kept alive during the GEMM.
24
 */
25
26
std::tuple<std::optional<at::Tensor>, std::optional<at::Tensor>> swizzle_scales_for_gemm(
    TensorWrapper& tensor, bool rowwise_usage, bool columnwise_usage);
27

28
/*! \brief Convert multiple tensor block scales into GEMM swizzled format.
29
 *
30
 *  The returned swizzled scales should be kept alive during the GEMMs.
31
 */
32
33
34
std::optional<at::Tensor> multi_tensor_swizzle_scales_for_gemm(std::vector<TensorWrapper>& tensors,
                                                               bool rowwise_usage,
                                                               bool columnwise_usage);
35

36
37
/*! \brief Convert a block scaling tensor to an mxfp8 tensor in-place.
 *
38
39
40
41
42
 *  If rowwise==false, the columnwise data will be reinterpreted as
 *  rowwise data to avoid transposing it in memory. Due to differences
 *  in how block scaling and mxfp8 store data, this requires the
 *  calling code to treat the output tensor as having been transposed
 *  in this case.
43
 *
44
45
46
 *  Returns the swizzled scaling factor of the converted mxfp8 tensor.
 *  The returned swizzled scaling factor tensor should be kept alive
 *  during the GEMM.
47
 */
48
49
50
51
at::Tensor convert_block_scaling_to_mxfp8_tensor(TensorWrapper& input, bool rowwise);

}  // namespace pytorch
}  // namespace transformer_engine
52

53
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_