Commit 01a10755 authored by yuguo-Jack's avatar yuguo-Jack
Browse files

2.5.2-dtk24.04

parent 63eb0da5
...@@ -23,13 +23,14 @@ ...@@ -23,13 +23,14 @@
#include "paddle/cinn/frontend/optimize.h" #include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/graph.h" #include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h" #include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h" #include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h" #include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h" #include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h" #include "paddle/cinn/utils/data_util.h"
DEFINE_string(model_dir, "", ""); PD_DEFINE_string(model_dir, "", "");
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
...@@ -69,7 +70,8 @@ TEST(syntax, program_execute_multi_elementwise_add) { ...@@ -69,7 +70,8 @@ TEST(syntax, program_execute_multi_elementwise_add) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A"); scope->Var<hlir::framework::Tensor>("A");
scope->Var<hlir::framework::Tensor>("B"); scope->Var<hlir::framework::Tensor>("B");
...@@ -88,7 +90,8 @@ TEST(syntax, program_execute_multi_elementwise_add2) { ...@@ -88,7 +90,8 @@ TEST(syntax, program_execute_multi_elementwise_add2) {
LOG(INFO) << "graph:\n" << graph->Visualize(); LOG(INFO) << "graph:\n" << graph->Visualize();
auto scope = BuildScope(target, graph); auto scope = BuildScope(target, graph);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope, target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
scope->Var<hlir::framework::Tensor>("A"); scope->Var<hlir::framework::Tensor>("A");
...@@ -121,7 +124,8 @@ std::get<2>(programTuple); ...@@ -121,7 +124,8 @@ std::get<2>(programTuple);
auto graph = cinn::frontend::Optimize(program.get(), fetch_ids, target); auto graph = cinn::frontend::Optimize(program.get(), fetch_ids, target);
scope = BuildScope(target, graph, scope); scope = BuildScope(target, graph, scope);
hlir::framework::GraphCompiler gc(target, scope, graph); hlir::framework::CompilationContext context(graph, scope,target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build(); auto runtime_program = gc.Build();
auto at = scope->GetTensor("A"); auto at = scope->GetTensor("A");
...@@ -133,11 +137,12 @@ std::get<2>(programTuple); ...@@ -133,11 +137,12 @@ std::get<2>(programTuple);
LOG(INFO) << "scope.names: " << Join(scope->var_names(), ","); LOG(INFO) << "scope.names: " << Join(scope->var_names(), ",");
const std::string output_name = "fc_0.tmp_2"; const std::string output_name = "fc_0.tmp_2";
auto tensor = auto tensor = scope->GetTensor(var_map_paddle_to_program.at(output_name));
scope->GetTensor(var_map_paddle_to_program.at(output_name)); LOG(INFO) << LOG(INFO) << "tensor.shape: " << utils::Join(tensor->shape().data(), ",");
"tensor.shape: " << utils::Join(tensor->shape().data(), ","); auto data = auto data = GetTensorData<float>(tensor, target);
GetTensorData<float>(tensor, target); for (int i = 0; i < 10; i++) LOG(INFO) << for (int i = 0; i < 10; i++) {
"data: " << data[i]; LOG(INFO) << "data: " << data[i];
}
} }
*/ */
......
...@@ -12,12 +12,12 @@ ...@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include "paddle/utils/flags.h"
int main(int argc, char **argv) { int main(int argc, char **argv) {
testing::InitGoogleTest(&argc, argv); testing::InitGoogleTest(&argc, argv);
GFLAGS_NAMESPACE::ParseCommandLineFlags(&argc, &argv, false); paddle::flags::ParseCommandLineFlags(&argc, &argv);
return RUN_ALL_TESTS(); return RUN_ALL_TESTS();
} }
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could add_subdirectory(operator)
# not found under CINN_ONLY mode add_subdirectory(runtime)
if(NOT CINN_ONLY)
cinn_cc_library(cinn_dialect SRCS runtime_dialect.cc jit_kernel_op.cc DEPS
pd_dialect)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/enforce.h"
namespace cinn {
namespace dialect {
const char* JitKernelOp::attributes_name[attributes_num] = {kAttrName};
void JitKernelOp::Verify() {
VLOG(4) << "Verifying inputs, outputs and attributes for: JitKernelOp.";
auto& attributes = this->attributes();
IR_ENFORCE(attributes.count(kAttrName) > 0 &&
attributes.at(kAttrName).isa<::ir::PointerAttribute>(),
"Type of attribute: instruction is not right.");
}
hlir::framework::Instruction* JitKernelOp::instruction() {
void* ptr =
attributes().at(kAttrName).dyn_cast<ir::PointerAttribute>().data();
return reinterpret_cast<hlir::framework::Instruction*>(ptr);
}
} // namespace dialect
} // namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::JitKernelOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/op_base.h"
namespace cinn {
namespace hlir {
namespace framework {
class Instruction;
} // namespace framework
} // namespace hlir
namespace dialect {
/*
* TODO(Aurelius84): THIS IS NOT FINAL STATE!
* JitKernel is unified runtime operation to represent
* jit compiled function ptr from backend, such as
* nvrct.
* Ideally, JitKernel should only contains ArrayAttribute
* with each element is PointerAttribute, which is jit
* function ptr indeed.
* Currently, we regard hlir::framework::Instruction
* temporarily, and will spilt executor information like
* scope, inputs, outputs into InterpretorCore module.
*/
class JitKernelOp : public ::ir::Op<JitKernelOp> {
public:
using Op::Op;
static const char* name() { return "cinn.jit_kernel"; }
// TODO(Aurelius84): Think deeply what should contains
static constexpr uint32_t attributes_num = 1;
static constexpr char* kAttrName = "instruction";
static const char* attributes_name[attributes_num];
hlir::framework::Instruction* instruction();
void Verify();
};
} // namespace dialect
} // namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::JitKernelOp)
add_subdirectory(ir)
add_subdirectory(transforms)
# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could
# not found under CINN_ONLY mode
if(NOT CINN_ONLY)
set(CINN_DIALECT_SOURCE_DIR
"${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir")
# Generate cinn_op_dialect files defining op using op_gen_file
set(cinn_op_gen_parsed_yaml_file
${PADDLE_SOURCE_DIR}/paddle/fluid/operators/generator/parse_op.py)
set(cinn_op_gen_file
${PADDLE_SOURCE_DIR}/paddle/fluid/pir/dialect/op_generator/op_gen.py)
set(cinn_op_compat_yaml_file
${PADDLE_SOURCE_DIR}/paddle/phi/api/yaml/op_compat.yaml)
set(cinn_op_yaml_file
${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/operator/ir/ops.yaml)
set(parsed_op_dir ${PADDLE_SOURCE_DIR}/paddle/cinn/hlir/dialect/generated)
set(cinn_op_parsed_yaml_file ${parsed_op_dir}/ops.parsed.yaml)
set(cinn_op_parsed_yaml_files ${cinn_op_parsed_yaml_file})
set(cinn_op_namespace cinn,dialect)
set(cinn_op_dialect_name cinn_op)
set(cinn_op_header_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op.h)
set(cinn_op_source_file ${CINN_DIALECT_SOURCE_DIR}/cinn_op.cc)
set(cinn_op_header_file_tmp ${cinn_op_header_file}.tmp)
set(cinn_op_source_file_tmp ${cinn_op_source_file}.tmp)
execute_process(
COMMAND ${CMAKE_COMMAND} -E make_directory ${parsed_op_dir}
COMMAND ${PYTHON_EXECUTABLE} ${cinn_op_gen_parsed_yaml_file} --op_yaml_path
${cinn_op_yaml_file} --output_path ${cinn_op_parsed_yaml_file})
execute_process(
COMMAND
${PYTHON_EXECUTABLE} ${cinn_op_gen_file} --op_yaml_files
${cinn_op_parsed_yaml_files} --op_compat_yaml_file
${cinn_op_compat_yaml_file} --namespaces ${cinn_op_namespace}
--dialect_name ${cinn_op_dialect_name} --op_def_h_file
${cinn_op_header_file_tmp} --op_def_cc_file ${cinn_op_source_file_tmp})
set(generated_files_cinn_op "${cinn_op_header_file}" "${cinn_op_source_file}")
foreach(generated_file ${generated_files_cinn_op})
if(EXISTS "${generated_file}.tmp" AND EXISTS "${generated_file}")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy_if_different
"${generated_file}.tmp" "${generated_file}")
message("copy if different ${generated_file}.tmp ${generated_file}")
elseif(EXISTS "${generated_file}.tmp")
execute_process(COMMAND ${CMAKE_COMMAND} -E copy "${generated_file}.tmp"
"${generated_file}")
message("copy ${generated_file}.tmp ${generated_file}")
endif()
endforeach()
cinn_cc_library(
cinn_op_dialect
SRCS
op_dialect.cc
${cinn_op_source_file}
manual_op.cc
op_attribute.cc
DEPS
op_dialect_vjp)
target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_SOURCE_DIR})
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/pir/core/attribute_base.h"
#include "paddle/pir/core/operation.h"
namespace cinn {
namespace dialect {
// TODO(Aurelius84): Need to figure out what we need indeed for GroupOp.
// Currently we paste almost members here and will remove them step by
// step.
struct GroupInfo {
public:
explicit GroupInfo(const std::vector<::pir::Operation*>& group_ops)
: ops(group_ops) {
Initialize();
}
explicit GroupInfo(std::initializer_list<::pir::Operation*> group_ops)
: ops(group_ops) {
Initialize();
}
std::string group_id;
std::string fn_name;
hlir::framework::OpPatternKind op_pattern_kind;
std::vector<::pir::Operation*> ops;
std::vector<std::string> input_names;
std::vector<std::string> output_names;
private:
void Initialize() {
op_pattern_kind = hlir::framework::OpPatternKind::kElementWise;
fn_name = hlir::framework::pir::CompatibleInfo::GroupOpsName(ops);
}
};
struct GroupInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = GroupInfo;
explicit GroupInfoAttributeStorage(const ParamKey& key) : data_(key) {}
static GroupInfoAttributeStorage* Construct(const ParamKey& key) {
return new GroupInfoAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey& key) {
return std::hash<std::string>{}(key.group_id);
}
bool operator==(const ParamKey& key) const {
return data_.group_id == key.group_id;
}
const ParamKey& GetAsKey() const { return data_; }
private:
ParamKey data_;
};
struct JITInfoAttributeStorage : public pir::AttributeStorage {
using ParamKey = cinn::hlir::framework::pir::CUDAJITInfo;
explicit JITInfoAttributeStorage(const ParamKey& key) : data_(key) {}
static JITInfoAttributeStorage* Construct(const ParamKey& key) {
return new JITInfoAttributeStorage(key);
}
static std::size_t HashValue(const ParamKey& key) {
return std::hash<int64_t>()(*(reinterpret_cast<int64_t*>(key.fn_ptr)));
}
bool operator==(const ParamKey& key) const {
return data_.fn_ptr == key.fn_ptr;
}
const ParamKey& GetAsKey() const { return data_; }
private:
ParamKey data_;
};
} // namespace dialect
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
namespace cinn {
namespace dialect {
const char *GroupOp::attributes_name[GroupOp::attributes_num] = {"group_info"};
const char *ConcatOp::attributes_name[ConcatOp::attributes_num] = {"axis"};
const char *SplitOp::attributes_name[SplitOp::attributes_num] = {
"num_or_sections", "axis"};
void GroupOp::Build(pir::Builder &builder,
pir::OperationArgument &argument,
const std::vector<pir::Type> &output_types) {
argument.AddRegion(nullptr);
argument.output_types = output_types;
}
void GroupOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
std::unique_ptr<pir::Block> &&block) {
VLOG(4) << "Start build GroupOp";
if (block && !block->empty()) {
IR_ENFORCE(block->back().isa<pir::YieldOp>());
auto &op = block->back();
for (size_t i = 0; i < op.num_operands(); ++i) {
argument.AddOutput(op.operand(i).type());
}
}
argument.AddRegion()->push_back(block.release());
}
pir::Block *GroupOp::block() {
pir::Region &region = (*this)->region(0);
if (region.empty()) region.emplace_back();
return &region.front();
}
std::vector<pir::Operation *> GroupOp::ops() {
std::vector<pir::Operation *> rt_ops;
for (auto &op : *block()) {
rt_ops.push_back(&op);
}
return rt_ops;
}
void GroupOp::VerifySig() {}
void GroupOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = " << name();
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << " {";
for (auto &sub_op : ops()) {
os << "\n";
printer.PrintOperation(sub_op);
}
os << " \n }";
}
void ConcatOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
int axis) {
VLOG(4) << "Start build ConcatOp";
argument.inputs = inputs;
std::vector<pir::Type> inputs_type(inputs.size());
IR_ENFORCE(inputs.size() > 0);
auto first_ele =
inputs[0].type().dyn_cast<paddle::dialect::DenseTensorType>();
phi::DDim out_dims = first_ele.dims();
if (axis < 0) {
axis += out_dims.size();
}
for (size_t idx = 0; idx < inputs.size(); ++idx) {
inputs_type[idx] = inputs[idx].type();
if (idx > 0) {
auto dim_i = inputs[idx]
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
out_dims[axis] += dim_i[axis];
}
}
auto out_type =
paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(),
first_ele.dtype(),
out_dims,
first_ele.data_layout(),
first_ele.lod(),
first_ele.offset());
argument.output_types.emplace_back(out_type);
PassStopGradientsDefaultly(argument);
argument.AddAttribute(
"axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis));
}
void SplitOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value input,
const std::vector<int> &sections,
int axis) {
VLOG(4) << "Start build ConcatOp";
argument.inputs.push_back(input);
std::vector<pir::Type> output_type(sections.size());
auto input_ele = input.type().dyn_cast<paddle::dialect::DenseTensorType>();
if (axis < 0) {
axis += input_ele.dims().size();
}
std::vector<pir::Attribute> section_attrs;
for (size_t idx = 0; idx < sections.size(); ++idx) {
auto out_dims = input_ele.dims();
out_dims[axis] = sections[idx];
auto out_type =
paddle::dialect::DenseTensorType::get(pir::IrContext::Instance(),
input_ele.dtype(),
out_dims,
input_ele.data_layout(),
input_ele.lod(),
input_ele.offset());
argument.output_types.emplace_back(out_type);
pir::Attribute attr_axis =
pir::Int32Attribute::get(pir::IrContext::Instance(), sections[idx]);
section_attrs.push_back(attr_axis);
}
PassStopGradientsDefaultly(argument);
argument.AddAttribute(
"num_or_sections",
pir::ArrayAttribute::get(pir::IrContext::Instance(), section_attrs));
argument.AddAttribute(
"axis", pir::Int32Attribute::get(pir::IrContext::Instance(), axis));
}
} // namespace dialect
} // namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/operation_utils.h"
namespace cinn {
namespace dialect {
class GroupOp : public pir::Op<GroupOp> {
public:
using Op::Op;
static const char *name() { return "cinn_op.group"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Type> &output_types);
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
std::unique_ptr<pir::Block> &&block);
pir::Block *block();
std::vector<pir::Operation *> ops();
void VerifySig();
void Print(pir::IrPrinter &printer); // NOLINT
};
class IR_API ConcatOp : public pir::Op<ConcatOp> {
public:
using Op::Op;
static const char *name() { return "cinn_op.concat"; }
static constexpr uint32_t attributes_num = 1;
static const char *attributes_name[attributes_num];
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
int axis);
void VerifySig() const {}
};
class IR_API SplitOp : public pir::Op<SplitOp> {
public:
using Op::Op;
static const char *name() { return "cinn_op.split"; }
static constexpr uint32_t attributes_num = 2;
static const char *attributes_name[attributes_num];
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value input,
const std::vector<int> &sections,
int axis);
void VerifySig() const {}
};
} // namespace dialect
} // namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::ConcatOp)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::SplitOp)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
namespace cinn {
namespace dialect {
const GroupInfo &GroupInfoAttribute::data() const {
return storage()->GetAsKey();
}
const cinn::hlir::framework::pir::CUDAJITInfo &CUDAJITInfoAttribute::data()
const {
return storage()->GetAsKey();
}
} // namespace dialect
} // namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/pir/core/attribute_base.h"
namespace cinn {
namespace dialect {
class GroupInfoAttribute : public pir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(GroupInfoAttribute,
GroupInfoAttributeStorage);
bool operator<(const GroupInfoAttribute& right) const {
return storage() < right.storage();
}
const GroupInfo& data() const;
};
class CUDAJITInfoAttribute : public pir::Attribute {
public:
using Attribute::Attribute;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR(CUDAJITInfoAttribute,
JITInfoAttributeStorage);
bool operator<(const CUDAJITInfoAttribute& right) const {
return storage() < right.storage();
}
const cinn::hlir::framework::pir::CUDAJITInfo& data() const;
};
} // namespace dialect
} // namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::GroupInfoAttribute)
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::CUDAJITInfoAttribute)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
// NOTE(chenxi67): File cinn_op.h is generated by op_gen.py, see details in
// paddle/cinn/hlir/dialect/CMakeLists.txt.
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
namespace cinn {
namespace dialect {
OperatorDialect::OperatorDialect(::pir::IrContext *context)
: ::pir::Dialect(name(),
context,
::pir::TypeId::get<cinn::dialect::OperatorDialect>()) {
this->initialize();
}
void OperatorDialect::initialize() {
// NOTE(chenxi67): GET_OP_LIST is defined in cinn_op.h which is
// generated by op_gen.py, see details in
// paddle/cinn/hlir/dialect/CMakeLists.txt.
RegisterOps<
#define GET_OP_LIST
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.cc" // NOLINT
>();
RegisterOp<GroupOp>();
RegisterOp<ConcatOp>();
RegisterOp<SplitOp>();
RegisterAttribute<GroupInfoAttribute>();
RegisterAttribute<CUDAJITInfoAttribute>();
}
void OperatorDialect::PrintType(pir::Type type, std::ostream &os) const {}
void OperatorDialect::PrintAttribute(pir::Attribute attr,
std::ostream &os) const {
if (attr.isa<GroupInfoAttribute>()) {
os << "(" << attr.dialect().name();
os << '.';
if (auto group_info_attr = attr.dyn_cast<GroupInfoAttribute>()) {
const GroupInfo &data = group_info_attr.data();
os << "GroupInfo)"
<< "[" << data.fn_name << "]";
}
{ os << "<#AttrNotImplemented>"; }
} else if (attr.isa<CUDAJITInfoAttribute>()) {
auto cuda_jit_info = attr.dyn_cast<CUDAJITInfoAttribute>();
os << "(" << cuda_jit_info.data().fn_ptr;
os << ')';
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"cinn dialect only support GrupInfo and CUDAJITInfo"));
}
}
void OperatorDialect::PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const {
if (auto group_op = op->dyn_cast<GroupOp>()) {
group_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
}
}
} // namespace dialect
} // namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID(cinn::dialect::OperatorDialect)
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pir/core/dialect.h"
namespace cinn {
namespace dialect {
class OperatorDialect : public ::pir::Dialect {
public:
explicit OperatorDialect(::pir::IrContext* context);
static const char* name() { return "cinn_op"; }
void PrintType(pir::Type type, std::ostream& os) const override;
void PrintAttribute(pir::Attribute type, std::ostream& os) const override;
void PrintOperation(pir::Operation* op,
pir::IrPrinter& printer) const override; // NOLINT
private:
void initialize();
};
} // namespace dialect
} // namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID(cinn::dialect::OperatorDialect)
- op : broadcast
args : (Tensor x, int64_t[] broadcast_axes, int64_t[] out_shape)
output : Tensor(out)
infer_meta :
func : CINNBroadcastInferMeta
param : [x, broadcast_axes, out_shape]
kernel :
func : expand
param : [x, broadcast_axes]
- op : reduce_max
args : (Tensor x, int64_t[] dim, bool keep_dim)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
kernel :
func : frobenius_norm
- op : reduce_sum
args : (Tensor x, int64_t[] dim, bool keep_dim)
output : Tensor(out)
infer_meta :
func : ReduceInferMeta
kernel :
func : frobenius_norm
- op : reshape
args : (Tensor x, int[] shape)
output : Tensor(out)
infer_meta :
func : ReshapeInferMeta
kernel :
func : reshape
- op : scale
args : (Tensor x, float scale=1.0, float bias=0.0, bool bias_after_scale=true)
output : Tensor(out)
infer_meta :
func : UnchangedInferMeta
param : [x]
kernel :
func : scale
- op : slice
args : (Tensor x, int64_t[] axes, int64_t[] starts, int64_t[] ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output : Tensor
infer_meta :
func : SliceRawInferMeta
kernel :
func : slice
- op : uniform_random
args : (int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
output : Tensor(out)
infer_meta :
func : CreateVecShapeInferMeta
param : [shape, dtype]
kernel :
func : full_int_array
param : [shape, dtype]
add_subdirectory(group_merge)
if(NOT CINN_ONLY)
cinn_cc_library(
pd_to_cinn_pass
SRCS
pd_to_cinn_pass.cc
DEPS
drr
cinn_op_dialect
op_dialect_vjp)
cinn_cc_library(
add_broadcast_to_elementwise_pass
SRCS
add_broadcast_to_elementwise_pass.cc
DEPS
pir
cinn_op_dialect
op_dialect_vjp)
endif()
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
namespace cinn {
namespace dialect {
namespace ir {
int64_t GetDimByIndex(const phi::DDim& first,
const phi::DDim& second,
int short_align_axis,
int idx) {
// rank of first less than rank of second
if (idx < short_align_axis) {
return second[idx];
} else {
return first[idx - short_align_axis] > second[idx]
? first[idx - short_align_axis]
: second[idx];
}
}
std::vector<int64_t> GetOutputShape(const phi::DDim& x, const phi::DDim& y) {
std::vector<int64_t> vec_res;
if (x.size() >= y.size()) {
int short_align_axis = x.size() - y.size();
int max_rank = x.size();
vec_res.resize(max_rank);
for (size_t i = 0; i < max_rank; ++i) {
vec_res[i] = GetDimByIndex(y, x, short_align_axis, i);
}
} else {
int short_align_axis = y.size() - x.size();
int max_rank = y.size();
vec_res.resize(max_rank);
for (size_t i = 0; i < max_rank; ++i) {
vec_res[i] = GetDimByIndex(x, y, short_align_axis, max_rank);
}
}
return vec_res;
}
bool IsSameDim(const phi::DDim& first, const std::vector<int64_t>& second) {
if (first.size() == second.size()) {
bool same = true;
for (size_t i = 0; i < first.size(); ++i) {
if (first[i] != second[i]) {
same = false;
break;
}
}
return same;
}
return false;
}
std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}
return broadcast_axes;
}
bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
auto x_dims = op->operand_source(0)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
auto y_dims = op->operand_source(1)
.type()
.dyn_cast<paddle::dialect::DenseTensorType>()
.dims();
if (x_dims != y_dims) {
auto output_shape = GetOutputShape(x_dims, y_dims);
if (!IsSameDim(x_dims, output_shape)) {
// add broadcast to input 0
if (auto full_op = op->operand_source(0)
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<paddle::dialect::FullOp>()) {
auto new_full = rewriter->Build<paddle::dialect::FullOp>(
output_shape,
full_op->attribute("value").dyn_cast<pir::FloatAttribute>().data(),
full_op->attribute("dtype")
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data(),
full_op->attribute("place")
.dyn_cast<paddle::dialect::PlaceAttribute>()
.data());
op->operand(0).set_source(new_full->result(0));
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(0),
GetBroadcastAxis(x_dims, output_shape),
output_shape);
op->operand(0).set_source(new_transpose_op->result(0));
}
}
if (!IsSameDim(y_dims, output_shape)) {
if (auto full_op = op->operand_source(1)
.dyn_cast<pir::OpResult>()
.owner()
->dyn_cast<paddle::dialect::FullOp>()) {
auto new_full = rewriter->Build<paddle::dialect::FullOp>(
output_shape,
full_op->attribute("value").dyn_cast<pir::FloatAttribute>().data(),
full_op->attribute("dtype")
.dyn_cast<paddle::dialect::DataTypeAttribute>()
.data(),
full_op->attribute("place")
.dyn_cast<paddle::dialect::PlaceAttribute>()
.data());
op->operand(1).set_source(new_full->result(0));
} else {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(1),
GetBroadcastAxis(y_dims, output_shape),
output_shape);
op->operand(1).set_source(new_transpose_op->result(0));
}
}
return true;
}
return false;
}
template <typename OPTYPE>
class AddBrodcastToElementwisePattern : public pir::OpRewritePattern<OPTYPE> {
public:
using pir::OpRewritePattern<OPTYPE>::OpRewritePattern;
bool MatchAndRewrite(OPTYPE op,
pir::PatternRewriter& rewriter) const override {
return ProcessOp(op, &rewriter);
}
};
AddBroadcastToElementwisePass::AddBroadcastToElementwisePass()
: pir::PatternRewritePass("add_broadcast_to_elementwise_pass", 1) {}
pir::RewritePatternSet AddBroadcastToElementwisePass::InitializePatterns(
pir::IrContext* context) {
pir::RewritePatternSet ps(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::AddOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::SubtractOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::MultiplyOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::DivideOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::ElementwisePowOp>>(
context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::LessThanOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::LessEqualOp>>(
context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::EqualOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::NotEqualOp>>(context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::GreaterThanOp>>(
context);
ps.Add<AddBrodcastToElementwisePattern<paddle::dialect::GreaterEqualOp>>(
context);
return ps;
}
bool AddBroadcastToElementwisePass::CanApplyOn(pir::Operation* op) const {
return op->isa<pir::ModuleOp>() && op->num_regions() > 0;
}
} // namespace ir
} // namespace dialect
} // namespace cinn
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
namespace cinn {
namespace dialect {
namespace ir {
class AddBroadcastToElementwisePass : public pir::PatternRewritePass {
public:
AddBroadcastToElementwisePass();
pir::RewritePatternSet InitializePatterns(pir::IrContext *context) override;
bool CanApplyOn(pir::Operation *op) const override;
};
} // namespace ir
} // namespace dialect
} // namespace cinn
if(NOT CINN_ONLY)
cinn_cc_library(
op_with_group_merge_pass
SRCS
group_with_group_merge_pass.cc
op_with_group_merge_pass.cc
cinn_group_lowering_pass.cc
tensor_node.cc
DEPS
op_dialect_vjp
pir_compiler
cinn_runtime_dialect)
endif()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment