util.h 826 Bytes
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
 *
 * See LICENSE for license information.
 ************************************************************************/

7
8
#ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
#define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_
9

10
11
12
13
14
15
16
17
18
19
20
21
22
#include <torch/extension.h>

#include <optional>

#include "transformer_engine/transformer_engine.h"

/* Swizzle the scaling factor of the input tensor.
 *
 * The returned swizzled scaling factor tensor should be kept alive during the GEMM.
 */
std::optional<at::Tensor> swizzle_scaling_factors(transformer_engine::TensorWrapper &input,
                                                  bool trans);

23
#endif  // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_