Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1017 additions
and
121 deletions
+1017
-121
paddle/cinn/frontend/syntax_test.cc
paddle/cinn/frontend/syntax_test.cc
+14
-9
paddle/cinn/gtest_main.cc
paddle/cinn/gtest_main.cc
+2
-2
paddle/cinn/hlir/dialect/.gitignore
paddle/cinn/hlir/dialect/.gitignore
+2
-0
paddle/cinn/hlir/dialect/CMakeLists.txt
paddle/cinn/hlir/dialect/CMakeLists.txt
+2
-6
paddle/cinn/hlir/dialect/jit_kernel_op.cc
paddle/cinn/hlir/dialect/jit_kernel_op.cc
+0
-44
paddle/cinn/hlir/dialect/jit_kernel_op.h
paddle/cinn/hlir/dialect/jit_kernel_op.h
+0
-60
paddle/cinn/hlir/dialect/operator/CMakeLists.txt
paddle/cinn/hlir/dialect/operator/CMakeLists.txt
+2
-0
paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt
paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt
+70
-0
paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
+103
-0
paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
+185
-0
paddle/cinn/hlir/dialect/operator/ir/manual_op.h
paddle/cinn/hlir/dialect/operator/ir/manual_op.h
+89
-0
paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc
paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc
+31
-0
paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
+54
-0
paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
+83
-0
paddle/cinn/hlir/dialect/operator/ir/op_dialect.h
paddle/cinn/hlir/dialect/operator/ir/op_dialect.h
+40
-0
paddle/cinn/hlir/dialect/operator/ir/ops.yaml
paddle/cinn/hlir/dialect/operator/ir/ops.yaml
+60
-0
paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
+22
-0
paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc
.../operator/transforms/add_broadcast_to_elementwise_pass.cc
+210
-0
paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h
...t/operator/transforms/add_broadcast_to_elementwise_pass.h
+35
-0
paddle/cinn/hlir/dialect/operator/transforms/group_merge/CMakeLists.txt
...ir/dialect/operator/transforms/group_merge/CMakeLists.txt
+13
-0
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/frontend/syntax_test.cc
View file @
01a10755
...
@@ -23,13 +23,14 @@
...
@@ -23,13 +23,14 @@
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/op/use_ops.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string
(
model_dir
,
""
,
""
);
PD_
DEFINE_string
(
model_dir
,
""
,
""
);
namespace
cinn
{
namespace
cinn
{
namespace
frontend
{
namespace
frontend
{
...
@@ -69,7 +70,8 @@ TEST(syntax, program_execute_multi_elementwise_add) {
...
@@ -69,7 +70,8 @@ TEST(syntax, program_execute_multi_elementwise_add) {
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
auto
scope
=
BuildScope
(
target
,
graph
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"B"
);
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"B"
);
...
@@ -88,7 +90,8 @@ TEST(syntax, program_execute_multi_elementwise_add2) {
...
@@ -88,7 +90,8 @@ TEST(syntax, program_execute_multi_elementwise_add2) {
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
auto
scope
=
BuildScope
(
target
,
graph
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
@@ -121,7 +124,8 @@ std::get<2>(programTuple);
...
@@ -121,7 +124,8 @@ std::get<2>(programTuple);
auto graph = cinn::frontend::Optimize(program.get(), fetch_ids, target);
auto graph = cinn::frontend::Optimize(program.get(), fetch_ids, target);
scope = BuildScope(target, graph, scope);
scope = BuildScope(target, graph, scope);
hlir::framework::GraphCompiler gc(target, scope, graph);
hlir::framework::CompilationContext context(graph, scope,target);
hlir::framework::GraphCompiler gc(context);
auto runtime_program = gc.Build();
auto runtime_program = gc.Build();
auto at = scope->GetTensor("A");
auto at = scope->GetTensor("A");
...
@@ -133,11 +137,12 @@ std::get<2>(programTuple);
...
@@ -133,11 +137,12 @@ std::get<2>(programTuple);
LOG(INFO) << "scope.names: " << Join(scope->var_names(), ",");
LOG(INFO) << "scope.names: " << Join(scope->var_names(), ",");
const std::string output_name = "fc_0.tmp_2";
const std::string output_name = "fc_0.tmp_2";
auto tensor =
auto tensor = scope->GetTensor(var_map_paddle_to_program.at(output_name));
scope->GetTensor(var_map_paddle_to_program.at(output_name)); LOG(INFO) <<
LOG(INFO) << "tensor.shape: " << utils::Join(tensor->shape().data(), ",");
"tensor.shape: " << utils::Join(tensor->shape().data(), ","); auto data =
auto data = GetTensorData<float>(tensor, target);
GetTensorData<float>(tensor, target); for (int i = 0; i < 10; i++) LOG(INFO) <<
for (int i = 0; i < 10; i++) {
"data: " << data[i];
LOG(INFO) << "data: " << data[i];
}
}
}
*/
*/
...
...
paddle/cinn/gtest_main.cc
View file @
01a10755
...
@@ -12,12 +12,12 @@
...
@@ -12,12 +12,12 @@
// See the License for the specific language governing permissions and
// See the License for the specific language governing permissions and
// limitations under the License.
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <gtest/gtest.h>
#include "paddle/utils/flags.h"
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
GFLAGS_NAMESPACE
::
ParseCommandLineFlags
(
&
argc
,
&
argv
,
false
);
paddle
::
flags
::
ParseCommandLineFlags
(
&
argc
,
&
argv
);
return
RUN_ALL_TESTS
();
return
RUN_ALL_TESTS
();
}
}
paddle/cinn/hlir/dialect/.gitignore
0 → 100644
View file @
01a10755
generated/**
generated/*
paddle/cinn/hlir/dialect/CMakeLists.txt
View file @
01a10755
# TODO(Aurelius84): new_ir_compiler depends on pd_dialect and could
add_subdirectory
(
operator
)
# not found under CINN_ONLY mode
add_subdirectory
(
runtime
)
if
(
NOT CINN_ONLY
)
cinn_cc_library
(
cinn_dialect SRCS runtime_dialect.cc jit_kernel_op.cc DEPS
pd_dialect
)
endif
()
paddle/cinn/hlir/dialect/jit_kernel_op.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/jit_kernel_op.h"
#include "paddle/ir/core/builtin_attribute.h"
#include "paddle/ir/core/enforce.h"
namespace
cinn
{
namespace
dialect
{
const
char
*
JitKernelOp
::
attributes_name
[
attributes_num
]
=
{
kAttrName
};
void
JitKernelOp
::
Verify
()
{
VLOG
(
4
)
<<
"Verifying inputs, outputs and attributes for: JitKernelOp."
;
auto
&
attributes
=
this
->
attributes
();
IR_ENFORCE
(
attributes
.
count
(
kAttrName
)
>
0
&&
attributes
.
at
(
kAttrName
).
isa
<::
ir
::
PointerAttribute
>
(),
"Type of attribute: instruction is not right."
);
}
hlir
::
framework
::
Instruction
*
JitKernelOp
::
instruction
()
{
void
*
ptr
=
attributes
().
at
(
kAttrName
).
dyn_cast
<
ir
::
PointerAttribute
>
().
data
();
return
reinterpret_cast
<
hlir
::
framework
::
Instruction
*>
(
ptr
);
}
}
// namespace dialect
}
// namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
JitKernelOp
)
paddle/cinn/hlir/dialect/jit_kernel_op.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/ir/core/op_base.h"
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
class
Instruction
;
}
// namespace framework
}
// namespace hlir
namespace
dialect
{
/*
* TODO(Aurelius84): THIS IS NOT FINAL STATE!
* JitKernel is unified runtime operation to represent
* jit compiled function ptr from backend, such as
* nvrct.
* Ideally, JitKernel should only contains ArrayAttribute
* with each element is PointerAttribute, which is jit
* function ptr indeed.
* Currently, we regard hlir::framework::Instruction
* temporarily, and will spilt executor information like
* scope, inputs, outputs into InterpretorCore module.
*/
class
JitKernelOp
:
public
::
ir
::
Op
<
JitKernelOp
>
{
public:
using
Op
::
Op
;
static
const
char
*
name
()
{
return
"cinn.jit_kernel"
;
}
// TODO(Aurelius84): Think deeply what should contains
static
constexpr
uint32_t
attributes_num
=
1
;
static
constexpr
char
*
kAttrName
=
"instruction"
;
static
const
char
*
attributes_name
[
attributes_num
];
hlir
::
framework
::
Instruction
*
instruction
();
void
Verify
();
};
}
// namespace dialect
}
// namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
JitKernelOp
)
paddle/cinn/hlir/dialect/operator/CMakeLists.txt
0 → 100644
View file @
01a10755
add_subdirectory
(
ir
)
add_subdirectory
(
transforms
)
paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt
0 → 100644
View file @
01a10755
# TODO(Aurelius84): pir_compiler depends on pd_op_dialect and could
# not found under CINN_ONLY mode
if
(
NOT CINN_ONLY
)
set
(
CINN_DIALECT_SOURCE_DIR
"
${
PADDLE_SOURCE_DIR
}
/paddle/cinn/hlir/dialect/operator/ir"
)
# Generate cinn_op_dialect files defining op using op_gen_file
set
(
cinn_op_gen_parsed_yaml_file
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/operators/generator/parse_op.py
)
set
(
cinn_op_gen_file
${
PADDLE_SOURCE_DIR
}
/paddle/fluid/pir/dialect/op_generator/op_gen.py
)
set
(
cinn_op_compat_yaml_file
${
PADDLE_SOURCE_DIR
}
/paddle/phi/api/yaml/op_compat.yaml
)
set
(
cinn_op_yaml_file
${
PADDLE_SOURCE_DIR
}
/paddle/cinn/hlir/dialect/operator/ir/ops.yaml
)
set
(
parsed_op_dir
${
PADDLE_SOURCE_DIR
}
/paddle/cinn/hlir/dialect/generated
)
set
(
cinn_op_parsed_yaml_file
${
parsed_op_dir
}
/ops.parsed.yaml
)
set
(
cinn_op_parsed_yaml_files
${
cinn_op_parsed_yaml_file
}
)
set
(
cinn_op_namespace cinn,dialect
)
set
(
cinn_op_dialect_name cinn_op
)
set
(
cinn_op_header_file
${
CINN_DIALECT_SOURCE_DIR
}
/cinn_op.h
)
set
(
cinn_op_source_file
${
CINN_DIALECT_SOURCE_DIR
}
/cinn_op.cc
)
set
(
cinn_op_header_file_tmp
${
cinn_op_header_file
}
.tmp
)
set
(
cinn_op_source_file_tmp
${
cinn_op_source_file
}
.tmp
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E make_directory
${
parsed_op_dir
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
cinn_op_gen_parsed_yaml_file
}
--op_yaml_path
${
cinn_op_yaml_file
}
--output_path
${
cinn_op_parsed_yaml_file
}
)
execute_process
(
COMMAND
${
PYTHON_EXECUTABLE
}
${
cinn_op_gen_file
}
--op_yaml_files
${
cinn_op_parsed_yaml_files
}
--op_compat_yaml_file
${
cinn_op_compat_yaml_file
}
--namespaces
${
cinn_op_namespace
}
--dialect_name
${
cinn_op_dialect_name
}
--op_def_h_file
${
cinn_op_header_file_tmp
}
--op_def_cc_file
${
cinn_op_source_file_tmp
}
)
set
(
generated_files_cinn_op
"
${
cinn_op_header_file
}
"
"
${
cinn_op_source_file
}
"
)
foreach
(
generated_file
${
generated_files_cinn_op
}
)
if
(
EXISTS
"
${
generated_file
}
.tmp"
AND EXISTS
"
${
generated_file
}
"
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy_if_different
"
${
generated_file
}
.tmp"
"
${
generated_file
}
"
)
message
(
"copy if different
${
generated_file
}
.tmp
${
generated_file
}
"
)
elseif
(
EXISTS
"
${
generated_file
}
.tmp"
)
execute_process
(
COMMAND
${
CMAKE_COMMAND
}
-E copy
"
${
generated_file
}
.tmp"
"
${
generated_file
}
"
)
message
(
"copy
${
generated_file
}
.tmp
${
generated_file
}
"
)
endif
()
endforeach
()
cinn_cc_library
(
cinn_op_dialect
SRCS
op_dialect.cc
${
cinn_op_source_file
}
manual_op.cc
op_attribute.cc
DEPS
op_dialect_vjp
)
target_include_directories
(
cinn_op_dialect PRIVATE
${
CINN_DIALECT_SOURCE_DIR
}
)
endif
()
paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <functional>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/pir/utils.h"
#include "paddle/pir/core/attribute_base.h"
#include "paddle/pir/core/operation.h"
namespace
cinn
{
namespace
dialect
{
// TODO(Aurelius84): Need to figure out what we need indeed for GroupOp.
// Currently we paste almost members here and will remove them step by
// step.
struct
GroupInfo
{
public:
explicit
GroupInfo
(
const
std
::
vector
<::
pir
::
Operation
*>&
group_ops
)
:
ops
(
group_ops
)
{
Initialize
();
}
explicit
GroupInfo
(
std
::
initializer_list
<::
pir
::
Operation
*>
group_ops
)
:
ops
(
group_ops
)
{
Initialize
();
}
std
::
string
group_id
;
std
::
string
fn_name
;
hlir
::
framework
::
OpPatternKind
op_pattern_kind
;
std
::
vector
<::
pir
::
Operation
*>
ops
;
std
::
vector
<
std
::
string
>
input_names
;
std
::
vector
<
std
::
string
>
output_names
;
private:
void
Initialize
()
{
op_pattern_kind
=
hlir
::
framework
::
OpPatternKind
::
kElementWise
;
fn_name
=
hlir
::
framework
::
pir
::
CompatibleInfo
::
GroupOpsName
(
ops
);
}
};
struct
GroupInfoAttributeStorage
:
public
pir
::
AttributeStorage
{
using
ParamKey
=
GroupInfo
;
explicit
GroupInfoAttributeStorage
(
const
ParamKey
&
key
)
:
data_
(
key
)
{}
static
GroupInfoAttributeStorage
*
Construct
(
const
ParamKey
&
key
)
{
return
new
GroupInfoAttributeStorage
(
key
);
}
static
std
::
size_t
HashValue
(
const
ParamKey
&
key
)
{
return
std
::
hash
<
std
::
string
>
{}(
key
.
group_id
);
}
bool
operator
==
(
const
ParamKey
&
key
)
const
{
return
data_
.
group_id
==
key
.
group_id
;
}
const
ParamKey
&
GetAsKey
()
const
{
return
data_
;
}
private:
ParamKey
data_
;
};
struct
JITInfoAttributeStorage
:
public
pir
::
AttributeStorage
{
using
ParamKey
=
cinn
::
hlir
::
framework
::
pir
::
CUDAJITInfo
;
explicit
JITInfoAttributeStorage
(
const
ParamKey
&
key
)
:
data_
(
key
)
{}
static
JITInfoAttributeStorage
*
Construct
(
const
ParamKey
&
key
)
{
return
new
JITInfoAttributeStorage
(
key
);
}
static
std
::
size_t
HashValue
(
const
ParamKey
&
key
)
{
return
std
::
hash
<
int64_t
>
()(
*
(
reinterpret_cast
<
int64_t
*>
(
key
.
fn_ptr
)));
}
bool
operator
==
(
const
ParamKey
&
key
)
const
{
return
data_
.
fn_ptr
==
key
.
fn_ptr
;
}
const
ParamKey
&
GetAsKey
()
const
{
return
data_
;
}
private:
ParamKey
data_
;
};
}
// namespace dialect
}
// namespace cinn
paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include <vector>
#include "glog/logging.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/pir/core/builtin_type.h"
#include "paddle/pir/core/enforce.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/dialect/control_flow/ir/cf_op.h"
namespace
cinn
{
namespace
dialect
{
const
char
*
GroupOp
::
attributes_name
[
GroupOp
::
attributes_num
]
=
{
"group_info"
};
const
char
*
ConcatOp
::
attributes_name
[
ConcatOp
::
attributes_num
]
=
{
"axis"
};
const
char
*
SplitOp
::
attributes_name
[
SplitOp
::
attributes_num
]
=
{
"num_or_sections"
,
"axis"
};
void
GroupOp
::
Build
(
pir
::
Builder
&
builder
,
pir
::
OperationArgument
&
argument
,
const
std
::
vector
<
pir
::
Type
>
&
output_types
)
{
argument
.
AddRegion
(
nullptr
);
argument
.
output_types
=
output_types
;
}
void
GroupOp
::
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
std
::
unique_ptr
<
pir
::
Block
>
&&
block
)
{
VLOG
(
4
)
<<
"Start build GroupOp"
;
if
(
block
&&
!
block
->
empty
())
{
IR_ENFORCE
(
block
->
back
().
isa
<
pir
::
YieldOp
>
());
auto
&
op
=
block
->
back
();
for
(
size_t
i
=
0
;
i
<
op
.
num_operands
();
++
i
)
{
argument
.
AddOutput
(
op
.
operand
(
i
).
type
());
}
}
argument
.
AddRegion
()
->
push_back
(
block
.
release
());
}
pir
::
Block
*
GroupOp
::
block
()
{
pir
::
Region
&
region
=
(
*
this
)
->
region
(
0
);
if
(
region
.
empty
())
region
.
emplace_back
();
return
&
region
.
front
();
}
std
::
vector
<
pir
::
Operation
*>
GroupOp
::
ops
()
{
std
::
vector
<
pir
::
Operation
*>
rt_ops
;
for
(
auto
&
op
:
*
block
())
{
rt_ops
.
push_back
(
&
op
);
}
return
rt_ops
;
}
void
GroupOp
::
VerifySig
()
{}
void
GroupOp
::
Print
(
pir
::
IrPrinter
&
printer
)
{
auto
&
os
=
printer
.
os
;
auto
op
=
operation
();
printer
.
PrintOpResult
(
op
);
os
<<
" = "
<<
name
();
printer
.
PrintOpOperands
(
op
);
os
<<
" -> "
;
printer
.
PrintOpReturnType
(
op
);
os
<<
" {"
;
for
(
auto
&
sub_op
:
ops
())
{
os
<<
"
\n
"
;
printer
.
PrintOperation
(
sub_op
);
}
os
<<
"
\n
}"
;
}
void
ConcatOp
::
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
const
std
::
vector
<
pir
::
Value
>
&
inputs
,
int
axis
)
{
VLOG
(
4
)
<<
"Start build ConcatOp"
;
argument
.
inputs
=
inputs
;
std
::
vector
<
pir
::
Type
>
inputs_type
(
inputs
.
size
());
IR_ENFORCE
(
inputs
.
size
()
>
0
);
auto
first_ele
=
inputs
[
0
].
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
();
phi
::
DDim
out_dims
=
first_ele
.
dims
();
if
(
axis
<
0
)
{
axis
+=
out_dims
.
size
();
}
for
(
size_t
idx
=
0
;
idx
<
inputs
.
size
();
++
idx
)
{
inputs_type
[
idx
]
=
inputs
[
idx
].
type
();
if
(
idx
>
0
)
{
auto
dim_i
=
inputs
[
idx
]
.
type
()
.
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
()
.
dims
();
out_dims
[
axis
]
+=
dim_i
[
axis
];
}
}
auto
out_type
=
paddle
::
dialect
::
DenseTensorType
::
get
(
pir
::
IrContext
::
Instance
(),
first_ele
.
dtype
(),
out_dims
,
first_ele
.
data_layout
(),
first_ele
.
lod
(),
first_ele
.
offset
());
argument
.
output_types
.
emplace_back
(
out_type
);
PassStopGradientsDefaultly
(
argument
);
argument
.
AddAttribute
(
"axis"
,
pir
::
Int32Attribute
::
get
(
pir
::
IrContext
::
Instance
(),
axis
));
}
void
SplitOp
::
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
pir
::
Value
input
,
const
std
::
vector
<
int
>
&
sections
,
int
axis
)
{
VLOG
(
4
)
<<
"Start build ConcatOp"
;
argument
.
inputs
.
push_back
(
input
);
std
::
vector
<
pir
::
Type
>
output_type
(
sections
.
size
());
auto
input_ele
=
input
.
type
().
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
();
if
(
axis
<
0
)
{
axis
+=
input_ele
.
dims
().
size
();
}
std
::
vector
<
pir
::
Attribute
>
section_attrs
;
for
(
size_t
idx
=
0
;
idx
<
sections
.
size
();
++
idx
)
{
auto
out_dims
=
input_ele
.
dims
();
out_dims
[
axis
]
=
sections
[
idx
];
auto
out_type
=
paddle
::
dialect
::
DenseTensorType
::
get
(
pir
::
IrContext
::
Instance
(),
input_ele
.
dtype
(),
out_dims
,
input_ele
.
data_layout
(),
input_ele
.
lod
(),
input_ele
.
offset
());
argument
.
output_types
.
emplace_back
(
out_type
);
pir
::
Attribute
attr_axis
=
pir
::
Int32Attribute
::
get
(
pir
::
IrContext
::
Instance
(),
sections
[
idx
]);
section_attrs
.
push_back
(
attr_axis
);
}
PassStopGradientsDefaultly
(
argument
);
argument
.
AddAttribute
(
"num_or_sections"
,
pir
::
ArrayAttribute
::
get
(
pir
::
IrContext
::
Instance
(),
section_attrs
));
argument
.
AddAttribute
(
"axis"
,
pir
::
Int32Attribute
::
get
(
pir
::
IrContext
::
Instance
(),
axis
));
}
}
// namespace dialect
}
// namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
GroupOp
)
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
ConcatOp
)
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
SplitOp
)
paddle/cinn/hlir/dialect/operator/ir/manual_op.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/op_base.h"
#include "paddle/pir/core/operation.h"
#include "paddle/pir/core/operation_utils.h"
namespace
cinn
{
namespace
dialect
{
class
GroupOp
:
public
pir
::
Op
<
GroupOp
>
{
public:
using
Op
::
Op
;
static
const
char
*
name
()
{
return
"cinn_op.group"
;
}
static
constexpr
uint32_t
attributes_num
=
1
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
const
std
::
vector
<
pir
::
Type
>
&
output_types
);
static
void
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
std
::
unique_ptr
<
pir
::
Block
>
&&
block
);
pir
::
Block
*
block
();
std
::
vector
<
pir
::
Operation
*>
ops
();
void
VerifySig
();
void
Print
(
pir
::
IrPrinter
&
printer
);
// NOLINT
};
class
IR_API
ConcatOp
:
public
pir
::
Op
<
ConcatOp
>
{
public:
using
Op
::
Op
;
static
const
char
*
name
()
{
return
"cinn_op.concat"
;
}
static
constexpr
uint32_t
attributes_num
=
1
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
const
std
::
vector
<
pir
::
Value
>
&
inputs
,
int
axis
);
void
VerifySig
()
const
{}
};
class
IR_API
SplitOp
:
public
pir
::
Op
<
SplitOp
>
{
public:
using
Op
::
Op
;
static
const
char
*
name
()
{
return
"cinn_op.split"
;
}
static
constexpr
uint32_t
attributes_num
=
2
;
static
const
char
*
attributes_name
[
attributes_num
];
static
void
Build
(
pir
::
Builder
&
builder
,
// NOLINT
pir
::
OperationArgument
&
argument
,
// NOLINT
pir
::
Value
input
,
const
std
::
vector
<
int
>
&
sections
,
int
axis
);
void
VerifySig
()
const
{}
};
}
// namespace dialect
}
// namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
GroupOp
)
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
ConcatOp
)
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
SplitOp
)
paddle/cinn/hlir/dialect/operator/ir/op_attribute.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
namespace
cinn
{
namespace
dialect
{
const
GroupInfo
&
GroupInfoAttribute
::
data
()
const
{
return
storage
()
->
GetAsKey
();
}
const
cinn
::
hlir
::
framework
::
pir
::
CUDAJITInfo
&
CUDAJITInfoAttribute
::
data
()
const
{
return
storage
()
->
GetAsKey
();
}
}
// namespace dialect
}
// namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
GroupInfoAttribute
)
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
CUDAJITInfoAttribute
)
paddle/cinn/hlir/dialect/operator/ir/op_attribute.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/hlir/dialect/operator/ir/attribute_storage.h"
#include "paddle/pir/core/attribute_base.h"
namespace
cinn
{
namespace
dialect
{
class
GroupInfoAttribute
:
public
pir
::
Attribute
{
public:
using
Attribute
::
Attribute
;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR
(
GroupInfoAttribute
,
GroupInfoAttributeStorage
);
bool
operator
<
(
const
GroupInfoAttribute
&
right
)
const
{
return
storage
()
<
right
.
storage
();
}
const
GroupInfo
&
data
()
const
;
};
class
CUDAJITInfoAttribute
:
public
pir
::
Attribute
{
public:
using
Attribute
::
Attribute
;
DECLARE_ATTRIBUTE_UTILITY_FUNCTOR
(
CUDAJITInfoAttribute
,
JITInfoAttributeStorage
);
bool
operator
<
(
const
CUDAJITInfoAttribute
&
right
)
const
{
return
storage
()
<
right
.
storage
();
}
const
cinn
::
hlir
::
framework
::
pir
::
CUDAJITInfo
&
data
()
const
;
};
}
// namespace dialect
}
// namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
GroupInfoAttribute
)
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
CUDAJITInfoAttribute
)
paddle/cinn/hlir/dialect/operator/ir/op_dialect.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/ir/op_dialect.h"
// NOTE(chenxi67): File cinn_op.h is generated by op_gen.py, see details in
// paddle/cinn/hlir/dialect/CMakeLists.txt.
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/manual_op.h"
#include "paddle/cinn/hlir/dialect/operator/ir/op_attribute.h"
namespace
cinn
{
namespace
dialect
{
OperatorDialect
::
OperatorDialect
(
::
pir
::
IrContext
*
context
)
:
::
pir
::
Dialect
(
name
(),
context
,
::
pir
::
TypeId
::
get
<
cinn
::
dialect
::
OperatorDialect
>
())
{
this
->
initialize
();
}
void
OperatorDialect
::
initialize
()
{
// NOTE(chenxi67): GET_OP_LIST is defined in cinn_op.h which is
// generated by op_gen.py, see details in
// paddle/cinn/hlir/dialect/CMakeLists.txt.
RegisterOps
<
#define GET_OP_LIST
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.cc" // NOLINT
>
();
RegisterOp
<
GroupOp
>
();
RegisterOp
<
ConcatOp
>
();
RegisterOp
<
SplitOp
>
();
RegisterAttribute
<
GroupInfoAttribute
>
();
RegisterAttribute
<
CUDAJITInfoAttribute
>
();
}
void
OperatorDialect
::
PrintType
(
pir
::
Type
type
,
std
::
ostream
&
os
)
const
{}
void
OperatorDialect
::
PrintAttribute
(
pir
::
Attribute
attr
,
std
::
ostream
&
os
)
const
{
if
(
attr
.
isa
<
GroupInfoAttribute
>
())
{
os
<<
"("
<<
attr
.
dialect
().
name
();
os
<<
'.'
;
if
(
auto
group_info_attr
=
attr
.
dyn_cast
<
GroupInfoAttribute
>
())
{
const
GroupInfo
&
data
=
group_info_attr
.
data
();
os
<<
"GroupInfo)"
<<
"["
<<
data
.
fn_name
<<
"]"
;
}
{
os
<<
"<#AttrNotImplemented>"
;
}
}
else
if
(
attr
.
isa
<
CUDAJITInfoAttribute
>
())
{
auto
cuda_jit_info
=
attr
.
dyn_cast
<
CUDAJITInfoAttribute
>
();
os
<<
"("
<<
cuda_jit_info
.
data
().
fn_ptr
;
os
<<
')'
;
}
else
{
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"cinn dialect only support GrupInfo and CUDAJITInfo"
));
}
}
void
OperatorDialect
::
PrintOperation
(
pir
::
Operation
*
op
,
pir
::
IrPrinter
&
printer
)
const
{
if
(
auto
group_op
=
op
->
dyn_cast
<
GroupOp
>
())
{
group_op
.
Print
(
printer
);
}
else
{
printer
.
PrintGeneralOperation
(
op
);
}
}
}
// namespace dialect
}
// namespace cinn
IR_DEFINE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
OperatorDialect
)
paddle/cinn/hlir/dialect/operator/ir/op_dialect.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pir/core/dialect.h"
namespace
cinn
{
namespace
dialect
{
class
OperatorDialect
:
public
::
pir
::
Dialect
{
public:
explicit
OperatorDialect
(
::
pir
::
IrContext
*
context
);
static
const
char
*
name
()
{
return
"cinn_op"
;
}
void
PrintType
(
pir
::
Type
type
,
std
::
ostream
&
os
)
const
override
;
void
PrintAttribute
(
pir
::
Attribute
type
,
std
::
ostream
&
os
)
const
override
;
void
PrintOperation
(
pir
::
Operation
*
op
,
pir
::
IrPrinter
&
printer
)
const
override
;
// NOLINT
private:
void
initialize
();
};
}
// namespace dialect
}
// namespace cinn
IR_DECLARE_EXPLICIT_TYPE_ID
(
cinn
::
dialect
::
OperatorDialect
)
paddle/cinn/hlir/dialect/operator/ir/ops.yaml
0 → 100644
View file @
01a10755
-
op
:
broadcast
args
:
(Tensor x, int64_t[] broadcast_axes, int64_t[] out_shape)
output
:
Tensor(out)
infer_meta
:
func
:
CINNBroadcastInferMeta
param
:
[
x
,
broadcast_axes
,
out_shape
]
kernel
:
func
:
expand
param
:
[
x
,
broadcast_axes
]
-
op
:
reduce_max
args
:
(Tensor x, int64_t[] dim, bool keep_dim)
output
:
Tensor(out)
infer_meta
:
func
:
ReduceInferMeta
kernel
:
func
:
frobenius_norm
-
op
:
reduce_sum
args
:
(Tensor x, int64_t[] dim, bool keep_dim)
output
:
Tensor(out)
infer_meta
:
func
:
ReduceInferMeta
kernel
:
func
:
frobenius_norm
-
op
:
reshape
args
:
(Tensor x, int[] shape)
output
:
Tensor(out)
infer_meta
:
func
:
ReshapeInferMeta
kernel
:
func
:
reshape
-
op
:
scale
args
:
(Tensor x, float scale=1.0, float bias=0.0, bool bias_after_scale=true)
output
:
Tensor(out)
infer_meta
:
func
:
UnchangedInferMeta
param
:
[
x
]
kernel
:
func
:
scale
-
op
:
slice
args
:
(Tensor x, int64_t[] axes, int64_t[] starts, int64_t[] ends, int64_t[] infer_flags, int64_t[] decrease_axis)
output
:
Tensor
infer_meta
:
func
:
SliceRawInferMeta
kernel
:
func
:
slice
-
op
:
uniform_random
args
:
(int64_t[] shape, float min, float max, int seed, DataType dtype, int diag_num = 0, int diag_step=0, float diag_val=1.0)
output
:
Tensor(out)
infer_meta
:
func
:
CreateVecShapeInferMeta
param
:
[
shape
,
dtype
]
kernel
:
func
:
full_int_array
param
:
[
shape
,
dtype
]
paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
0 → 100644
View file @
01a10755
add_subdirectory
(
group_merge
)
if
(
NOT CINN_ONLY
)
cinn_cc_library
(
pd_to_cinn_pass
SRCS
pd_to_cinn_pass.cc
DEPS
drr
cinn_op_dialect
op_dialect_vjp
)
cinn_cc_library
(
add_broadcast_to_elementwise_pass
SRCS
add_broadcast_to_elementwise_pass.cc
DEPS
pir
cinn_op_dialect
op_dialect_vjp
)
endif
()
paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h"
#include "paddle/cinn/hlir/dialect/operator/ir/cinn_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/drr/api/match_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/pir/core/builtin_dialect.h"
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/pattern_applicator.h"
#include "paddle/pir/pattern_rewrite/pattern_match.h"
#include "paddle/pir/pattern_rewrite/pattern_rewrite_driver.h"
namespace
cinn
{
namespace
dialect
{
namespace
ir
{
int64_t
GetDimByIndex
(
const
phi
::
DDim
&
first
,
const
phi
::
DDim
&
second
,
int
short_align_axis
,
int
idx
)
{
// rank of first less than rank of second
if
(
idx
<
short_align_axis
)
{
return
second
[
idx
];
}
else
{
return
first
[
idx
-
short_align_axis
]
>
second
[
idx
]
?
first
[
idx
-
short_align_axis
]
:
second
[
idx
];
}
}
std
::
vector
<
int64_t
>
GetOutputShape
(
const
phi
::
DDim
&
x
,
const
phi
::
DDim
&
y
)
{
std
::
vector
<
int64_t
>
vec_res
;
if
(
x
.
size
()
>=
y
.
size
())
{
int
short_align_axis
=
x
.
size
()
-
y
.
size
();
int
max_rank
=
x
.
size
();
vec_res
.
resize
(
max_rank
);
for
(
size_t
i
=
0
;
i
<
max_rank
;
++
i
)
{
vec_res
[
i
]
=
GetDimByIndex
(
y
,
x
,
short_align_axis
,
i
);
}
}
else
{
int
short_align_axis
=
y
.
size
()
-
x
.
size
();
int
max_rank
=
y
.
size
();
vec_res
.
resize
(
max_rank
);
for
(
size_t
i
=
0
;
i
<
max_rank
;
++
i
)
{
vec_res
[
i
]
=
GetDimByIndex
(
x
,
y
,
short_align_axis
,
max_rank
);
}
}
return
vec_res
;
}
bool
IsSameDim
(
const
phi
::
DDim
&
first
,
const
std
::
vector
<
int64_t
>&
second
)
{
if
(
first
.
size
()
==
second
.
size
())
{
bool
same
=
true
;
for
(
size_t
i
=
0
;
i
<
first
.
size
();
++
i
)
{
if
(
first
[
i
]
!=
second
[
i
])
{
same
=
false
;
break
;
}
}
return
same
;
}
return
false
;
}
std
::
vector
<
int64_t
>
GetBroadcastAxis
(
const
phi
::
DDim
&
in_shape
,
const
std
::
vector
<
int64_t
>&
out_shape
)
{
std
::
vector
<
int64_t
>
broadcast_axes
(
in_shape
.
size
(),
0
);
auto
in_shape_size
=
in_shape
.
size
();
if
(
in_shape_size
>=
1
)
{
for
(
int
i
=
1
;
i
<=
in_shape_size
;
++
i
)
{
broadcast_axes
[
in_shape_size
-
i
]
=
out_shape
.
size
()
-
i
;
}
}
return
broadcast_axes
;
}
bool
ProcessOp
(
pir
::
Operation
*
op
,
pir
::
PatternRewriter
*
rewriter
)
{
auto
x_dims
=
op
->
operand_source
(
0
)
.
type
()
.
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
()
.
dims
();
auto
y_dims
=
op
->
operand_source
(
1
)
.
type
()
.
dyn_cast
<
paddle
::
dialect
::
DenseTensorType
>
()
.
dims
();
if
(
x_dims
!=
y_dims
)
{
auto
output_shape
=
GetOutputShape
(
x_dims
,
y_dims
);
if
(
!
IsSameDim
(
x_dims
,
output_shape
))
{
// add broadcast to input 0
if
(
auto
full_op
=
op
->
operand_source
(
0
)
.
dyn_cast
<
pir
::
OpResult
>
()
.
owner
()
->
dyn_cast
<
paddle
::
dialect
::
FullOp
>
())
{
auto
new_full
=
rewriter
->
Build
<
paddle
::
dialect
::
FullOp
>
(
output_shape
,
full_op
->
attribute
(
"value"
).
dyn_cast
<
pir
::
FloatAttribute
>
().
data
(),
full_op
->
attribute
(
"dtype"
)
.
dyn_cast
<
paddle
::
dialect
::
DataTypeAttribute
>
()
.
data
(),
full_op
->
attribute
(
"place"
)
.
dyn_cast
<
paddle
::
dialect
::
PlaceAttribute
>
()
.
data
());
op
->
operand
(
0
).
set_source
(
new_full
->
result
(
0
));
}
else
{
auto
new_transpose_op
=
rewriter
->
Build
<
cinn
::
dialect
::
BroadcastOp
>
(
op
->
operand_source
(
0
),
GetBroadcastAxis
(
x_dims
,
output_shape
),
output_shape
);
op
->
operand
(
0
).
set_source
(
new_transpose_op
->
result
(
0
));
}
}
if
(
!
IsSameDim
(
y_dims
,
output_shape
))
{
if
(
auto
full_op
=
op
->
operand_source
(
1
)
.
dyn_cast
<
pir
::
OpResult
>
()
.
owner
()
->
dyn_cast
<
paddle
::
dialect
::
FullOp
>
())
{
auto
new_full
=
rewriter
->
Build
<
paddle
::
dialect
::
FullOp
>
(
output_shape
,
full_op
->
attribute
(
"value"
).
dyn_cast
<
pir
::
FloatAttribute
>
().
data
(),
full_op
->
attribute
(
"dtype"
)
.
dyn_cast
<
paddle
::
dialect
::
DataTypeAttribute
>
()
.
data
(),
full_op
->
attribute
(
"place"
)
.
dyn_cast
<
paddle
::
dialect
::
PlaceAttribute
>
()
.
data
());
op
->
operand
(
1
).
set_source
(
new_full
->
result
(
0
));
}
else
{
auto
new_transpose_op
=
rewriter
->
Build
<
cinn
::
dialect
::
BroadcastOp
>
(
op
->
operand_source
(
1
),
GetBroadcastAxis
(
y_dims
,
output_shape
),
output_shape
);
op
->
operand
(
1
).
set_source
(
new_transpose_op
->
result
(
0
));
}
}
return
true
;
}
return
false
;
}
template
<
typename
OPTYPE
>
class
AddBrodcastToElementwisePattern
:
public
pir
::
OpRewritePattern
<
OPTYPE
>
{
public:
using
pir
::
OpRewritePattern
<
OPTYPE
>::
OpRewritePattern
;
bool
MatchAndRewrite
(
OPTYPE
op
,
pir
::
PatternRewriter
&
rewriter
)
const
override
{
return
ProcessOp
(
op
,
&
rewriter
);
}
};
AddBroadcastToElementwisePass
::
AddBroadcastToElementwisePass
()
:
pir
::
PatternRewritePass
(
"add_broadcast_to_elementwise_pass"
,
1
)
{}
pir
::
RewritePatternSet
AddBroadcastToElementwisePass
::
InitializePatterns
(
pir
::
IrContext
*
context
)
{
pir
::
RewritePatternSet
ps
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
AddOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
SubtractOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
MultiplyOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
DivideOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
ElementwisePowOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
LessThanOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
LessEqualOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
EqualOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
NotEqualOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
GreaterThanOp
>>
(
context
);
ps
.
Add
<
AddBrodcastToElementwisePattern
<
paddle
::
dialect
::
GreaterEqualOp
>>
(
context
);
return
ps
;
}
bool
AddBroadcastToElementwisePass
::
CanApplyOn
(
pir
::
Operation
*
op
)
const
{
return
op
->
isa
<
pir
::
ModuleOp
>
()
&&
op
->
num_regions
()
>
0
;
}
}
// namespace ir
}
// namespace dialect
}
// namespace cinn
paddle/cinn/hlir/dialect/operator/transforms/add_broadcast_to_elementwise_pass.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/pir/pass/pass.h"
#include "paddle/pir/pattern_rewrite/frozen_rewrite_pattern_set.h"
namespace
cinn
{
namespace
dialect
{
namespace
ir
{
class
AddBroadcastToElementwisePass
:
public
pir
::
PatternRewritePass
{
public:
AddBroadcastToElementwisePass
();
pir
::
RewritePatternSet
InitializePatterns
(
pir
::
IrContext
*
context
)
override
;
bool
CanApplyOn
(
pir
::
Operation
*
op
)
const
override
;
};
}
// namespace ir
}
// namespace dialect
}
// namespace cinn
paddle/cinn/hlir/dialect/operator/transforms/group_merge/CMakeLists.txt
0 → 100644
View file @
01a10755
if
(
NOT CINN_ONLY
)
cinn_cc_library
(
op_with_group_merge_pass
SRCS
group_with_group_merge_pass.cc
op_with_group_merge_pass.cc
cinn_group_lowering_pass.cc
tensor_node.cc
DEPS
op_dialect_vjp
pir_compiler
cinn_runtime_dialect
)
endif
()
Prev
1
…
10
11
12
13
14
15
16
17
18
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment