// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/operation_utils.h"

namespace cinn {
namespace dialect {

class GroupOp : public pir::Op<GroupOp> {
 public:
  using Op::Op;
  static const char *name() { return "cinn_op.group"; }
  static constexpr uint32_t attributes_num = 1;
  static const char *attributes_name[attributes_num];
  static void Build(pir::Builder &builder,             // NOLINT
                    pir::OperationArgument &argument,  // NOLINT
                    const std::vector<pir::Type> &output_types);

  static void Build(pir::Builder &builder,             // NOLINT
                    pir::OperationArgument &argument,  // NOLINT
                    std::unique_ptr<pir::Block> &&block);

  pir::Block *block();
  std::vector<pir::Operation *> ops();

  void VerifySig();
  void Print(pir::IrPrinter &printer);  // NOLINT
};

class IR_API ConcatOp : public pir::Op<ConcatOp> {
 public:
  using Op::Op;

  static const char *name() { return "cinn_op.concat"; }

  static constexpr uint32_t attributes_num = 1;

  static const char *attributes_name[attributes_num];

  static void Build(pir::Builder &builder,             // NOLINT
                    pir::OperationArgument &argument,  // NOLINT
                    const std::vector<pir::Value> &inputs,
                    int axis);

  void VerifySig() const {}
};

class IR_API SplitOp : public pir::Op<SplitOp> {
 public:
  using Op::Op;

  static const char *name() { return "cinn_op.split"; }

  static constexpr uint32_t attributes_num = 2;

  static const char *attributes_name[attributes_num];

  static void Build(pir::Builder &builder,             // NOLINT
                    pir::OperationArgument &argument,  // NOLINT
                    pir::Value input,
                    const std::vector<int> &sections,
                    int axis);

  void VerifySig() const {}
};

}  // namespace dialect
}  // namespace cinn

IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
