// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/pir/core/op_info_impl.h"
#include "paddle/pir/core/dialect.h"
#include "paddle/pir/core/interface_support.h"

namespace pir {

void OpInfo::AttachInterface(InterfaceValue &&interface_value) {
  IR_ENFORCE(impl_, "Cann't attach interface to a nullptr OpInfo");
  impl_->AttachInterface(std::move(interface_value));
}

void OpInfoImpl::AttachInterface(InterfaceValue &&interface_value) {
  auto suceess = interface_set_.insert(std::move(interface_value)).second;
  IR_ENFORCE(suceess,
             "Interface: id[%u] is already registered. inset failed",
             interface_value.type_id());
  VLOG(6) << "Attach a interface: id[" << interface_value.type_id() << "]. to "
          << op_name_;
}

OpInfoImpl::OpInfoImpl(std::set<InterfaceValue> &&interface_set,
                       pir::Dialect *dialect,
                       TypeId op_id,
                       const char *op_name,
                       uint32_t num_traits,
                       uint32_t num_attributes,
                       const char **p_attributes,
                       VerifyPtr verify_sig,
                       VerifyPtr verify_region)
    : interface_set_(std::move(interface_set)),
      dialect_(dialect),
      op_id_(op_id),
      op_name_(op_name),
      num_traits_(num_traits),
      num_attributes_(num_attributes),
      p_attributes_(p_attributes),
      verify_sig_(verify_sig),
      verify_region_(verify_region) {}

OpInfo OpInfoImpl::Create(Dialect *dialect,
                          TypeId op_id,
                          const char *op_name,
                          std::set<InterfaceValue> &&interface_set,
                          const std::vector<TypeId> &trait_set,
                          size_t attributes_num,
                          const char *attributes_name[],  // NOLINT
                          VerifyPtr verify_sig,
                          VerifyPtr verify_region) {
  // (1) Malloc memory for traits, opinfo_impl.
  size_t traits_num = trait_set.size();
  VLOG(6) << "Create OpInfoImpl with: " << interface_set.size()
          << " interfaces, " << traits_num << " traits, " << attributes_num
          << " attributes.";
  size_t base_size = sizeof(TypeId) * traits_num + sizeof(OpInfoImpl);
  char *base_ptr = static_cast<char *>(::operator new(base_size));
  VLOG(6) << "Malloc " << base_size << " Bytes at "
          << static_cast<void *>(base_ptr);
  if (traits_num > 0) {
    auto p_first_trait = reinterpret_cast<TypeId *>(base_ptr);
    memcpy(base_ptr, trait_set.data(), sizeof(TypeId) * traits_num);
    std::sort(p_first_trait, p_first_trait + traits_num);
    base_ptr += traits_num * sizeof(TypeId);
  }
  // Construct OpInfoImpl.
  VLOG(6) << "Construct OpInfoImpl at " << reinterpret_cast<void *>(base_ptr)
          << " ......";
  OpInfo op_info = OpInfo(new (base_ptr) OpInfoImpl(std::move(interface_set),
                                                    dialect,
                                                    op_id,
                                                    op_name,
                                                    traits_num,
                                                    attributes_num,
                                                    attributes_name,
                                                    verify_sig,
                                                    verify_region));
  return op_info;
}
void OpInfoImpl::Destroy(OpInfo info) {
  if (info.impl_) {
    info.impl_->Destroy();
  } else {
    LOG(WARNING) << "A nullptr OpInfo is destoryed.";
  }
}

pir::IrContext *OpInfoImpl::ir_context() const {
  return dialect_ ? dialect_->ir_context() : nullptr;
}

bool OpInfoImpl::HasTrait(TypeId trait_id) const {
  if (num_traits_ > 0) {
    const TypeId *p_first_trait =
        reinterpret_cast<const TypeId *>(reinterpret_cast<const char *>(this) -
                                         sizeof(pir::TypeId) * num_traits_);
    return std::binary_search(
        p_first_trait, p_first_trait + num_traits_, trait_id);
  }
  return false;
}

bool OpInfoImpl::HasInterface(TypeId interface_id) const {
  return interface_set_.find(interface_id) != interface_set_.end();
}

void *OpInfoImpl::GetInterfaceImpl(TypeId interface_id) const {
  auto iter = interface_set_.find(interface_id);
  return iter != interface_set_.end() ? iter->model() : nullptr;
}

void OpInfoImpl::Destroy() {
  VLOG(10) << "Destroy op_info impl at " << this;
  // (1) compute memory address
  char *base_ptr =
      reinterpret_cast<char *>(this) - sizeof(pir::TypeId) * num_traits_;
  // (2)free interfaces
  this->~OpInfoImpl();
  // (3) free memeory
  VLOG(10) << "Free base_ptr " << reinterpret_cast<void *>(base_ptr);
  ::operator delete(base_ptr);
}

}  // namespace pir
