/************************************************************************* * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "./config.h" #include #include #include #include "../util/logging.h" NVTEMatmulConfig nvte_create_matmul_config() { return new transformer_engine::MatmulConfig; } void nvte_get_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written) { // Write attribute size NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; *size_written = attr_size; // Return immediately if buffer is not provided if (buf == nullptr) { return; } // Check buffer size NVTE_CHECK(size_in_bytes >= attr_size, "Buffer is too small for matmul config attribute " "(attribute ", static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, " bytes)"); // bool size is implementation-dependent, so we explicitly specify // uint8_t in the user-facing API. auto bool_to_uint8 = [](bool in, void *out) { *reinterpret_cast(out) = static_cast(in); }; // Write to buffer NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); const auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEMatmulConfigBiasTensor: std::memcpy(buf, &config_.bias_tensor, attr_size); break; case kNVTEMatmulConfigDBiasTensor: std::memcpy(buf, &config_.dbias_tensor, attr_size); break; case kNVTEMatmulConfigWithGELUEpilogue: bool_to_uint8(config_.with_gelu_epilogue, buf); break; case kNVTEMatmulConfigWithDGELUEpilogue: bool_to_uint8(config_.with_dgelu_epilogue, buf); break; case kNVTEMatmulConfigEpilogueAuxTensor: std::memcpy(buf, &config_.epilogue_aux_tensor, attr_size); break; case kNVTEMatmulConfigUseSplitAccumulator: bool_to_uint8(config_.use_split_accumulator, buf); break; case kNVTEMatmulConfigSMCount: *reinterpret_cast(buf) = static_cast(config_.sm_count); break; default: NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); } } void nvte_set_matmul_config_attribute(NVTEMatmulConfig config, NVTEMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes) { // Check attribute and buffer NVTE_CHECK(attr < kNVTEMatmulConfigNumAttributes, "Invalid NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); const auto &attr_size = transformer_engine::MatmulConfig::attr_sizes[attr]; NVTE_CHECK(size_in_bytes >= attr_size, "Buffer is too small for matmul config attribute " "(attribute ", static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, " bytes)"); NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); // bool size is implementation-dependent, so we explicitly specify // uint8_t in the user-facing API. auto uint8_to_bool = [](const void *in, bool &out) { out = static_cast(*reinterpret_cast(in)); }; // Read from buffer NVTE_CHECK(config != nullptr, "Invalid NVTEMatmulConfig (got NULL)"); auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEMatmulConfigBiasTensor: std::memcpy(&config_.bias_tensor, buf, attr_size); break; case kNVTEMatmulConfigDBiasTensor: std::memcpy(&config_.dbias_tensor, buf, attr_size); break; case kNVTEMatmulConfigWithGELUEpilogue: uint8_to_bool(buf, config_.with_gelu_epilogue); break; case kNVTEMatmulConfigWithDGELUEpilogue: uint8_to_bool(buf, config_.with_dgelu_epilogue); break; case kNVTEMatmulConfigEpilogueAuxTensor: std::memcpy(&config_.epilogue_aux_tensor, buf, attr_size); break; case kNVTEMatmulConfigUseSplitAccumulator: uint8_to_bool(buf, config_.use_split_accumulator); break; case kNVTEMatmulConfigSMCount: config_.sm_count = static_cast(*reinterpret_cast(buf)); break; default: NVTE_ERROR("Unsupported NVTEMatmulConfigAttribute (got ", static_cast(attr), ")"); } } void nvte_destroy_matmul_config(NVTEMatmulConfig config) { if (config != nullptr) { delete reinterpret_cast(config); } } NVTEGroupedMatmulConfig nvte_create_grouped_matmul_config() { return new transformer_engine::GroupedMatmulConfig; } void nvte_get_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, NVTEGroupedMatmulConfigAttribute attr, void *buf, size_t size_in_bytes, size_t *size_written) { // Write attribute size NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); NVTE_CHECK(size_written != nullptr, "Invalid size_written (got NULL)"); const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; *size_written = attr_size; // Return immediately if buffer is not provided if (buf == nullptr) { return; } // Check buffer size NVTE_CHECK(size_in_bytes >= attr_size, "Buffer is too small for grouped matmul config attribute " "(attribute ", static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, " bytes)"); // Write to buffer NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); const auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEGroupedMatmulConfigAvgM: { int64_t val = config_.avg_m.value_or(0); std::memcpy(buf, &val, attr_size); break; } case kNVTEGroupedMatmulConfigAvgN: { int64_t val = config_.avg_n.value_or(0); std::memcpy(buf, &val, attr_size); break; } case kNVTEGroupedMatmulConfigAvgK: { int64_t val = config_.avg_k.value_or(0); std::memcpy(buf, &val, attr_size); break; } case kNVTEGroupedMatmulConfigSMCount: std::memcpy(buf, &config_.sm_count, attr_size); break; default: NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); } } void nvte_set_grouped_matmul_config_attribute(NVTEGroupedMatmulConfig config, NVTEGroupedMatmulConfigAttribute attr, const void *buf, size_t size_in_bytes) { // Check attribute and buffer NVTE_CHECK(attr < kNVTEGroupedMatmulConfigNumAttributes, "Invalid NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); const auto &attr_size = transformer_engine::GroupedMatmulConfig::attr_sizes[attr]; NVTE_CHECK(size_in_bytes >= attr_size, "Buffer is too small for grouped matmul config attribute " "(attribute ", static_cast(attr), " needs ", attr_size, " bytes, but buffer has ", size_in_bytes, " bytes)"); NVTE_CHECK(buf != nullptr, "Invalid buffer (got NULL)"); // Read from buffer NVTE_CHECK(config != nullptr, "Invalid NVTEGroupedMatmulConfig (got NULL)"); auto &config_ = *reinterpret_cast(config); switch (attr) { case kNVTEGroupedMatmulConfigAvgM: { int64_t val; std::memcpy(&val, buf, attr_size); config_.avg_m = val; break; } case kNVTEGroupedMatmulConfigAvgN: { int64_t val; std::memcpy(&val, buf, attr_size); config_.avg_n = val; break; } case kNVTEGroupedMatmulConfigAvgK: { int64_t val; std::memcpy(&val, buf, attr_size); config_.avg_k = val; break; } case kNVTEGroupedMatmulConfigSMCount: std::memcpy(&config_.sm_count, buf, attr_size); break; default: NVTE_ERROR("Unsupported NVTEGroupedMatmulConfigAttribute (got ", static_cast(attr), ")"); } } void nvte_destroy_grouped_matmul_config(NVTEGroupedMatmulConfig config) { if (config != nullptr) { delete reinterpret_cast(config); } }