config.cpp 4.49 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
/*************************************************************************
 * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
 *
 * See LICENSE for license information.
 ************************************************************************/

#include "./config.h"

#include <transformer_engine/gemm.h>
#include <transformer_engine/transformer_engine.h>

#include <cstring>

#include "../util/logging.h"

NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; }

void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
                                      void *buf, size_t size_in_bytes, size_t *size_written) {
  // Write attribute size
  NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
             static_cast<int>(attr), ")");
  NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)");
  const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
  *size_written = attr_size;

  // Return immediately if buffer is not provided
  if (buf == nullptr) {
    return;
  }

  // Check buffer size
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for matmul config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");

  // Write to buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
  const auto &config_ = *reinterpret_cast<const transformer_engine::MatmulConfig *>(config);
  switch (attr) {
    case kNVTEMatmulConfigBiasTensor:
      std::memcpy(buf, &config_.bias_tensor, attr_size);
      break;
    case kNVTEMatmulConfigDBiasTensor:
      std::memcpy(buf, &config_.dbias_tensor, attr_size);
      break;
    case kNVTEMatmulConfigWithGELUEpilogue:
      std::memcpy(buf, &config_.with_gelu_epilogue, attr_size);
      break;
    case kNVTEMatmulConfigWithDGELUEpilogue:
      std::memcpy(buf, &config_.with_dgelu_epilogue, attr_size);
      break;
    case kNVTEMatmulConfigEpilogueAuxTensor:
      std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size);
      break;
    case kNVTEMatmulConfigUseSplitAccumulator:
      std::memcpy(buf, &config_.use_split_accumulator, attr_size);
      break;
    case kNVTEMatmulConfigSMCount:
      std::memcpy(buf, &config_.sm_count, attr_size);
      break;
    default:
      NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr,
                                      const void *buf, size_t size_in_bytes) {
  // Check attribute and buffer
  NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ",
             static_cast<int>(attr), ")");
  const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr];
  NVTE_CHECK(size_in_bytes >= attr_size,
             "Buffer is too small for matmul config attribute "
             "(attribute ",
             static_cast<int>(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes,
             " bytes)");
  NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)");

  // Read from buffer
  NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)");
  auto &config_ = *reinterpret_cast<transformer_engine::MatmulConfig *>(config);
  switch (attr) {
    case kNVTEMatmulConfigBiasTensor:
      std::memcpy(&config_.bias_tensor, buf, attr_size);
      break;
    case kNVTEMatmulConfigDBiasTensor:
      std::memcpy(&config_.dbias_tensor, buf, attr_size);
      break;
    case kNVTEMatmulConfigWithGELUEpilogue:
      std::memcpy(&config_.with_gelu_epilogue, buf, attr_size);
      break;
    case kNVTEMatmulConfigWithDGELUEpilogue:
      std::memcpy(&config_.with_dgelu_epilogue, buf, attr_size);
      break;
    case kNVTEMatmulConfigEpilogueAuxTensor:
      std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size);
      break;
    case kNVTEMatmulConfigUseSplitAccumulator:
      std::memcpy(&config_.use_split_accumulator, buf, attr_size);
      break;
    case kNVTEMatmulConfigSMCount:
      std::memcpy(&config_.sm_count, buf, attr_size);
      break;
    default:
      NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast<int>(attr), ")");
  }
}

void nvte_destroy_matmul_config(NVTEMatmulConfig config) {
  if (config != nullptr) {
    delete reinterpret_cast<transformer_engine::MatmulConfig *>(config);
  }
}