config.h 1.84 KB
Newer Older
1
/*************************************************************************
2
 * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
4
5
6
7
8
9
10
11
 *
 * See LICENSE for license information.
 ************************************************************************/

#ifndef TRANSFORMER_ENGINE_GEMM_CONFIG_H_
#define TRANSFORMER_ENGINE_GEMM_CONFIG_H_

#include <transformer_engine/transformer_engine.h>

12
13
14
#include <cstdint>
#include <optional>

15
16
17
18
19
20
21
22
23
24
25
26
27
28
namespace transformer_engine {

struct MatmulConfig {
  NVTETensor bias_tensor = nullptr;
  NVTETensor dbias_tensor = nullptr;
  bool with_gelu_epilogue = false;
  bool with_dgelu_epilogue = false;
  NVTETensor epilogue_aux_tensor = nullptr;
  bool use_split_accumulator = false;
  int sm_count = 0;

  static constexpr size_t attr_sizes[] = {
      sizeof(NVTETensor),  // bias_tensor
      sizeof(NVTETensor),  // dbias_tensor
29
30
      sizeof(uint8_t),     // with_gelu_epilogue
      sizeof(uint8_t),     // with_dgelu_epilogue
31
      sizeof(NVTETensor),  // epilogue_aux_tensor
32
33
      sizeof(uint8_t),     // use_split_accumulator
      sizeof(int32_t)      // sm_count
34
35
36
  };
};

37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
struct GroupedMatmulConfig {
  // Average dimension hints for cuBLASLt algorithm selection heuristics.
  // nullopt means "not set" - compute automatically from tensor shapes.
  std::optional<int64_t> avg_m;
  std::optional<int64_t> avg_n;
  std::optional<int64_t> avg_k;

  // Number of streaming multiprocessors to use in GEMM kernel
  int sm_count = 0;

  // Note: API transfers the value type, not std::optional
  static constexpr size_t attr_sizes[] = {sizeof(decltype(avg_m)::value_type),
                                          sizeof(decltype(avg_n)::value_type),
                                          sizeof(decltype(avg_k)::value_type), sizeof(sm_count)};
};

53
54
55
}  // namespace transformer_engine

#endif  // TRANSFORMER_ENGINE_GEMM_CONFIG_H_