Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1356 additions
and
553 deletions
+1356
-553
paddle/cinn/ir/lowered_func.cc
paddle/cinn/ir/lowered_func.cc
+21
-14
paddle/cinn/ir/lowered_func.h
paddle/cinn/ir/lowered_func.h
+11
-2
paddle/cinn/ir/module.cc
paddle/cinn/ir/module.cc
+10
-1
paddle/cinn/ir/module.h
paddle/cinn/ir/module.h
+2
-0
paddle/cinn/ir/operation.cc
paddle/cinn/ir/operation.cc
+6
-4
paddle/cinn/ir/operation.h
paddle/cinn/ir/operation.h
+2
-0
paddle/cinn/ir/schedule/CMakeLists.txt
paddle/cinn/ir/schedule/CMakeLists.txt
+8
-2
paddle/cinn/ir/schedule/factorize_reduction.h
paddle/cinn/ir/schedule/factorize_reduction.h
+424
-0
paddle/cinn/ir/schedule/ir_schedule.cc
paddle/cinn/ir/schedule/ir_schedule.cc
+289
-350
paddle/cinn/ir/schedule/ir_schedule.h
paddle/cinn/ir/schedule/ir_schedule.h
+59
-35
paddle/cinn/ir/schedule/ir_schedule_error.cc
paddle/cinn/ir/schedule/ir_schedule_error.cc
+3
-3
paddle/cinn/ir/schedule/ir_schedule_util.cc
paddle/cinn/ir/schedule/ir_schedule_util.cc
+144
-98
paddle/cinn/ir/schedule/ir_schedule_util.h
paddle/cinn/ir/schedule/ir_schedule_util.h
+15
-4
paddle/cinn/ir/schedule/schedule_base.cc
paddle/cinn/ir/schedule/schedule_base.cc
+74
-0
paddle/cinn/ir/schedule/schedule_base.h
paddle/cinn/ir/schedule/schedule_base.h
+169
-0
paddle/cinn/ir/schedule/schedule_desc.cc
paddle/cinn/ir/schedule/schedule_desc.cc
+12
-0
paddle/cinn/ir/schedule_block_graph.cc
paddle/cinn/ir/schedule_block_graph.cc
+1
-1
paddle/cinn/ir/schedule_block_graph.h
paddle/cinn/ir/schedule_block_graph.h
+2
-4
paddle/cinn/ir/tensor.cc
paddle/cinn/ir/tensor.cc
+67
-5
paddle/cinn/ir/tensor.h
paddle/cinn/ir/tensor.h
+37
-30
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/ir/lowered_func.cc
View file @
01a10755
...
...
@@ -25,9 +25,8 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
...
...
@@ -65,6 +64,16 @@ LoweredFunc _LoweredFunc_::Make(const std::string& name,
return
LoweredFunc
(
n
);
}
LoweredFunc
_LoweredFunc_
::
Make
(
const
std
::
string
&
name
,
const
std
::
vector
<
Argument
>&
args
,
const
Expr
&
body
)
{
auto
*
n
=
make_shared
<
_LoweredFunc_
>
();
n
->
name
=
name
;
n
->
args
=
args
;
n
->
body
=
body
;
return
LoweredFunc
(
n
);
}
void
_LoweredFunc_
::
CheckValid
()
const
{
// check there is at least one output
int
out_count
=
0
;
...
...
@@ -83,7 +92,7 @@ std::vector<const Expr*> _LoweredFunc_::expr_fields() const { return {&body}; }
void
_LoweredFunc_
::
PrepareCudaAxisInfoFromBody
()
{
std
::
set
<
Expr
>
bound_for_exprs
=
ir
::
CollectIRNodes
(
body
,
[](
const
Expr
*
expr
)
{
ir
::
ir_utils
::
CollectIRNodes
(
body
,
[](
const
Expr
*
expr
)
{
const
ir
::
For
*
for_expr
=
expr
->
As
<
ir
::
For
>
();
return
for_expr
!=
nullptr
&&
for_expr
->
is_binded
();
});
...
...
@@ -209,8 +218,7 @@ void _LoweredFunc_::AllocTempBuffer() {}
void
_LoweredFunc_
::
PrepareBufferCastExprs
(
bool
with_expr_gen_tensor
)
{
buffer_data_cast_exprs
.
clear
();
// collect write.
optim
::
TensorWriteTeller
write_teller
;
write_teller
.
Collect
(
&
body
);
auto
write_teller
=
ir
::
ir_utils
::
CollectTensorNeedsWrite
(
&
body
);
auto
tensors
=
CollectAllTensorReference
(
with_expr_gen_tensor
);
std
::
sort
(
tensors
.
begin
(),
...
...
@@ -224,7 +232,7 @@ void _LoweredFunc_::PrepareBufferCastExprs(bool with_expr_gen_tensor) {
if
(
!
tensor
->
buffer
.
defined
())
continue
;
Type
value_type
=
tensor
->
type
().
ElementOf
();
bool
is_const
=
!
write_teller
.
IsWrite
(
tensor
->
name
);
bool
is_const
=
!
write_teller
.
count
(
tensor
->
name
);
value_type
.
set_cpp_handle
();
value_type
.
set_cpp_const
(
is_const
);
Var
variable
=
_Var_
::
Make
(
tensor
->
name
,
value_type
);
...
...
@@ -250,8 +258,7 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {
}
// collect write.
std
::
vector
<
Expr
>
res
;
optim
::
TensorWriteTeller
write_teller
;
write_teller
.
Collect
(
&
body
);
auto
write_teller
=
ir
::
ir_utils
::
CollectTensorNeedsWrite
(
&
body
);
auto
tensors
=
CollectAllTensorReference
();
std
::
sort
(
tensors
.
begin
(),
...
...
@@ -269,7 +276,7 @@ std::vector<Expr> _LoweredFunc_::CudaAliasVarExprs() const {
continue
;
}
Type
value_type
=
tensor
->
type
().
ElementOf
();
bool
is_const
=
!
write_teller
.
IsWrite
(
tensor
->
name
);
bool
is_const
=
!
write_teller
.
count
(
tensor
->
name
);
value_type
.
set_cpp_handle
();
value_type
.
set_cpp_const
(
is_const
);
Var
variable
=
_Var_
::
Make
(
tensor
->
name
,
value_type
);
...
...
@@ -406,11 +413,11 @@ std::vector<Tensor> _LoweredFunc_::CollectAllTensorReference(
bool
with_expr_gen_tensor
)
const
{
std
::
set
<
Expr
>
tensor_exprs
=
with_expr_gen_tensor
?
ir
::
CollectIRNodes
(
?
ir
::
ir_utils
::
CollectIRNodes
(
body
,
[](
const
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
_Tensor_
>
();
})
:
ir
::
CollectIRNodesWithoutTensor
(
body
,
[](
const
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
_Tensor_
>
();
});
:
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
[](
const
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
_Tensor_
>
();
});
std
::
vector
<
Tensor
>
tensors
;
// remove the duplicate tensor by their name.
...
...
paddle/cinn/ir/lowered_func.h
View file @
01a10755
...
...
@@ -30,8 +30,10 @@ class _LoweredFunc_;
* the function signature of generated code.
*/
struct
Argument
{
//! Input or output.
enum
class
IO
{
kInput
=
0
,
kOutput
=
1
};
//! kInput: arg is input
//! kOutput: arg is output
//! kUnknown: arg maybe input or output
enum
class
IO
{
kInput
=
0
,
kOutput
=
1
,
kUnknown
=
2
};
IO
io
{
IO
::
kInput
};
...
...
@@ -164,6 +166,13 @@ struct _LoweredFunc_ : ExprNode<_LoweredFunc_> {
const
Expr
&
body
,
const
std
::
vector
<
ir
::
Buffer
>&
temp_bufs
);
// A simple version of the make function method,
// regardless of the argument buffer information and IO information of
// Argument, after building the function to optimize the buffer through pass
static
LoweredFunc
Make
(
const
std
::
string
&
name
,
const
std
::
vector
<
Argument
>&
args
,
const
Expr
&
body
);
bool
is_gpu_host
()
const
{
return
cuda_axis_info
.
valid
();
}
void
Verify
()
const
override
{}
...
...
paddle/cinn/ir/module.cc
View file @
01a10755
...
...
@@ -16,6 +16,7 @@
#include <memory>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/optimize.h"
...
...
@@ -48,12 +49,19 @@ void Module::Builder::AddBuffer(ir::Buffer buffer) {
}
}
void
Module
::
Builder
::
AddPredicate
(
ir
::
Expr
predicate
)
{
module_
->
predicates
.
push_back
(
predicate
);
}
void
Module
::
Builder
::
Clear
()
{
module_
->
buffers
.
clear
();
module_
->
functions
.
clear
();
module_
->
submodules
.
clear
();
module_
->
predicates
.
clear
();
}
Target
::
Arch
Module
::
Builder
::
GetTargetArch
()
{
return
module_
->
target
.
arch
;
}
Module
Module
::
Builder
::
Build
()
{
if
(
module_
->
functions
.
empty
())
{
VLOG
(
1
)
<<
"Module has no functions"
;
...
...
@@ -61,7 +69,8 @@ Module Module::Builder::Build() {
auto
res
=
ir
::
Module
(
module_
.
get
());
return
optim
::
Optimize
(
res
,
module_
->
target
);
res
=
optim
::
Optimize
(
res
,
module_
->
target
);
return
res
;
}
ir
::
_Module_
*
Module
::
self
()
{
return
p_
->
as
<
ir
::
_Module_
>
();
}
...
...
paddle/cinn/ir/module.h
View file @
01a10755
...
...
@@ -44,7 +44,9 @@ class Module : public ir::IrNodeRef {
void
AddFunction
(
ir
::
LoweredFunc
func
);
void
AddFunctionWithoutOptim
(
const
ir
::
LoweredFunc
&
func
);
void
AddBuffer
(
ir
::
Buffer
buffer
);
void
AddPredicate
(
ir
::
Expr
predicate
);
void
Clear
();
Target
::
Arch
GetTargetArch
();
Module
Build
();
...
...
paddle/cinn/ir/operation.cc
View file @
01a10755
...
...
@@ -49,10 +49,12 @@ Operation ComputeOp::Make(const std::string &name,
n
->
reduce_axis
=
reduce_axis
;
n
->
tag
=
tag
;
n
->
attrs
=
attrs
;
auto
axis
=
common
::
GenDefaultAxis
(
domain
.
size
());
std
::
vector
<
Expr
>
_axis
;
for
(
auto
&
x
:
axis
)
_axis
.
push_back
(
x
);
n
->
body
=
{
handle
(
_axis
)};
n
->
axis
=
common
::
GenDefaultAxis
(
domain
.
size
());
std
::
vector
<
Expr
>
tmp_axis
;
for
(
auto
&
x
:
n
->
axis
)
{
tmp_axis
.
push_back
(
x
);
}
n
->
body
=
{
handle
(
tmp_axis
)};
n
->
reduce_axis
=
reduce_axis
;
return
Operation
(
n
);
}
...
...
paddle/cinn/ir/operation.h
View file @
01a10755
...
...
@@ -105,6 +105,8 @@ struct BufferShareOp : public _Operation_ {
*/
struct
ComputeOp
:
public
_Operation_
{
using
handle_t
=
std
::
function
<
Expr
(
const
std
::
vector
<
Expr
>
&
)
>
;
//! Var on each dimension
std
::
vector
<
Var
>
axis
;
//! Var on each reduction axis, if the body is a Reduction.
std
::
vector
<
Var
>
reduce_axis
;
//! Shape of the output.
...
...
paddle/cinn/ir/schedule/CMakeLists.txt
View file @
01a10755
cinn_proto_library
(
schedule_desc_proto SRCS schedule_desc.proto
)
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS ir_schedule.cc ir_schedule_util.cc
ir_schedule_error.cc schedule_desc.cc
)
gather_srcs
(
cinnapi_src
SRCS
schedule_base.cc
ir_schedule.cc
ir_schedule_util.cc
ir_schedule_error.cc
schedule_desc.cc
)
foreach
(
header
${
schedule_desc_proto_HDRS
}
)
set
(
core_proto_includes
...
...
paddle/cinn/ir/schedule/factorize_reduction.h
0 → 100644
View file @
01a10755
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Used in FactorizeReduction
#pragma once
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/error.h"
namespace
cinn
{
namespace
ir
{
// Create the new Reduction-Factorized tensor,
// only used for FactorizeReduction schedule primitive.
Tensor
CreateRFTensor
(
const
Tensor
&
original_tensor
,
const
Expr
&
rf_loop
,
int
rf_axis
)
{
std
::
string
name
=
common
::
UniqName
(
original_tensor
->
name
+
"_rf"
);
std
::
vector
<
Expr
>
new_shape
=
original_tensor
->
shape
;
new_shape
.
insert
(
new_shape
.
begin
()
+
rf_axis
,
rf_loop
.
As
<
For
>
()
->
extent
);
Tensor
rf_tensor
=
_Tensor_
::
Make
(
name
,
original_tensor
->
type
(),
new_shape
,
new_shape
,
original_tensor
->
operation
,
original_tensor
->
reduce_axis
);
rf_tensor
->
WithBuffer
(
"global"
,
name
,
original_tensor
->
type
());
return
rf_tensor
;
}
// Base class to create a new reduce block,
// only used for FactorizeReduction schedule primitive.
class
ReduceBlockCreater
{
public:
ReduceBlockCreater
(
const
Expr
&
original_block
,
const
std
::
vector
<
Expr
>&
original_loops
,
const
Expr
&
rf_loop
,
const
Expr
&
original_update_stmt
,
const
ir
::
Tensor
&
rf_tensor
,
bool
is_rf_block
)
:
original_block_
(
original_block
),
original_loops_
(
original_loops
),
rf_loop_
(
rf_loop
),
original_update_stmt_
(
original_update_stmt
),
rf_tensor_
(
rf_tensor
),
is_rf_block_
(
is_rf_block
)
{
const
ScheduleBlockRealize
*
block_real
=
original_block_
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK_NOTNULL
(
block_real
);
num_block_iters_
=
block_real
->
iter_values
.
size
();
}
void
CreateBlock
()
{
CreateRFIter
();
for
(
int
i
=
0
;
i
<
num_block_iters_
;
++
i
)
{
CreateNormalIter
(
i
);
}
CreateUpdateStmt
();
std
::
string
new_update_block_name
=
original_block_
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
if
(
is_rf_block_
)
{
new_update_block_name
=
rf_tensor_
->
name
;
}
std
::
string
new_init_block_name
=
ir
::
GenReduceInitTensorNameOf
(
new_update_block_name
);
VLOG
(
5
)
<<
"new_init_block_name = "
<<
new_init_block_name
;
const
ir
::
Tensor
&
real_tensor
=
is_rf_block_
?
rf_tensor_
:
original_update_stmt_
.
As
<
ir
::
Store
>
()
->
tensor
.
as_tensor_ref
();
Expr
init_value
=
real_tensor
->
GetReduceInitVal
();
const
std
::
vector
<
Expr
>&
domain
=
real_tensor
->
domain_without_reduce_axis
();
ir
::
Tensor
init_tensor
=
lang
::
Compute
(
domain
,
[
=
](
const
std
::
vector
<
Expr
>&
axis
)
{
return
init_value
;
},
new_init_block_name
);
init_tensor
->
Bind
(
real_tensor
->
buffer
);
Expr
init_stmt
=
ir
::
Store
::
Make
(
init_tensor
,
init_value
,
new_update_stmt_
.
As
<
ir
::
Store
>
()
->
indices
);
new_init_sch_block_
=
ScheduleBlock
::
Make
(
new_init_iter_vars_
,
{},
{},
new_init_block_name
,
init_stmt
);
new_init_block_realize_
=
ScheduleBlockRealize
::
Make
(
new_init_iter_values_
,
new_init_sch_block_
);
new_update_sch_block_
=
ScheduleBlock
::
Make
(
new_iter_vars_
,
{},
{},
new_update_block_name
,
new_update_stmt_
);
new_update_block_realize_
=
ScheduleBlockRealize
::
Make
(
new_iter_values_
,
new_update_sch_block_
);
VLOG
(
4
)
<<
"new_update_block_realize:
\n
"
<<
new_update_block_realize_
;
}
Expr
CreateLoops
()
{
int
num_loops
=
original_loops_
.
size
();
std
::
vector
<
Expr
>
new_loops
(
num_loops
);
Expr
body
=
new_update_block_realize_
;
bool
has_add_init_block
=
false
;
for
(
int
i
=
num_loops
-
1
;
i
>=
0
;
--
i
)
{
bool
is_spatial_loop
=
new_spatial_loop_var_names_
.
count
(
original_loops_
[
i
].
As
<
For
>
()
->
loop_var
->
name
)
>
0
;
bool
is_rf_loop
=
rf_loop_
.
As
<
For
>
()
->
loop_var
->
name
==
original_loops_
[
i
].
As
<
For
>
()
->
loop_var
->
name
;
// Skip non rf reduction loops of write back block.
if
(
!
is_rf_block_
&&
!
is_spatial_loop
&&
!
is_rf_loop
)
{
continue
;
}
// Add reduce init block.
if
(
!
has_add_init_block
&&
is_spatial_loop
)
{
body
=
Block
::
Make
({
new_init_block_realize_
,
body
});
has_add_init_block
=
true
;
}
// Add loops
Var
loop_var
=
ir_utils
::
IRCopy
(
original_loops_
[
i
].
As
<
For
>
()
->
loop_var
);
Expr
min
=
ir_utils
::
IRCopy
(
original_loops_
[
i
].
As
<
For
>
()
->
min
);
Expr
extent
=
ir_utils
::
IRCopy
(
original_loops_
[
i
].
As
<
For
>
()
->
extent
);
body
=
For
::
Make
(
loop_var
,
min
,
extent
,
original_loops_
[
i
].
As
<
For
>
()
->
for_type
(),
original_loops_
[
i
].
As
<
For
>
()
->
device_api
,
body
,
original_loops_
[
i
].
As
<
For
>
()
->
vectorize_info
(),
original_loops_
[
i
].
As
<
For
>
()
->
bind_info
());
VLOG
(
5
)
<<
"new body:
\n
"
<<
body
;
}
VLOG
(
4
)
<<
"new loop nest:
\n
"
<<
body
;
return
body
;
}
private:
virtual
void
CreateRFIter
()
=
0
;
virtual
void
CreateNormalIter
(
int
idx
)
=
0
;
virtual
void
CreateUpdateStmt
()
=
0
;
public:
Var
rf_var_
;
std
::
vector
<
Expr
>
rf_tensor_access_indices_
;
protected:
const
Expr
&
original_block_
;
const
std
::
vector
<
Expr
>&
original_loops_
;
const
Expr
&
rf_loop_
;
const
Expr
&
original_update_stmt_
;
const
ir
::
Tensor
&
rf_tensor_
;
std
::
map
<
Var
,
Expr
,
CompVar
>
original_indice2new_expr_
;
int
num_block_iters_
;
bool
is_rf_block_
;
std
::
vector
<
Var
>
new_iter_vars_
;
std
::
vector
<
Expr
>
new_iter_values_
;
std
::
vector
<
Var
>
new_init_iter_vars_
;
std
::
vector
<
Expr
>
new_init_iter_values_
;
std
::
unordered_set
<
std
::
string
>
new_spatial_loop_var_names_
;
Expr
new_update_stmt_
;
Expr
new_update_sch_block_
;
Expr
new_update_block_realize_
;
Expr
new_init_sch_block_
;
Expr
new_init_block_realize_
;
};
// Implement class for building Reduction-Factorized block,
// only used for FactorizeReduction schedule primitive.
class
RFBlockCreater
:
public
ReduceBlockCreater
{
public:
RFBlockCreater
(
const
Expr
&
original_block
,
const
std
::
vector
<
Expr
>&
original_loops
,
const
Expr
&
rf_loop
,
const
Expr
&
original_update_stmt
,
const
ir
::
Tensor
&
rf_tensor
,
const
std
::
map
<
Var
,
Expr
,
CompVar
>&
var2loops
,
int
rf_axis
)
:
ReduceBlockCreater
(
original_block
,
original_loops
,
rf_loop
,
original_update_stmt
,
rf_tensor
,
true
),
var2loops_
(
var2loops
),
rf_axis_
(
rf_axis
)
{}
private:
void
CreateRFIter
()
override
{
std
::
string
loop_var_name
=
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
->
name
;
std
::
string
rf_var_name
=
"v"
+
loop_var_name
;
rf_var_
=
Var
(
rf_loop_
.
As
<
ir
::
For
>
()
->
min
,
rf_loop_
.
As
<
ir
::
For
>
()
->
extent
,
rf_var_name
,
/* is_reduce = */
false
);
loop_var2block_iters_
[
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
]
=
rf_var_
;
new_iter_vars_
.
push_back
(
rf_var_
);
new_iter_values_
.
push_back
(
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
);
new_init_iter_vars_
.
push_back
(
rf_var_
);
new_init_iter_values_
.
push_back
(
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
);
new_spatial_loop_var_names_
.
insert
(
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
->
name
);
VLOG
(
4
)
<<
"create new_rf_var = "
<<
rf_var_
<<
", with iter value = "
<<
new_iter_values_
.
back
();
}
void
CreateNormalIter
(
int
idx
)
override
{
Var
original_iter_var
=
original_block_
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
[
idx
];
Expr
original_iter_value
=
original_block_
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
iter_values
[
idx
];
// The original iter is either a spatial iter, or a reduction iter that
// doesn't touch the rf loop. In this case reuse the old iter var and its
// corresponding iter value.
if
(
!
original_iter_var
->
is_reduce_axis
)
{
new_iter_vars_
.
push_back
(
original_iter_var
);
new_iter_values_
.
push_back
(
original_iter_value
);
new_init_iter_vars_
.
push_back
(
original_iter_var
);
new_init_iter_values_
.
push_back
(
original_iter_value
);
ir_utils
::
CollectIRNodesWithoutTensor
(
original_iter_value
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
as_var
())
{
new_spatial_loop_var_names_
.
insert
(
x
->
as_var
()
->
name
);
}
return
false
;
});
return
;
}
else
if
(
!
ContainVar
({
original_iter_value
},
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
->
name
))
{
new_iter_vars_
.
push_back
(
original_iter_var
);
new_iter_values_
.
push_back
(
original_iter_value
);
return
;
}
CHECK
(
original_iter_var
->
is_reduce_axis
);
// This iter is a reduction iter and touches the rfactor loop. So we try to
// create a new iter for each loop var that appear in the original iter
// value.
std
::
vector
<
Var
>
vars_in_original_iter_values
;
ir_utils
::
CollectIRNodesWithoutTensor
(
original_iter_value
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
as_var
())
{
vars_in_original_iter_values
.
push_back
(
x
->
as_var_ref
());
}
return
false
;
});
for
(
const
Var
&
loop_var
:
vars_in_original_iter_values
)
{
if
(
var2loops_
.
count
(
loop_var
)
==
0
)
{
continue
;
}
Expr
loop
=
var2loops_
.
at
(
loop_var
);
if
(
loop_var2block_iters_
.
count
(
loop_var
)
==
0
)
{
Var
new_iter_var
(
loop
.
As
<
ir
::
For
>
()
->
min
,
loop
.
As
<
ir
::
For
>
()
->
extent
,
"v"
+
loop_var
->
name
,
/* is_reduce = */
true
);
new_iter_vars_
.
push_back
(
new_iter_var
);
new_iter_values_
.
emplace_back
(
loop_var
);
loop_var2block_iters_
[
loop_var
]
=
new_iter_var
;
}
}
// Substitute the original iter values with new iter vars,
// and store the new iter values in original_indice2new_expr_,
// it will be used in Load/Store indices.
Expr
new_iters
=
ir_utils
::
IRCopy
(
original_iter_value
);
ReplaceExpr
(
&
new_iters
,
loop_var2block_iters_
);
original_indice2new_expr_
[
original_iter_var
]
=
new_iters
;
VLOG
(
4
)
<<
"original_indice2new_expr_["
<<
original_iter_var
<<
"] = "
<<
new_iters
;
}
void
CreateUpdateStmt
()
override
{
rf_tensor_access_indices_
=
original_update_stmt_
.
As
<
ir
::
Store
>
()
->
indices
;
rf_tensor_access_indices_
.
insert
(
rf_tensor_access_indices_
.
begin
()
+
rf_axis_
,
rf_var_
);
Expr
original_store_body
=
original_update_stmt_
.
As
<
ir
::
Store
>
()
->
value
;
Expr
new_store_body
=
ir_utils
::
IRCopy
(
original_store_body
);
#define REPLACE_RF_TENSOR(Op) \
if (new_store_body.As<Op>()) { \
auto* node = new_store_body.As<Op>(); \
CHECK(node); \
auto& operand = node->a(); \
operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \
}
REPLACE_RF_TENSOR
(
Add
)
REPLACE_RF_TENSOR
(
Mul
)
REPLACE_RF_TENSOR
(
Max
)
REPLACE_RF_TENSOR
(
Min
)
REPLACE_RF_TENSOR
(
And
)
REPLACE_RF_TENSOR
(
Or
)
REPLACE_RF_TENSOR
(
LT
)
REPLACE_RF_TENSOR
(
LE
)
REPLACE_RF_TENSOR
(
GT
)
REPLACE_RF_TENSOR
(
GE
)
#undef REPLACE_RF_TENSOR
new_update_stmt_
=
ir
::
Store
::
Make
(
rf_tensor_
,
new_store_body
,
rf_tensor_access_indices_
);
ReplaceExpr
(
&
new_update_stmt_
,
original_indice2new_expr_
);
VLOG
(
4
)
<<
"new_update_stmt of rf block:
\n
"
<<
new_update_stmt_
;
}
private:
const
std
::
map
<
Var
,
Expr
,
CompVar
>&
var2loops_
;
int
rf_axis_
;
std
::
map
<
Var
,
Expr
,
CompVar
>
loop_var2block_iters_
;
};
// Implement class for building Writing-Back block,
// only used for FactorizeReduction schedule primitive.
class
RBBlockCreater
:
public
ReduceBlockCreater
{
public:
RBBlockCreater
(
const
Expr
&
original_block
,
const
std
::
vector
<
Expr
>&
original_loops
,
const
Expr
&
rf_loop
,
const
Expr
&
original_update_stmt
,
const
ir
::
Tensor
&
rf_tensor
,
const
std
::
vector
<
Expr
>&
rf_tensor_access_indices
,
const
Var
&
rf_block_rf_iter_var
)
:
ReduceBlockCreater
(
original_block
,
original_loops
,
rf_loop
,
original_update_stmt
,
rf_tensor
,
false
),
rf_tensor_access_indices_
(
rf_tensor_access_indices
),
rf_block_rf_iter_var_
(
rf_block_rf_iter_var
)
{}
private:
void
CreateRFIter
()
override
{
std
::
string
loop_var_name
=
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
->
name
;
std
::
string
rf_var_name
=
"v"
+
loop_var_name
;
rf_var_
=
Var
(
rf_loop_
.
As
<
ir
::
For
>
()
->
min
,
rf_loop_
.
As
<
ir
::
For
>
()
->
extent
,
rf_var_name
,
/* is_reduce = */
true
);
new_iter_vars_
.
push_back
(
rf_var_
);
new_iter_values_
.
push_back
(
rf_loop_
.
As
<
ir
::
For
>
()
->
loop_var
);
original_indice2new_expr_
[
rf_block_rf_iter_var_
]
=
Expr
(
rf_var_
);
VLOG
(
4
)
<<
"create new_rf_var = "
<<
rf_var_
<<
", with iter value = "
<<
new_iter_values_
.
back
();
}
void
CreateNormalIter
(
int
idx
)
override
{
Var
original_iter_var
=
original_block_
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
[
idx
];
Expr
original_iter_value
=
original_block_
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
iter_values
[
idx
];
if
(
!
original_iter_var
->
is_reduce_axis
)
{
new_iter_vars_
.
push_back
(
original_iter_var
);
new_iter_values_
.
push_back
(
original_iter_value
);
new_init_iter_vars_
.
push_back
(
original_iter_var
);
new_init_iter_values_
.
push_back
(
original_iter_value
);
ir_utils
::
CollectIRNodesWithoutTensor
(
original_iter_value
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
as_var
())
{
new_spatial_loop_var_names_
.
insert
(
x
->
as_var
()
->
name
);
}
return
false
;
});
// original_indice2new_expr_[original_iter_var] = new_iter_vars_.back();
VLOG
(
4
)
<<
"create new iter var = "
<<
new_iter_vars_
.
back
()
<<
", with iter value = "
<<
new_iter_values_
.
back
();
}
}
void
CreateUpdateStmt
()
override
{
Expr
original_store_body
=
original_update_stmt_
.
As
<
ir
::
Store
>
()
->
value
;
Expr
new_store_body
=
ir_utils
::
IRCopy
(
original_store_body
);
#define REPLACE_RF_TENSOR(Op) \
if (new_store_body.As<Op>()) { \
auto* node = new_store_body.As<Op>(); \
CHECK(node); \
auto& operand = node->b(); \
operand = Load::Make(rf_tensor_, rf_tensor_access_indices_); \
}
REPLACE_RF_TENSOR
(
Add
)
REPLACE_RF_TENSOR
(
Mul
)
REPLACE_RF_TENSOR
(
Max
)
REPLACE_RF_TENSOR
(
Min
)
REPLACE_RF_TENSOR
(
And
)
REPLACE_RF_TENSOR
(
Or
)
REPLACE_RF_TENSOR
(
LT
)
REPLACE_RF_TENSOR
(
LE
)
REPLACE_RF_TENSOR
(
GT
)
REPLACE_RF_TENSOR
(
GE
)
#undef REPLACE_RF_TENSOR
Expr
original_store_tensor
=
original_update_stmt_
.
As
<
ir
::
Store
>
()
->
tensor
;
std
::
vector
<
Expr
>
original_store_indices
=
original_update_stmt_
.
As
<
ir
::
Store
>
()
->
indices
;
new_update_stmt_
=
ir
::
Store
::
Make
(
original_store_tensor
,
new_store_body
,
original_store_indices
);
ReplaceExpr
(
&
new_update_stmt_
,
original_indice2new_expr_
);
VLOG
(
4
)
<<
"new_update_stmt of write back block:
\n
"
<<
new_update_stmt_
;
}
private:
const
std
::
vector
<
Expr
>&
rf_tensor_access_indices_
;
const
Var
&
rf_block_rf_iter_var_
;
};
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/schedule/ir_schedule.cc
View file @
01a10755
...
...
@@ -27,56 +27,46 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/dev_info_manager.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/dy_schedule/ir_schedule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_analyzer/ir_analyzer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/factorize_reduction.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/utils/string.h"
DECLARE_int32
(
cinn_error_message_level
);
PD_
DECLARE_int32
(
cinn_error_message_level
);
namespace
cinn
{
namespace
ir
{
/**
* A struct helps to implement Schedule primitives.
* A struct helps to implement
static shape
Schedule primitives.
*/
class
ScheduleImpl
{
class
St
ScheduleImpl
:
public
ScheduleBase
{
public:
ScheduleImpl
()
=
default
;
explicit
ScheduleImpl
(
const
ModuleExpr
&
module_expr
,
bool
debug_flag
=
false
,
utils
::
ErrorMessageLevel
err_msg_level
=
utils
::
ErrorMessageLevel
::
kGeneral
)
:
module_expr_
(
module_expr
),
debug_flag_
(
debug_flag
)
{
err_msg_level_
=
static_cast
<
utils
::
ErrorMessageLevel
>
(
FLAGS_cinn_error_message_level
||
static_cast
<
int
>
(
err_msg_level
));
}
explicit
ScheduleImpl
(
ModuleExpr
&&
module_expr
)
:
module_expr_
(
std
::
move
(
module_expr
))
{}
//! Set the debug flag.
void
SetDebugFlag
(
bool
debug_flag
)
{
debug_flag_
=
debug_flag
;
}
//! Get the ModuleExpr stored in ScheduleImpl.
const
ModuleExpr
&
GetModule
()
const
{
return
module_expr_
;
}
StScheduleImpl
()
=
delete
;
explicit
StScheduleImpl
(
const
ModuleExpr
&
module_expr
,
bool
debug_flag
=
false
,
utils
::
ErrorMessageLevel
err_msg_level
=
utils
::
ErrorMessageLevel
::
kGeneral
)
:
ScheduleBase
(
module_expr
,
false
,
err_msg_level
)
{}
explicit
StScheduleImpl
(
ModuleExpr
&&
module_expr
)
:
ScheduleBase
(
std
::
move
(
module_expr
))
{}
void
MergeExprs
();
void
SetExprs
(
const
std
::
vector
<
Expr
>&
exprs
)
{
module_expr_
.
SetExprs
(
exprs
);
}
bool
HasBlock
(
const
std
::
string
&
block_name
)
const
;
std
::
vector
<
Expr
>
GetLoops
(
const
Expr
&
block
)
const
;
std
::
vector
<
Expr
>
GetLoops
(
const
std
::
string
&
block_name
)
const
;
std
::
vector
<
Expr
>
GetAllBlocks
()
const
;
...
...
@@ -120,6 +110,7 @@ class ScheduleImpl {
void
ReverseComputeInline
(
const
Expr
&
schedule_block
);
void
Bind
(
const
Expr
&
loop
,
const
std
::
string
&
thread_axis
);
Expr
Rfactor
(
const
Expr
&
rf_loop
,
int
rf_axis
);
Expr
FactorizeReduction
(
const
Expr
&
rf_loop
,
int
rf_axis
);
Expr
AddUnitLoop
(
const
Expr
&
block
)
const
;
void
Annotate
(
const
Expr
&
block
,
const
std
::
string
&
key
,
const
attr_t
&
value
);
void
Unannotate
(
Expr
&
block
,
const
std
::
string
&
key
);
// NOLINT
...
...
@@ -131,14 +122,32 @@ class ScheduleImpl {
Expr
SampleCategorical
(
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
,
const
std
::
vector
<
int
>&
candidates
,
const
std
::
vector
<
float
>&
probs
);
};
private:
void
Replace
(
const
Expr
&
src_sref
,
const
Expr
&
tgt_stmt
);
std
::
unique_ptr
<
ScheduleBase
>
ScheduleBase
::
Make
(
const
ModuleExpr
&
module_expr
,
bool
debug_flag
,
utils
::
ErrorMessageLevel
err_msg_level
,
bool
is_dynamic
)
{
if
(
is_dynamic
)
{
return
std
::
make_unique
<
DyScheduleImpl
>
(
module_expr
,
debug_flag
,
err_msg_level
);
}
else
{
return
std
::
make_unique
<
StScheduleImpl
>
(
module_expr
,
debug_flag
,
err_msg_level
);
}
return
nullptr
;
}
ModuleExpr
module_expr_
;
bool
debug_flag_
{
false
};
utils
::
ErrorMessageLevel
err_msg_level_
=
utils
::
ErrorMessageLevel
::
kGeneral
;
};
std
::
unique_ptr
<
ScheduleBase
>
ScheduleBase
::
Make
(
ModuleExpr
&&
module_expr
,
bool
is_dynamic
)
{
if
(
is_dynamic
)
{
return
std
::
make_unique
<
DyScheduleImpl
>
(
std
::
move
(
module_expr
));
}
else
{
return
std
::
make_unique
<
StScheduleImpl
>
(
std
::
move
(
module_expr
));
}
return
nullptr
;
}
/** \brief A macro that guards the beginning of each implementation of schedule
*/
...
...
@@ -156,8 +165,8 @@ class ScheduleImpl {
CINN_THROW(err_hanlder.FormatErrorMessage(err_msg_level)); \
}
std
::
vector
<
Expr
>
ScheduleImpl
::
Split
(
const
Expr
&
loop
,
const
std
::
vector
<
int
>&
factors
)
{
std
::
vector
<
Expr
>
St
ScheduleImpl
::
Split
(
const
Expr
&
loop
,
const
std
::
vector
<
int
>&
factors
)
{
CHECK
(
loop
.
As
<
ir
::
For
>
())
<<
"Expr param of Split must be For node! Please check."
;
auto
*
for_node
=
loop
.
As
<
ir
::
For
>
();
...
...
@@ -189,7 +198,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
new_loop_vars
.
push_back
(
temp_var
);
}
substitute_value
=
common
::
AutoSimplify
(
substitute_value
);
Expr
new_node
=
optim
::
IRCopy
(
for_node
->
body
);
Expr
new_node
=
ir
::
ir_utils
::
IRCopy
(
for_node
->
body
);
ReplaceExpr
(
&
new_node
,
{
for_node
->
loop_var
},
{
substitute_value
});
std
::
vector
<
Expr
>
splited_loops
;
splited_loops
.
resize
(
processed_factors
.
size
());
...
...
@@ -213,7 +222,7 @@ std::vector<Expr> ScheduleImpl::Split(const Expr& loop,
return
splited_loops
;
}
Expr
ScheduleImpl
::
Fuse
(
const
std
::
vector
<
Expr
>&
loops
)
{
Expr
St
ScheduleImpl
::
Fuse
(
const
std
::
vector
<
Expr
>&
loops
)
{
VLOG
(
3
)
<<
"Tring to fuse:
\n
"
<<
cinn
::
utils
::
Join
(
loops
,
"
\n
"
);
std
::
vector
<
const
ir
::
For
*>
for_nodes
;
std
::
vector
<
Var
>
loop_vars
;
...
...
@@ -252,7 +261,7 @@ Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
}
substitute_value
[
0
]
=
fused_expr
;
Expr
fused_body
=
optim
::
IRCopy
(
for_nodes
.
back
()
->
body
);
Expr
fused_body
=
ir
::
ir_utils
::
IRCopy
(
for_nodes
.
back
()
->
body
);
ReplaceExpr
(
&
fused_body
,
loop_vars
,
substitute_value
);
optim
::
Simplify
(
&
fused_body
);
Expr
fused_extent
(
1
);
...
...
@@ -274,8 +283,8 @@ Expr ScheduleImpl::Fuse(const std::vector<Expr>& loops) {
return
new_stmt
;
}
Expr
ScheduleImpl
::
Fuse
(
const
std
::
string
&
block_name
,
const
std
::
vector
<
int
>&
loops_index
)
{
Expr
St
ScheduleImpl
::
Fuse
(
const
std
::
string
&
block_name
,
const
std
::
vector
<
int
>&
loops_index
)
{
std
::
vector
<
Expr
>
all_loops
=
this
->
GetLoops
(
block_name
);
std
::
vector
<
Expr
>
loops_expr
;
loops_expr
.
reserve
(
loops_index
.
size
());
...
...
@@ -293,8 +302,8 @@ Expr ScheduleImpl::Fuse(const std::string& block_name,
return
this
->
Fuse
(
loops_expr
);
}
Expr
ScheduleImpl
::
Fuse
(
const
Expr
&
block
,
const
std
::
vector
<
int
>&
loops_index
)
{
Expr
St
ScheduleImpl
::
Fuse
(
const
Expr
&
block
,
const
std
::
vector
<
int
>&
loops_index
)
{
std
::
vector
<
Expr
>
all_loops
=
this
->
GetLoops
(
block
);
std
::
vector
<
Expr
>
loops_expr
;
loops_expr
.
reserve
(
loops_index
.
size
());
...
...
@@ -312,16 +321,16 @@ Expr ScheduleImpl::Fuse(const Expr& block,
return
this
->
Fuse
(
loops_expr
);
}
void
ScheduleImpl
::
MutateForType
(
const
Expr
&
loop
,
ForType
for_type
,
int
factor
)
{
void
St
ScheduleImpl
::
MutateForType
(
const
Expr
&
loop
,
ForType
for_type
,
int
factor
)
{
auto
*
for_node
=
loop
.
As
<
ir
::
For
>
();
CHECK
(
for_node
)
<<
"loop param must be For node! Please check."
;
CHECK
(
for_node
->
is_serial
())
<<
"loop is not serial, current forloop type is "
<<
static_cast
<
int
>
(
for_node
->
for_type
())
<<
", and it cannot become "
<<
static_cast
<
int
>
(
for_type
);
auto
loop_copy
=
optim
::
IRCopy
(
loop
);
auto
loop_copy
=
ir
::
ir_utils
::
IRCopy
(
loop
);
auto
*
new_for_node
=
loop_copy
.
As
<
ir
::
For
>
();
CHECK
(
new_for_node
);
new_for_node
->
set_for_type
(
for_type
);
...
...
@@ -335,20 +344,21 @@ void ScheduleImpl::MutateForType(const Expr& loop,
this
->
Replace
(
loop
,
loop_copy
);
}
void
ScheduleImpl
::
Parallel
(
const
Expr
&
loop
)
{
void
St
ScheduleImpl
::
Parallel
(
const
Expr
&
loop
)
{
MutateForType
(
loop
,
ForType
::
Parallel
);
}
void
ScheduleImpl
::
Vectorize
(
const
Expr
&
loop
,
int
factor
)
{
void
St
ScheduleImpl
::
Vectorize
(
const
Expr
&
loop
,
int
factor
)
{
CHECK_GT
(
factor
,
0
)
<<
"vectorize factor should be more than 0"
;
MutateForType
(
loop
,
ForType
::
Vectorized
,
factor
);
}
void
ScheduleImpl
::
Unroll
(
const
Expr
&
loop
)
{
void
St
ScheduleImpl
::
Unroll
(
const
Expr
&
loop
)
{
MutateForType
(
loop
,
ForType
::
Unrolled
);
}
void
ScheduleImpl
::
Bind
(
const
Expr
&
loop
,
const
std
::
string
&
thread_axis
)
{
void
StScheduleImpl
::
Bind
(
const
Expr
&
loop
,
const
std
::
string
&
thread_axis
)
{
#ifdef CINN_WITH_CUDA
static
std
::
set
<
std
::
string
>
thread_axes
=
{
"blockIdx.x"
,
"blockIdx.y"
,
"blockIdx.z"
,
...
...
@@ -358,11 +368,24 @@ void ScheduleImpl::Bind(const Expr& loop, const std::string& thread_axis) {
CHECK
(
thread_axes
.
count
(
thread_axis
))
<<
"thread_axis "
<<
thread_axis
<<
" is not supported"
;
int
offset
=
thread_axis
.
back
()
-
'x'
;
auto
cur_dev_info
=
common
::
DevInfoMgr
<
common
::
Target
::
Arch
::
NVGPU
>::
GetDevInfo
(
0
);
const
std
::
array
<
int
,
3
>
kMaxBlockDims
=
cur_dev_info
->
GetMaxBlockDims
();
const
std
::
array
<
int
,
3
>
kMaxGridDims
=
cur_dev_info
->
GetMaxGridDims
();
auto
check_offset
=
[
&
](
const
char
&
c
)
->
bool
{
auto
extent
=
loop
.
As
<
ir
::
For
>
()
->
extent
.
as_int32
();
return
extent
<=
(
c
==
'b'
?
kMaxGridDims
[
offset
]
:
kMaxBlockDims
[
offset
]);
};
if
(
thread_axis
[
0
]
==
'b'
)
{
CHECK
(
check_offset
(
thread_axis
[
0
]))
<<
"Invalid Bind! The extent of loop is out of range on grid size!
\n
"
;
MutateForType
(
loop
,
ForType
::
GPUBlock
,
offset
);
}
else
{
CHECK
(
check_offset
(
thread_axis
[
0
]))
<<
"Invalid Bind! The extent of loop is out of range on block size!
\n
"
;
MutateForType
(
loop
,
ForType
::
GPUThread
,
offset
);
}
#endif
}
// The struct used to mutate new rfactor forloop and its' schedule block.
...
...
@@ -674,7 +697,7 @@ struct RfCreater : public ir::IRMutator<> {
CHECK
(
root_realize
);
auto
root_block
=
root_realize
->
schedule_block
.
As
<
ScheduleBlock
>
();
CHECK
(
root_block
);
Expr
root_loop
=
optim
::
IRCopy
(
root_block
->
body
);
Expr
root_loop
=
ir
::
ir_utils
::
IRCopy
(
root_block
->
body
);
if
(
auto
block
=
root_loop
.
As
<
Block
>
())
{
CHECK_EQ
(
block
->
stmts
.
size
(),
1U
)
<<
"rfactor root should only have one block stmt"
;
...
...
@@ -685,13 +708,13 @@ struct RfCreater : public ir::IRMutator<> {
auto
rf_for
=
rf_loop_
.
As
<
For
>
();
CHECK
(
rf_for
);
// create new rfactor forloops
Expr
new_rf_forloop
=
optim
::
IRCopy
(
root_loop
);
Expr
new_rf_forloop
=
ir
::
ir_utils
::
IRCopy
(
root_loop
);
RfMutator
rf_mutator
(
rf_loop_
,
rf_axis_
);
rf_mutator
(
&
new_rf_forloop
);
VLOG
(
3
)
<<
"After RfMutator, new rf_forloop is
\n
"
<<
new_rf_forloop
;
auto
new_rf_tensor
=
rf_mutator
.
GetNewRfTensor
();
// create final write-back forloops
Expr
final_forloop
=
optim
::
IRCopy
(
root_loop
);
Expr
final_forloop
=
ir
::
ir_utils
::
IRCopy
(
root_loop
);
FinalMutator
final_mutator
(
rf_loop_
,
rf_axis_
,
new_rf_tensor
);
final_mutator
(
&
final_forloop
);
VLOG
(
3
)
<<
"After FinalMuator, final write-back forloop is
\n
"
...
...
@@ -707,7 +730,7 @@ struct RfCreater : public ir::IRMutator<> {
int
rf_axis_
;
};
Expr
ScheduleImpl
::
Rfactor
(
const
Expr
&
rf_loop
,
int
rf_axis
)
{
Expr
St
ScheduleImpl
::
Rfactor
(
const
Expr
&
rf_loop
,
int
rf_axis
)
{
CHECKRfactorValidation
(
rf_loop
,
rf_axis
);
// get root ScheduleBlockRealize
Expr
root
=
GetRootBlock
(
rf_loop
);
...
...
@@ -717,11 +740,84 @@ Expr ScheduleImpl::Rfactor(const Expr& rf_loop, int rf_axis) {
return
rf_create
.
CreateRfAllStmts
();
}
Expr
StScheduleImpl
::
FactorizeReduction
(
const
Expr
&
rf_loop
,
int
rf_axis
)
{
std
::
string
primitive
=
"FactorizeReduction"
;
// Get child block of the rf_loop and check.
std
::
vector
<
Expr
>
blocks
=
GetChildBlocks
(
rf_loop
);
if
(
blocks
.
size
()
!=
1
)
{
std
::
ostringstream
os
;
os
<<
"The rf_loop is required to have only one child block, but got "
<<
blocks
.
size
()
<<
std
::
endl
;
throw
IRScheduleErrorHandler
(
primitive
,
os
.
str
(),
this
->
module_expr_
);
}
Expr
original_block
=
blocks
.
at
(
0
);
Expr
root_block
=
GetRootBlock
(
original_block
);
// TODO(BiynXu): Add CheckReductionBlock()
// Collect the loops of the block.
// Construct a map from loop var names to corresponding loops.
std
::
vector
<
Expr
>
original_loops
=
this
->
GetLoops
(
original_block
);
CHECK_GT
(
original_loops
.
size
(),
0
);
VLOG
(
3
)
<<
"before FactorizeReduction, original computational body of the "
"reduction is:
\n
"
<<
original_loops
[
0
];
std
::
map
<
Var
,
Expr
,
CompVar
>
var2loops
;
for
(
const
Expr
&
loop
:
original_loops
)
{
var2loops
[
loop
.
As
<
For
>
()
->
loop_var
]
=
loop
;
}
// Get original stmt of reduction update and original store tensor.
Expr
original_update_body
=
original_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
;
Expr
original_update_stmt
;
CHECK
(
original_update_body
.
As
<
Block
>
()
||
original_update_body
.
As
<
Store
>
());
if
(
original_update_body
.
As
<
Block
>
())
{
CHECK_EQ
(
original_update_body
.
As
<
Block
>
()
->
stmts
.
size
(),
1
);
original_update_stmt
=
original_update_body
.
As
<
Block
>
()
->
stmts
[
0
];
}
else
if
(
original_update_body
.
As
<
Store
>
())
{
original_update_stmt
=
original_update_body
;
}
Tensor
original_tensor
=
original_update_stmt
.
As
<
Store
>
()
->
tensor
.
as_tensor_ref
();
// Create new blocks and loops.
Tensor
rf_tensor
=
CreateRFTensor
(
original_tensor
,
rf_loop
,
rf_axis
);
RFBlockCreater
rf_block_creater
(
original_block
,
original_loops
,
rf_loop
,
original_update_stmt
,
rf_tensor
,
var2loops
,
rf_axis
);
rf_block_creater
.
CreateBlock
();
RBBlockCreater
wb_block_creater
(
original_block
,
original_loops
,
rf_loop
,
original_update_stmt
,
rf_tensor
,
rf_block_creater
.
rf_tensor_access_indices_
,
rf_block_creater
.
rf_var_
);
wb_block_creater
.
CreateBlock
();
Expr
rf_body
=
rf_block_creater
.
CreateLoops
();
Expr
wb_body
=
wb_block_creater
.
CreateLoops
();
Expr
new_computational_body
=
Block
::
Make
({
rf_body
,
wb_body
});
// Replace and update the AST.
this
->
Replace
(
original_loops
[
0
],
new_computational_body
);
VLOG
(
3
)
<<
"After FactorizeReduction, new computational body of the "
"reduction is:
\n
"
<<
new_computational_body
;
return
rf_tensor
;
}
struct
CacheReadRewriter
:
public
ir
::
IRMutator
<>
{
public:
static
Expr
Rewrite
(
const
Expr
&
root
,
CacheBlockInfo
*
info
)
{
CacheReadRewriter
rewriter
(
root
,
info
);
Expr
new_root
=
optim
::
IRCopy
(
root
);
Expr
new_root
=
ir
::
ir_utils
::
IRCopy
(
root
);
rewriter
(
&
new_root
);
return
new_root
;
}
...
...
@@ -762,12 +858,12 @@ struct CacheWriteRewriter : public ir::IRMutator<> {
public:
static
Expr
Rewrite
(
const
Expr
&
root
,
CacheBlockInfo
*
info
)
{
CacheWriteRewriter
rewriter
(
root
,
info
);
Expr
new_root
=
optim
::
IRCopy
(
root
);
Expr
new_root
=
ir
::
ir_utils
::
IRCopy
(
root
);
rewriter
.
mutate_cache_block
=
true
;
rewriter
(
&
info
->
cache_block
);
rewriter
.
mutate_cache_block
=
false
;
rewriter
(
&
new_root
);
auto
find_tensor
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_tensor
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
new_root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
Store
>
()
&&
...
...
@@ -775,7 +871,7 @@ struct CacheWriteRewriter : public ir::IRMutator<> {
},
true
);
if
(
!
find_tensor
.
empty
())
{
auto
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
(
*
find_tensor
.
begin
()),
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
Load
>
()
&&
(
x
->
As
<
Load
>
()
->
tensor
==
Expr
(
info
->
write_tensor
));
...
...
@@ -862,17 +958,14 @@ struct ChangeBodyToBlock : public ir::IRMutator<> {
}
};
DeviceAPI
ScheduleImpl
::
GetDeviceAPI
()
const
{
DeviceAPI
St
ScheduleImpl
::
GetDeviceAPI
()
const
{
auto
exprs
=
this
->
GetModule
().
GetExprs
();
auto
find_for_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
exprs
.
front
(),
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
();
},
true
);
CHECK
(
!
find_for_nodes
.
empty
());
return
(
*
find_for_nodes
.
begin
()).
As
<
ir
::
For
>
()
->
device_api
;
return
analyzer
::
GetDeviceAPI
(
exprs
);
}
Expr
ScheduleImpl
::
CacheRead
(
const
Expr
&
block
,
int
read_tensor_index
,
const
std
::
string
&
memory_type
)
{
Expr
St
ScheduleImpl
::
CacheRead
(
const
Expr
&
block
,
int
read_tensor_index
,
const
std
::
string
&
memory_type
)
{
CHECK
(
block
.
As
<
ScheduleBlockRealize
>
());
auto
root
=
GetRootBlock
(
block
);
ChangeBodyToBlock
::
Change
(
&
root
);
...
...
@@ -898,9 +991,9 @@ Expr ScheduleImpl::CacheRead(const Expr& block,
return
new_block
;
}
Expr
ScheduleImpl
::
CacheWrite
(
const
Expr
&
block
,
int
write_buffer_index
,
const
std
::
string
&
memory_type
)
{
Expr
St
ScheduleImpl
::
CacheWrite
(
const
Expr
&
block
,
int
write_buffer_index
,
const
std
::
string
&
memory_type
)
{
CHECK
(
block
.
As
<
ScheduleBlockRealize
>
());
auto
root
=
GetRootBlock
(
block
);
ChangeBodyToBlock
::
Change
(
&
root
);
...
...
@@ -925,7 +1018,7 @@ Expr ScheduleImpl::CacheWrite(const Expr& block,
->
schedule_block
.
As
<
ScheduleBlock
>
()
->
body
);
auto
find_cache_block
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_cache_block
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
...
...
@@ -937,9 +1030,10 @@ Expr ScheduleImpl::CacheWrite(const Expr& block,
CHECK
(
info
.
write_tensor
->
buffer
.
defined
());
// Replace buffer
auto
all_tensors
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
();
});
auto
all_tensors
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
();
});
for
(
auto
i
:
all_tensors
)
{
if
(
i
.
as_tensor
()
->
name
!=
info
.
write_tensor
->
name
&&
...
...
@@ -1007,7 +1101,7 @@ struct InsertExpr : public ir::IRMutator<> {
bool
after_node_
;
};
void
ScheduleImpl
::
SyncThreads
(
const
Expr
&
ir_node
,
bool
after_node
)
{
void
St
ScheduleImpl
::
SyncThreads
(
const
Expr
&
ir_node
,
bool
after_node
)
{
CHECK
(
ir_node
.
As
<
ScheduleBlockRealize
>
()
||
ir_node
.
As
<
ir
::
For
>
());
auto
root
=
GetRootBlock
(
ir_node
);
ChangeBodyToBlock
::
Change
(
&
root
);
...
...
@@ -1016,60 +1110,7 @@ void ScheduleImpl::SyncThreads(const Expr& ir_node, bool after_node) {
return
;
}
/**
* Replace a For node to another For node.
* @param src_sref The For node to be changed.
* @param tgt_stmt The For node we want.
*/
void
ScheduleImpl
::
Replace
(
const
Expr
&
src_sref
,
const
Expr
&
tgt_stmt
)
{
CHECK
(
src_sref
.
As
<
ir
::
For
>
()
||
src_sref
.
As
<
ir
::
Block
>
()
||
src_sref
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
tgt_stmt
.
As
<
ir
::
For
>
()
||
tgt_stmt
.
As
<
ir
::
Block
>
()
||
tgt_stmt
.
As
<
ir
::
ScheduleBlockRealize
>
());
if
(
src_sref
==
tgt_stmt
)
{
return
;
}
struct
ForLoopMutator
:
public
ir
::
IRMutator
<>
{
ForLoopMutator
(
const
Expr
&
source
,
const
Expr
&
target
)
:
source_
(
source
),
target_
(
target
)
{}
void
operator
()(
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
For
*
op
,
Expr
*
expr
)
override
{
if
(
*
expr
==
source_
)
{
*
expr
=
target_
;
return
;
}
ir
::
IRMutator
<>::
Visit
(
op
,
expr
);
}
void
Visit
(
const
ir
::
ScheduleBlockRealize
*
op
,
Expr
*
expr
)
override
{
if
(
*
expr
==
source_
)
{
*
expr
=
target_
;
return
;
}
ir
::
IRMutator
<>::
Visit
(
op
,
expr
);
}
void
Visit
(
const
ir
::
Block
*
op
,
Expr
*
expr
)
override
{
if
(
*
expr
==
source_
)
{
*
expr
=
target_
;
return
;
}
ir
::
IRMutator
<>::
Visit
(
op
,
expr
);
}
const
Expr
&
source_
;
const
Expr
&
target_
;
};
auto
exprs
=
module_expr_
.
GetExprs
();
ForLoopMutator
mutator
(
src_sref
,
tgt_stmt
);
for
(
auto
&
i
:
exprs
)
{
mutator
(
&
i
);
}
}
Expr
ScheduleImpl
::
Reorder
(
const
std
::
vector
<
Expr
>&
loops
)
{
Expr
StScheduleImpl
::
Reorder
(
const
std
::
vector
<
Expr
>&
loops
)
{
if
(
loops
.
size
()
<=
1
)
{
return
Expr
{
nullptr
};
}
...
...
@@ -1088,8 +1129,8 @@ Expr ScheduleImpl::Reorder(const std::vector<Expr>& loops) {
return
new_loop
;
}
Expr
ScheduleImpl
::
Reorder
(
const
std
::
string
&
block_name
,
const
std
::
vector
<
int
>&
loops_index
)
{
Expr
St
ScheduleImpl
::
Reorder
(
const
std
::
string
&
block_name
,
const
std
::
vector
<
int
>&
loops_index
)
{
std
::
vector
<
Expr
>
all_loops
=
this
->
GetLoops
(
block_name
);
std
::
vector
<
Expr
>
loops_expr
;
loops_expr
.
reserve
(
loops_index
.
size
());
...
...
@@ -1102,8 +1143,8 @@ Expr ScheduleImpl::Reorder(const std::string& block_name,
return
this
->
Reorder
(
loops_expr
);
}
Expr
ScheduleImpl
::
Reorder
(
const
Expr
&
block
,
const
std
::
vector
<
int
>&
loops_index
)
{
Expr
St
ScheduleImpl
::
Reorder
(
const
Expr
&
block
,
const
std
::
vector
<
int
>&
loops_index
)
{
std
::
vector
<
Expr
>
all_loops
=
this
->
GetLoops
(
block
);
std
::
vector
<
Expr
>
loops_expr
;
loops_expr
.
reserve
(
loops_index
.
size
());
...
...
@@ -1116,25 +1157,9 @@ Expr ScheduleImpl::Reorder(const Expr& block,
return
this
->
Reorder
(
loops_expr
);
}
Expr
ScheduleImpl
::
GetRootBlock
(
const
Expr
&
expr
)
const
{
Expr
St
ScheduleImpl
::
GetRootBlock
(
const
Expr
&
expr
)
const
{
auto
exprs
=
this
->
GetModule
().
GetExprs
();
for
(
auto
&
it_expr
:
exprs
)
{
auto
find_expr
=
ir
::
CollectIRNodesWithoutTensor
(
it_expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
node_type
()
==
expr
.
node_type
()
&&
*
x
==
expr
;
},
true
);
if
(
!
find_expr
.
empty
())
{
CHECK
(
it_expr
.
As
<
ir
::
Block
>
());
CHECK_EQ
(
it_expr
.
As
<
ir
::
Block
>
()
->
stmts
.
size
(),
1U
);
CHECK
(
it_expr
.
As
<
ir
::
Block
>
()
->
stmts
[
0
].
As
<
ir
::
ScheduleBlockRealize
>
());
return
it_expr
.
As
<
ir
::
Block
>
()
->
stmts
[
0
];
}
}
LOG
(
FATAL
)
<<
"Didn't find expr
\n
"
<<
expr
<<
"in ScheduleImpl:
\n
"
<<
exprs
[
0
];
return
analyzer
::
GetRootBlock
(
exprs
,
expr
);
}
// The struct used to reconstruct the new For node to replace the old For node.
...
...
@@ -1193,25 +1218,26 @@ struct LoopReconstructor : public ir::IRMutator<> {
loop_
.
As
<
ir
::
For
>
()
->
device_api
,
std
::
move
(
loop_body
));
}
new_loop_
=
optim
::
IRCopy
(
loop_
);
new_loop_
=
ir
::
ir_utils
::
IRCopy
(
loop_
);
// Replace the copied Tensor object with the original Tensor object,
// to ensure that the same Tensor in a AST is the same object.
std
::
unordered_map
<
std
::
string
,
ir
::
Expr
>
tensors_map
;
ir
::
CollectIRNodesWithoutTensor
(
loop_
,
[
&
tensors_map
](
const
Expr
*
x
)
{
if
(
x
->
as_tensor
())
{
tensors_map
.
insert
({
x
->
as_tensor
()
->
name
,
*
x
});
return
true
;
}
return
false
;
});
auto
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
loop_
,
[
&
tensors_map
](
const
Expr
*
x
)
{
if
(
x
->
as_tensor
())
{
tensors_map
.
insert
({
x
->
as_tensor
()
->
name
,
*
x
});
return
true
;
}
return
false
;
});
auto
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
new_loop_
,
[](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
for
(
auto
store
:
find_store
)
{
store
.
As
<
ir
::
Store
>
()
->
tensor
=
tensors_map
.
at
(
store
.
As
<
ir
::
Store
>
()
->
tensor
.
as_tensor
()
->
name
);
}
auto
find_load
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
new_loop_
,
[](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
();
});
for
(
auto
load
:
find_load
)
{
load
.
As
<
ir
::
Load
>
()
->
tensor
=
...
...
@@ -1271,11 +1297,11 @@ struct FixLocalBufferSize : public ir::IRMutator<> {
std
::
string
tensor_name_
;
};
void
ScheduleImpl
::
SetBuffer
(
Expr
&
block
,
const
std
::
string
&
memory_type
,
bool
fixed
)
{
void
St
ScheduleImpl
::
SetBuffer
(
Expr
&
block
,
const
std
::
string
&
memory_type
,
bool
fixed
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
auto
find_tensor
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_tensor
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
},
true
);
CHECK_EQ
(
find_tensor
.
size
(),
1U
)
<<
"One block should only have one Store node!(except for root block)"
;
...
...
@@ -1286,7 +1312,7 @@ void ScheduleImpl::SetBuffer(Expr& block,
auto
exprs
=
this
->
GetModule
().
GetExprs
();
for
(
auto
&
it_expr
:
exprs
)
{
auto
find_tensor
=
ir
::
CollectIRNodesWithoutTensor
(
it_expr
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
it_expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
(
x
->
as_tensor
()
->
name
==
tensor
.
as_tensor_ref
()
->
name
||
x
->
as_tensor
()
->
name
==
...
...
@@ -1308,7 +1334,7 @@ void ScheduleImpl::SetBuffer(Expr& block,
}
}
void
ScheduleImpl
::
MergeExprs
()
{
void
St
ScheduleImpl
::
MergeExprs
()
{
auto
exprs
=
this
->
GetModule
().
GetExprs
();
if
(
exprs
.
size
()
==
1U
)
return
;
CHECK
(
exprs
[
0
].
As
<
ir
::
Block
>
());
...
...
@@ -1328,7 +1354,7 @@ void ScheduleImpl::MergeExprs() {
->
body
);
VLOG
(
3
)
<<
"Before merging, exprs[0] is : "
<<
exprs
[
0
];
for
(
int
i
=
1
;
i
<
exprs
.
size
();
++
i
)
{
auto
root_block
=
ir
::
CollectIRNodesWithoutTensor
(
auto
root_block
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
exprs
[
i
],
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
...
...
@@ -1358,9 +1384,9 @@ void ScheduleImpl::MergeExprs() {
this
->
SetExprs
(
exprs
);
}
void
ScheduleImpl
::
ComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
,
bool
keep_unit_loops
)
{
void
St
ScheduleImpl
::
ComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
,
bool
keep_unit_loops
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
loop
.
As
<
ir
::
For
>
());
Expr
root
=
this
->
GetRootBlock
(
block
);
...
...
@@ -1386,7 +1412,7 @@ void ScheduleImpl::ComputeAt(const Expr& block,
VLOG
(
3
)
<<
"After SimpleComputeAt, ir is:
\n
"
<<
reconstructor
.
new_loop_
;
}
void
ScheduleImpl
::
SimpleComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
)
{
void
St
ScheduleImpl
::
SimpleComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
loop
.
As
<
ir
::
For
>
());
std
::
vector
<
Expr
>
block_loops
=
this
->
GetLoops
(
block
);
...
...
@@ -1429,15 +1455,15 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
}
Expr
result
=
loops
.
size
()
<
block_loops
.
size
()
?
optim
::
IRCopy
(
block_loops
[
loops
.
size
()])
:
optim
::
IRCopy
(
this_block
);
Expr
new_loop
=
optim
::
IRCopy
(
this_loop
);
?
ir
::
ir_utils
::
IRCopy
(
block_loops
[
loops
.
size
()])
:
ir
::
ir_utils
::
IRCopy
(
this_block
);
Expr
new_loop
=
ir
::
ir_utils
::
IRCopy
(
this_loop
);
// Get the body of block_loop under the same loops
auto
body
=
block_loops
.
at
(
loops
.
size
()
-
1
).
As
<
ir
::
For
>
()
->
body
;
// collect if
auto
if_checker
=
[](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
IfThenElse
>
();
};
auto
if_set
=
ir
::
CollectIRNodesWithoutTensor
(
body
,
if_checker
);
auto
if_set
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
if_checker
);
for
(
auto
if_expr
:
if_set
)
{
auto
checker
=
[
block_name
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
...
...
@@ -1445,7 +1471,8 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
->
schedule_block
.
As
<
ScheduleBlock
>
()
->
name
==
block_name
;
};
if
(
ir
::
CollectIRNodesWithoutTensor
(
if_expr
,
checker
,
true
).
size
()
>
0
)
{
if
(
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
if_expr
,
checker
,
true
)
.
size
()
>
0
)
{
result
=
IfThenElse
::
Make
(
if_expr
.
As
<
ir
::
IfThenElse
>
()
->
condition
,
result
);
break
;
...
...
@@ -1498,9 +1525,9 @@ void ScheduleImpl::SimpleComputeAt(const Expr& block, const Expr& loop) {
VLOG
(
3
)
<<
"After SimpleComputeAt, ir is:
\n
"
<<
new_loop
;
}
void
ScheduleImpl
::
ReverseComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
,
bool
keep_unit_loops
)
{
void
St
ScheduleImpl
::
ReverseComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
,
bool
keep_unit_loops
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
loop
.
As
<
ir
::
For
>
());
Expr
root
=
this
->
GetRootBlock
(
block
);
...
...
@@ -1582,7 +1609,7 @@ bool ComputeInliner::BodyPatternAllowInline() {
return
false
;
}
CHECK
(
inlined_store_
.
As
<
Store
>
());
auto
find_vars
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_vars
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
inlined_store_
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_var
();
});
std
::
set
<
Var
,
CompVar
>
vars_set
;
for
(
auto
&
i
:
find_vars
)
vars_set
.
insert
(
i
.
as_var_ref
());
...
...
@@ -1605,12 +1632,12 @@ void ComputeInliner::Visit(const ir::Load* expr, Expr* op) {
Expr
ComputeInliner
::
ReplaceInlinedTensor
(
Expr
*
load
)
{
CHECK
(
load
->
As
<
ir
::
Load
>
());
SetIndexSubstitution
(
load
->
As
<
ir
::
Load
>
()
->
indices
);
Expr
value_copy
=
optim
::
IRCopy
(
inlined_store_
.
As
<
Store
>
()
->
value
);
Expr
value_copy
=
ir
::
ir_utils
::
IRCopy
(
inlined_store_
.
As
<
Store
>
()
->
value
);
ReplaceExpr
(
&
value_copy
,
idx_sub_var_
,
idx_sub_expr_
);
return
value_copy
;
}
void
ScheduleImpl
::
ComputeInline
(
const
Expr
&
schedule_block
)
{
void
St
ScheduleImpl
::
ComputeInline
(
const
Expr
&
schedule_block
)
{
CHECK
(
schedule_block
.
As
<
ir
::
ScheduleBlockRealize
>
());
Expr
root
=
this
->
GetRootBlock
(
schedule_block
);
Expr
store
=
CheckComputeInlineValidationAndGetStore
(
schedule_block
,
root
);
...
...
@@ -1650,7 +1677,7 @@ bool ReverseComputeInliner::BodyPatternAllowInline() {
CHECK
(
inlined_store_
.
As
<
Store
>
());
CHECK
(
inlined_load_
.
As
<
Load
>
());
CHECK
(
target_store_
.
As
<
Store
>
());
auto
find_vars
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_vars
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
inlined_store_
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_var
();
});
std
::
set
<
Var
,
CompVar
>
vars_set
;
for
(
auto
&
i
:
find_vars
)
vars_set
.
insert
(
i
.
as_var_ref
());
...
...
@@ -1681,7 +1708,7 @@ void ReverseComputeInliner::Visit(const ir::Store* expr, Expr* op) {
Expr
ReverseComputeInliner
::
ReplaceInlinedTensor
(
Expr
*
load
)
{
CHECK
(
load
->
As
<
ir
::
Load
>
());
SetIndexSubstitution
(
load
->
As
<
ir
::
Load
>
()
->
indices
);
Expr
value_copy
=
optim
::
IRCopy
(
inlined_store_
.
As
<
Store
>
()
->
value
);
Expr
value_copy
=
ir
::
ir_utils
::
IRCopy
(
inlined_store_
.
As
<
Store
>
()
->
value
);
return
value_copy
;
}
...
...
@@ -1696,12 +1723,12 @@ Expr ReverseComputeInliner::ReplaceTargetTensor(Expr* store) {
idx_sub_expr_
.
emplace_back
(
idx_vars_
[
i
]);
}
Expr
value_copy
=
optim
::
IRCopy
(
target_store_
);
Expr
value_copy
=
ir
::
ir_utils
::
IRCopy
(
target_store_
);
ReplaceExpr
(
&
value_copy
,
idx_sub_var_
,
idx_sub_expr_
);
return
value_copy
;
}
void
ScheduleImpl
::
ReverseComputeInline
(
const
Expr
&
schedule_block
)
{
void
St
ScheduleImpl
::
ReverseComputeInline
(
const
Expr
&
schedule_block
)
{
Expr
root
=
this
->
GetRootBlock
(
schedule_block
);
auto
exprs
=
CheckReverseComputeInlineValidationAndGetExprs
(
schedule_block
,
root
);
...
...
@@ -1777,170 +1804,55 @@ struct FindBlockParent : public ir::IRMutator<> {
ir
::
Expr
*
target_
{
nullptr
};
};
Expr
ScheduleImpl
::
AddUnitLoop
(
const
Expr
&
block
)
const
{
Expr
St
ScheduleImpl
::
AddUnitLoop
(
const
Expr
&
block
)
const
{
auto
exprs
=
module_expr_
.
GetExprs
();
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
std
::
string
block_name
=
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
FindBlockParent
visitor
(
block_name
);
for
(
auto
expr
:
exprs
)
{
visitor
(
&
expr
);
if
(
visitor
.
target_
)
{
break
;
}
}
CHECK
(
visitor
.
target_
)
<<
", block name : "
<<
block_name
<<
"
\n
"
<<
exprs
;
if
(
visitor
.
target_
->
As
<
ir
::
Block
>
())
{
for
(
auto
&
stmt
:
visitor
.
target_
->
As
<
ir
::
Block
>
()
->
stmts
)
{
if
(
stmt
.
As
<
ir
::
ScheduleBlockRealize
>
())
{
if
(
stmt
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
==
block_name
)
{
auto
block
=
ir
::
Block
::
Make
({
GetBlock
(
block_name
)});
auto
loop
=
ir
::
For
::
Make
(
ir
::
Var
(
common
::
UniqName
(
"ix"
)),
ir
::
Expr
(
0
),
ir
::
Expr
(
1
),
ir
::
ForType
::
Serial
,
ir
::
DeviceAPI
::
UNK
,
block
);
stmt
=
loop
;
return
loop
;
}
}
}
}
else
if
(
visitor
.
target_
->
As
<
ir
::
For
>
())
{
auto
block
=
ir
::
Block
::
Make
({
visitor
.
target_
->
As
<
ir
::
For
>
()
->
body
});
auto
loop
=
ir
::
For
::
Make
(
ir
::
Var
(
common
::
UniqName
(
"ix"
)),
ir
::
Expr
(
0
),
ir
::
Expr
(
1
),
ir
::
ForType
::
Serial
,
ir
::
DeviceAPI
::
UNK
,
block
);
visitor
.
target_
->
As
<
ir
::
For
>
()
->
body
=
loop
;
return
loop
;
}
else
if
(
visitor
.
target_
->
As
<
ir
::
ScheduleBlock
>
())
{
auto
block
=
ir
::
Block
::
Make
({
visitor
.
target_
->
As
<
ir
::
ScheduleBlock
>
()
->
body
});
auto
loop
=
ir
::
For
::
Make
(
ir
::
Var
(
common
::
UniqName
(
"ix"
)),
ir
::
Expr
(
0
),
ir
::
Expr
(
1
),
ir
::
ForType
::
Serial
,
ir
::
DeviceAPI
::
UNK
,
block
);
visitor
.
target_
->
As
<
ir
::
ScheduleBlock
>
()
->
body
=
loop
;
return
loop
;
}
else
{
LOG
(
FATAL
)
<<
"Can't find block's parent!"
;
}
LOG
(
FATAL
)
<<
"Shouldn't reach code here in AddUnitLoop"
;
return
Expr
{
nullptr
};
return
analyzer
::
AddUnitLoop
(
exprs
,
block
);
}
std
::
vector
<
Expr
>
ScheduleImpl
::
GetLoops
(
const
Expr
&
block
)
const
{
std
::
vector
<
Expr
>
result
;
std
::
vector
<
Expr
>
StScheduleImpl
::
GetLoops
(
const
Expr
&
block
)
const
{
auto
exprs
=
module_expr_
.
GetExprs
();
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
std
::
string
block_name
=
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
for
(
auto
&
it_expr
:
exprs
)
{
ir
::
FindLoopsVisitor
visitor
(
block
);
auto
find_loops
=
visitor
(
&
it_expr
);
if
(
!
find_loops
.
empty
())
{
if
(
!
result
.
empty
())
LOG
(
FATAL
)
<<
"Find block with name:
\n
"
<<
block_name
<<
" appeared in more than one AST!"
;
result
=
find_loops
;
}
}
if
(
result
.
empty
())
{
result
.
push_back
(
AddUnitLoop
(
block
));
}
return
result
;
return
analyzer
::
GetLoops
(
exprs
,
block
);
}
std
::
vector
<
Expr
>
ScheduleImpl
::
GetLoops
(
const
std
::
string
&
block_name
)
const
{
Expr
block
=
this
->
GetBlock
(
block_name
)
;
std
::
vector
<
Expr
>
result
=
this
->
GetLoops
(
block
);
return
result
;
std
::
vector
<
Expr
>
St
ScheduleImpl
::
GetLoops
(
const
std
::
string
&
block_name
)
const
{
auto
exprs
=
module_expr_
.
GetExprs
(
);
return
analyzer
::
GetLoops
(
exprs
,
block_name
)
;
}
std
::
vector
<
Expr
>
ScheduleImpl
::
GetAllBlocks
()
const
{
std
::
vector
<
Expr
>
result
;
std
::
vector
<
Expr
>
StScheduleImpl
::
GetAllBlocks
()
const
{
auto
exprs
=
module_expr_
.
GetExprs
();
for
(
auto
&
it_expr
:
exprs
)
{
ir
::
FindBlocksVisitor
visitor
;
auto
find_blocks
=
visitor
(
&
it_expr
);
result
.
insert
(
result
.
end
(),
find_blocks
.
begin
(),
find_blocks
.
end
());
}
for
(
auto
&
it_expr
:
exprs
)
{
VLOG
(
3
)
<<
"it_expr is : "
<<
it_expr
;
}
CHECK
(
!
result
.
empty
())
<<
"Didn't find blocks in expr."
;
return
result
;
return
analyzer
::
GetAllBlocks
(
exprs
);
}
std
::
vector
<
Expr
>
ScheduleImpl
::
GetChildBlocks
(
const
Expr
&
expr
)
const
{
CHECK
(
expr
.
As
<
ir
::
ScheduleBlockRealize
>
()
||
expr
.
As
<
ir
::
For
>
());
ir
::
FindBlocksVisitor
visitor
;
std
::
vector
<
Expr
>
result
=
visitor
(
&
expr
);
return
result
;
std
::
vector
<
Expr
>
StScheduleImpl
::
GetChildBlocks
(
const
Expr
&
expr
)
const
{
return
analyzer
::
GetChildBlocks
(
expr
);
}
bool
ScheduleImpl
::
HasBlock
(
const
std
::
string
&
block_name
)
const
{
bool
St
ScheduleImpl
::
HasBlock
(
const
std
::
string
&
block_name
)
const
{
auto
exprs
=
module_expr_
.
GetExprs
();
for
(
auto
&
it_expr
:
exprs
)
{
ir
::
FindBlocksVisitor
visitor
(
block_name
);
auto
find_blocks
=
visitor
(
&
it_expr
);
if
(
!
find_blocks
.
empty
())
{
CHECK_EQ
(
find_blocks
.
size
(),
1U
)
<<
"There should not be more than 1 block with identical name!"
;
return
true
;
}
}
return
false
;
return
analyzer
::
HasBlock
(
exprs
,
block_name
);
}
Expr
ScheduleImpl
::
GetBlock
(
const
std
::
string
&
block_name
)
const
{
Expr
result
;
Expr
StScheduleImpl
::
GetBlock
(
const
std
::
string
&
block_name
)
const
{
auto
exprs
=
module_expr_
.
GetExprs
();
for
(
auto
&
it_expr
:
exprs
)
{
ir
::
FindBlocksVisitor
visitor
(
block_name
);
auto
find_blocks
=
visitor
(
&
it_expr
);
if
(
!
find_blocks
.
empty
())
{
CHECK_EQ
(
find_blocks
.
size
(),
1U
)
<<
"There should not be more than 1 block with identical name!"
;
result
=
find_blocks
[
0
];
return
result
;
}
}
LOG
(
FATAL
)
<<
"Didn't find a block with name "
<<
block_name
<<
" in this ModuleExpr!"
;
return
analyzer
::
GetBlock
(
exprs
,
block_name
);
}
void
ScheduleImpl
::
Annotate
(
const
Expr
&
block
,
const
std
::
string
&
key
,
const
attr_t
&
value
)
{
void
St
ScheduleImpl
::
Annotate
(
const
Expr
&
block
,
const
std
::
string
&
key
,
const
attr_t
&
value
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
auto
copied_block
=
optim
::
IRCopy
(
block
);
auto
copied_block
=
ir
::
ir_utils
::
IRCopy
(
block
);
auto
*
schedule_block
=
copied_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
schedule_block
->
attrs
.
emplace
(
key
,
value
);
this
->
Replace
(
block
,
copied_block
);
}
void
ScheduleImpl
::
Unannotate
(
Expr
&
block
,
const
std
::
string
&
ann_key
)
{
void
St
ScheduleImpl
::
Unannotate
(
Expr
&
block
,
const
std
::
string
&
ann_key
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
...
...
@@ -1954,8 +1866,8 @@ void ScheduleImpl::Unannotate(Expr& block, const std::string& ann_key) {
}
}
void
ScheduleImpl
::
FlattenLoops
(
const
std
::
vector
<
Expr
>&
loops
,
const
bool
flat_tensor
)
{
void
St
ScheduleImpl
::
FlattenLoops
(
const
std
::
vector
<
Expr
>&
loops
,
const
bool
flat_tensor
)
{
CHECK_GT
(
loops
.
size
(),
0
)
<<
"Loops can't be empty!"
;
VLOG
(
4
)
<<
"Before FlattenLoops, ir is:
\n
"
<<
loops
[
0
];
// compute loop
...
...
@@ -2031,12 +1943,12 @@ void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
CHECK_EQ
(
iter
.
as_var_ref
()
->
name
,
loop_vars
[
idx
]
->
name
)
<<
"loops is not the same order with tensor!"
;
}
else
{
CHECK
(
iter
.
As
<
IntImm
>
());
CHECK
(
iter
.
As
<
IntImm
>
())
<<
iter
.
node_type
()
<<
" is not IntImm"
;
CHECK_EQ
(
iter
.
as_int32
(),
0
);
}
}
auto
exprs
=
ir
::
CollectIRNodesInOrder
(
auto
exprs
=
ir
::
ir_utils
::
CollectIRNodesInOrder
(
schedule_block
->
body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
||
x
->
As
<
ir
::
Load
>
();
});
// reverse exprs from last to first.
...
...
@@ -2136,15 +2048,15 @@ void ScheduleImpl::FlattenLoops(const std::vector<Expr>& loops,
VLOG
(
4
)
<<
"After FlattenLoops, ir is:
\n
"
<<
loop
;
}
void
ScheduleImpl
::
CopyTransformAndLoopInfo
(
void
St
ScheduleImpl
::
CopyTransformAndLoopInfo
(
const
std
::
string
&
block_name
,
const
std
::
string
&
block_target_name
)
{
auto
block
=
this
->
GetBlock
(
block_name
);
auto
block_target
=
this
->
GetBlock
(
block_target_name
);
this
->
CopyTransformAndLoopInfo
(
block
,
block_target
);
}
void
ScheduleImpl
::
CopyTransformAndLoopInfo
(
const
Expr
&
block
,
const
Expr
&
block_target
)
{
void
St
ScheduleImpl
::
CopyTransformAndLoopInfo
(
const
Expr
&
block
,
const
Expr
&
block_target
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
block_target
.
As
<
ir
::
ScheduleBlockRealize
>
());
auto
exprs
=
this
->
GetModule
().
GetExprs
();
...
...
@@ -2185,16 +2097,16 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
std
::
set
<
std
::
string
>
used_target_loop_vars
;
for
(
auto
&
iter_val
:
new_iter_values
)
{
auto
find_partial_loop
=
ir
::
CollectIRNodesWithoutTensor
(
iter_val
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
iter_val
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
as_var
())
used_target_loop_vars
.
insert
(
x
->
as_var_ref
()
->
name
);
return
x
->
as_var
();
});
}
CHECK
(
!
used_target_loop_vars
.
empty
());
std
::
vector
<
Expr
>
used_target_loops
;
auto
expr_copy
=
optim
::
IRCopy
(
expr
);
auto
expr_copy
=
ir
::
ir_utils
::
IRCopy
(
expr
);
for
(
auto
&
var
:
used_target_loop_vars
)
{
auto
find_loop_var
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_loop_var
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
expr_copy
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
loop_var
->
name
==
var
&&
...
...
@@ -2217,12 +2129,12 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
VLOG
(
3
)
<<
"changed_loop_num is : "
<<
changed_loop_num
;
VLOG
(
3
)
<<
"old_iter_values.size() is : "
<<
old_iter_values
.
size
();
if
(
changed_loop_num
>=
static_cast
<
int
>
(
old_iter_values
.
size
()))
{
new_loop
=
optim
::
IRCopy
(
block
);
new_loop
=
ir
::
ir_utils
::
IRCopy
(
block
);
new_loop
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
iter_values
=
new_iter_values
;
}
else
{
CHECK
(
old_iter_values
[
changed_loop_num
].
as_var
());
auto
old_var
=
old_iter_values
[
changed_loop_num
].
as_var_ref
();
auto
find_partial_loop
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_partial_loop
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
...
...
@@ -2231,8 +2143,8 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
},
true
);
CHECK_EQ
(
find_partial_loop
.
size
(),
1U
);
new_loop
=
optim
::
IRCopy
(
*
find_partial_loop
.
begin
());
auto
find_schedule_block
=
ir
::
CollectIRNodesWithoutTensor
(
new_loop
=
ir
::
ir_utils
::
IRCopy
(
*
find_partial_loop
.
begin
());
auto
find_schedule_block
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
new_loop
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
();
},
true
);
...
...
@@ -2265,7 +2177,7 @@ void ScheduleImpl::CopyTransformAndLoopInfo(const Expr& block,
this
->
Replace
(
all_loops
[
0
],
res
);
}
std
::
vector
<
Expr
>
ScheduleImpl
::
SamplePerfectTile
(
std
::
vector
<
Expr
>
St
ScheduleImpl
::
SamplePerfectTile
(
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
,
const
Expr
&
loop
,
int
n
,
...
...
@@ -2296,7 +2208,7 @@ std::vector<Expr> ScheduleImpl::SamplePerfectTile(
return
result_expr
;
}
Expr
ScheduleImpl
::
SampleCategorical
(
Expr
St
ScheduleImpl
::
SampleCategorical
(
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
,
const
std
::
vector
<
int
>&
candidates
,
const
std
::
vector
<
float
>&
probs
)
{
...
...
@@ -2314,41 +2226,52 @@ IRSchedule::IRSchedule() {}
IRSchedule
::
IRSchedule
(
const
ModuleExpr
&
module_expr
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
bool
debug_flag
,
utils
::
ErrorMessageLevel
err_msg_level
)
{
impl_
=
std
::
make_unique
<
ScheduleImpl
>
(
module_expr
,
debug_flag
,
err_msg_level
);
utils
::
ErrorMessageLevel
err_msg_level
,
bool
is_dynamic_shape
)
:
impl_
(
ScheduleBase
::
Make
(
module_expr
,
debug_flag
,
err_msg_level
,
is_dynamic_shape
)),
is_dynamic_shape_
(
is_dynamic_shape
)
{
this
->
InitSeed
(
rand_seed
);
}
IRSchedule
::
IRSchedule
(
ir
::
ModuleExpr
&&
mod_expr
,
ScheduleDesc
&&
trace
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
)
:
impl_
(
std
::
make_unique
<
ScheduleImpl
>
(
std
::
move
(
mod_expr
))),
trace_
(
std
::
move
(
trace
))
{
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
bool
is_dynamic_shape
)
:
impl_
(
ScheduleBase
::
Make
(
std
::
move
(
mod_expr
),
is_dynamic_shape
)),
trace_
(
std
::
move
(
trace
)),
is_dynamic_shape_
(
is_dynamic_shape
)
{
this
->
InitSeed
(
rand_seed
);
}
IRSchedule
::
IRSchedule
(
const
IRSchedule
&
other
)
:
impl_
(
std
::
make_unique
<
ScheduleImpl
>
(
optim
::
IRCopy
(
other
.
GetModule
()))),
trace_
(
other
.
trace_
)
{
:
impl_
(
ScheduleBase
::
Make
(
ir
::
ir_utils
::
IRCopy
(
other
.
GetModule
()),
other
.
IsDynamicShape
())),
trace_
(
other
.
trace_
),
is_dynamic_shape_
(
other
.
IsDynamicShape
())
{
this
->
InitSeed
(
other
.
ForkSeed
());
}
IRSchedule
&
IRSchedule
::
operator
=
(
const
IRSchedule
&
src
)
{
impl_
=
std
::
make_unique
<
ScheduleImpl
>
(
optim
::
IRCopy
(
src
.
GetModule
()));
impl_
=
ScheduleBase
::
Make
(
ir
::
ir_utils
::
IRCopy
(
src
.
GetModule
()),
src
.
IsDynamicShape
());
trace_
=
src
.
trace_
;
is_dynamic_shape_
=
src
.
IsDynamicShape
();
this
->
InitSeed
(
src
.
ForkSeed
());
return
*
this
;
}
IRSchedule
::
IRSchedule
(
IRSchedule
&&
other
)
:
impl_
(
std
::
move
(
other
.
impl_
)),
trace_
(
std
::
move
(
other
.
trace_
))
{
:
impl_
(
std
::
move
(
other
.
impl_
)),
trace_
(
std
::
move
(
other
.
trace_
)),
is_dynamic_shape_
(
other
.
IsDynamicShape
())
{
this
->
InitSeed
(
other
.
ForkSeed
());
}
IRSchedule
&
IRSchedule
::
operator
=
(
IRSchedule
&&
src
)
{
impl_
=
std
::
move
(
src
.
impl_
);
trace_
=
std
::
move
(
src
.
trace_
);
is_dynamic_shape_
=
src
.
IsDynamicShape
();
this
->
InitSeed
(
src
.
ForkSeed
());
return
*
this
;
}
...
...
@@ -2561,6 +2484,13 @@ void IRSchedule::SetBuffer(Expr& block,
{}));
}
Expr
IRSchedule
::
AddUnitLoop
(
const
Expr
&
block
)
{
Expr
ret
=
impl_
->
AddUnitLoop
(
block
);
trace_
.
Append
(
ScheduleDesc
::
Step
(
"AddUnitLoop"
,
{{
"block"
,
std
::
vector
<
Expr
>
({
block
})}},
{},
{
ret
}));
return
ret
;
}
Expr
IRSchedule
::
Reorder
(
const
std
::
vector
<
Expr
>&
loops
)
{
Expr
ret
=
impl_
->
Reorder
(
loops
);
trace_
.
Append
(
ScheduleDesc
::
Step
(
"Reorder"
,
{{
"loops"
,
loops
}},
{},
{
ret
}));
...
...
@@ -2643,6 +2573,15 @@ Expr IRSchedule::Rfactor(const Expr& rf_loop, int rf_axis) {
return
result
;
}
Expr
IRSchedule
::
FactorizeReduction
(
const
Expr
&
rf_loop
,
int
rf_axis
)
{
auto
result
=
impl_
->
FactorizeReduction
(
rf_loop
,
rf_axis
);
trace_
.
Append
(
ScheduleDesc
::
Step
(
"FactorizeReduction"
,
{{
"rf_loop"
,
std
::
vector
<
Expr
>
({
rf_loop
})}},
{{
"rf_axis"
,
rf_axis
}},
{
result
}));
return
result
;
}
void
IRSchedule
::
Annotate
(
const
Expr
&
block
,
const
std
::
string
&
key
,
const
attr_t
&
value
)
{
...
...
paddle/cinn/ir/schedule/ir_schedule.h
View file @
01a10755
...
...
@@ -21,51 +21,23 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/schedule/schedule_base.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"
namespace
cinn
{
namespace
ir
{
/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
*/
class
ModuleExpr
{
public:
ModuleExpr
()
=
default
;
ModuleExpr
(
const
ModuleExpr
&
mod_expr
)
=
default
;
ModuleExpr
(
ModuleExpr
&&
mod_expr
)
=
default
;
ModuleExpr
&
operator
=
(
const
ModuleExpr
&
mod_expr
)
=
default
;
explicit
ModuleExpr
(
const
std
::
vector
<
Expr
>&
exprs
)
:
exprs_
(
exprs
)
{}
explicit
ModuleExpr
(
std
::
vector
<
Expr
>&&
exprs
)
:
exprs_
(
std
::
move
(
exprs
))
{}
//! Get all the Expr in this ModuleExpr.
std
::
vector
<
Expr
>
GetExprs
()
{
return
exprs_
;
}
std
::
vector
<
Expr
>
GetExprs
()
const
{
return
exprs_
;
}
void
SetExprs
(
const
std
::
vector
<
Expr
>&
exprs
)
{
exprs_
=
exprs
;
}
private:
//! Exprs stored in ModuleExpr. Each one is an AST, representing a computation
//! kernel.
std
::
vector
<
Expr
>
exprs_
;
};
/**
* A struct containing all the schedule primitives. Each shedule primitive is a
* member function of IRSchedule. Schedule primitves are implmented by
* ScheduleImpl manipulating the AST - IR(Expr). To support serializing and
*
St
ScheduleImpl manipulating the AST - IR(Expr). To support serializing and
* replaying, each schedule primitive should append a ScheduleDesc::Step to the
* trace_ in its corresponding function implment.
*/
class
ScheduleImpl
;
class
IRSchedule
{
public:
IRSchedule
();
...
...
@@ -73,10 +45,12 @@ class IRSchedule {
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
-
1
,
bool
debug_flag
=
false
,
utils
::
ErrorMessageLevel
err_msg_level
=
utils
::
ErrorMessageLevel
::
kGeneral
);
utils
::
ErrorMessageLevel
::
kGeneral
,
bool
is_dynamic
=
false
);
IRSchedule
(
ir
::
ModuleExpr
&&
mod_expr
,
ScheduleDesc
&&
trace
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
-
1
);
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
-
1
,
bool
is_dynamic
=
false
);
IRSchedule
(
const
IRSchedule
&
other
);
IRSchedule
&
operator
=
(
const
IRSchedule
&
src
);
IRSchedule
(
IRSchedule
&&
other
);
...
...
@@ -97,6 +71,8 @@ class IRSchedule {
//! Get the ScheduleDesc that traces the scheduling process
const
ScheduleDesc
&
GetTraceDesc
()
const
{
return
trace_
;
}
bool
IsDynamicShape
()
const
{
return
is_dynamic_shape_
;
}
/**
* \brief Get all the loops of specific Block stored in ModuleExpr.
* @param block The block we find loop in.
...
...
@@ -244,7 +220,7 @@ class IRSchedule {
*/
void
SyncThreads
(
const
Expr
&
ir_node
,
bool
after_node
=
true
);
/*
!
/*
*
* \brief Set a tensor's buffer type(memory_type)
* \param block The ScheduleBlockRealize corresponding to an unique tensor.
* \param memory_type The memory type we want to set. Should be "local",
...
...
@@ -254,6 +230,13 @@ class IRSchedule {
const
std
::
string
&
memory_type
,
bool
fixed
=
false
);
// NOLINT
/**
* \brief Create a new unit loop on top of the block.
* @param block The block to be added the new loop.
* @return The new unit loop.
*/
Expr
AddUnitLoop
(
const
Expr
&
block
);
/**
* \brief Reorder the loops in the order of vector.
* @param loops The loops to be reordered.
...
...
@@ -381,6 +364,46 @@ class IRSchedule {
*/
Expr
Rfactor
(
const
Expr
&
rf_loop
,
int
rf_axis
);
/**
* \brief Factorize the reduction block by the given loop. The block will be
* split into two blocks: reduction-factorized block and write-back block.
* @param rf_loop the reduce loop to be factorized.
* @param rf_axis The position where the new dimension is placed in the new rf
* tensor.
* @return The new created rf tensor.
*
* For example, input the block:
* \code
* for (i, 0, 10) // serial loop
* B_init[i] = 0
* for (j, 0, 20) // reduce loop
* for (k, 0, 30) // reduce loop
* B[i] = B[i] + A[i, j, k]
* \endcode
*
* If the rf loop is j and rf_axis is 0, the transformation is
* divided into 2 steps:
* 1. get the rf block where the reduce loop j is transformed to the
* serial loop with no accumalation and a new rf tensor is created.
* The axis j will be placed in the rf_axis of the new rf_tensor.
* The rf_block is as follows:
* \code
* for (i, 0, 10) // serial loop
* for (j, 0, 20) // rf loop j is transformed to the serial loop
* rf_B_init[j, i] = 0
* for (k, 0, 30) // reduce loop.
* rf_B[j, i] = rf_B[j, i] + A[i, j, k]
* \endcode
* 2. do reduction of the rf loop j to get the final result block:
* \code
* for (i, 0, 10) // serial loop
* B_init[i] = 0
* for (j, 0, 20) // rf reduction loop
* B[i] = B[i] + rf_B[j, i]
* \endcode
*/
Expr
FactorizeReduction
(
const
Expr
&
rf_loop
,
int
rf_axis
);
/*!
* \brief Annotate a block with a key-value pair to set as its attribute
* \param block The block to be annotated
...
...
@@ -451,9 +474,10 @@ class IRSchedule {
utils
::
LinearRandomEngine
::
StateType
ForkSeed
()
const
;
private:
std
::
unique_ptr
<
Schedule
Impl
>
impl_
;
std
::
unique_ptr
<
Schedule
Base
>
impl_
;
mutable
ScheduleDesc
trace_
;
// trace the scheduling process
mutable
utils
::
LinearRandomEngine
::
StateType
rand_seed_
;
bool
is_dynamic_shape_
;
};
/*!
...
...
paddle/cinn/ir/schedule/ir_schedule_error.cc
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
{
namespace
ir
{
...
...
@@ -23,14 +23,14 @@ std::string IRScheduleErrorHandler::GeneralErrorMessage() const {
std
::
ostringstream
os
;
os
<<
"[IRScheduleError] An error occurred in the scheduel primitive < "
<<
this
->
primitive_
<<
" >. "
<<
std
::
endl
;
os
<<
this
->
err_msg_
;
os
<<
indent_str_
<<
"[Error info] "
<<
this
->
err_msg_
;
return
os
.
str
();
}
std
::
string
IRScheduleErrorHandler
::
DetailedErrorMessage
()
const
{
std
::
ostringstream
os
;
os
<<
GeneralErrorMessage
();
os
<<
"[Expr info] The Expr of current schedule is:
"
os
<<
indent_str_
<<
"[Expr info] The Expr of current schedule is:
\n
"
<<
this
->
module_expr_
.
GetExprs
()
<<
std
::
endl
;
return
os
.
str
();
}
...
...
paddle/cinn/ir/schedule/ir_schedule_util.cc
View file @
01a10755
...
...
@@ -26,11 +26,11 @@
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
...
...
@@ -40,7 +40,7 @@ namespace ir {
Tensor
GetTensor
(
const
Expr
&
block
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
auto
find_tensor
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_tensor
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
},
true
);
CHECK_EQ
(
find_tensor
.
size
(),
1U
)
<<
"One block should only have one Store node!(except for root block)"
;
...
...
@@ -52,13 +52,13 @@ Tensor GetTensor(const Expr& block) {
Tensor
GetReadTensor
(
const
Expr
&
block
,
int
index
)
{
CHECK
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
());
auto
find_tensor
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_tensor
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
},
true
);
CHECK_EQ
(
find_tensor
.
size
(),
1U
)
<<
"One block should only have one Store node!(except for root block)"
;
std
::
vector
<
Tensor
>
res
;
auto
find_read_tensor
=
ir
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
Load
>
())
res
.
push_back
(
x
->
As
<
ir
::
Load
>
()
->
tensor
.
as_tensor_ref
());
return
x
->
As
<
ir
::
Load
>
();
...
...
@@ -86,41 +86,43 @@ void SetCudaAxisInfo(Expr* lowered_func) {
auto
func_body
=
lowered_func
->
as_lowered_func_ref
()
->
body
;
CudaAxisInfo
info
;
auto
block_nodes
=
ir
::
CollectIRNodes
(
func_body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
bind_info
().
valid
())
{
auto
bind_info
=
x
->
As
<
ir
::
For
>
()
->
bind_info
();
info
.
set_valid
(
true
);
if
(
bind_info
.
for_type
==
ForType
::
GPUThread
)
{
CHECK
(
common
::
is_zero
(
x
->
As
<
ir
::
For
>
()
->
min
));
CHECK
(
x
->
As
<
ir
::
For
>
()
->
extent
.
is_constant
());
int
range
=
x
->
As
<
ir
::
For
>
()
->
extent
.
get_constant
();
range
=
range
>
info
.
block_dim
(
bind_info
.
offset
)
?
range
:
info
.
block_dim
(
bind_info
.
offset
);
VLOG
(
3
)
<<
"Set block dim["
<<
bind_info
.
offset
<<
"] with range "
<<
range
;
info
.
set_block_dim
(
bind_info
.
offset
,
range
);
}
else
if
(
bind_info
.
for_type
==
ForType
::
GPUBlock
)
{
CHECK
(
common
::
is_zero
(
x
->
As
<
ir
::
For
>
()
->
min
));
CHECK
(
x
->
As
<
ir
::
For
>
()
->
extent
.
is_constant
());
int
range
=
x
->
As
<
ir
::
For
>
()
->
extent
.
get_constant
();
range
=
range
>
info
.
grid_dim
(
bind_info
.
offset
)
?
range
:
info
.
grid_dim
(
bind_info
.
offset
);
info
.
set_grid_dim
(
bind_info
.
offset
,
range
);
VLOG
(
3
)
<<
"Set grid dim["
<<
bind_info
.
offset
<<
"] with range "
<<
range
;
}
else
{
LOG
(
FATAL
)
<<
"The for loop's bind info should be gpu block or thread!"
;
}
}
return
(
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
bind_info
().
valid
());
});
auto
block_nodes
=
ir
::
ir_utils
::
CollectIRNodes
(
func_body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
bind_info
().
valid
())
{
auto
bind_info
=
x
->
As
<
ir
::
For
>
()
->
bind_info
();
info
.
set_valid
(
true
);
if
(
bind_info
.
for_type
==
ForType
::
GPUThread
)
{
CHECK
(
common
::
is_zero
(
x
->
As
<
ir
::
For
>
()
->
min
));
CHECK
(
x
->
As
<
ir
::
For
>
()
->
extent
.
is_constant
());
int
range
=
x
->
As
<
ir
::
For
>
()
->
extent
.
get_constant
();
range
=
range
>
info
.
block_dim
(
bind_info
.
offset
)
?
range
:
info
.
block_dim
(
bind_info
.
offset
);
VLOG
(
3
)
<<
"Set block dim["
<<
bind_info
.
offset
<<
"] with range "
<<
range
;
info
.
set_block_dim
(
bind_info
.
offset
,
range
);
}
else
if
(
bind_info
.
for_type
==
ForType
::
GPUBlock
)
{
CHECK
(
common
::
is_zero
(
x
->
As
<
ir
::
For
>
()
->
min
));
CHECK
(
x
->
As
<
ir
::
For
>
()
->
extent
.
is_constant
());
int
range
=
x
->
As
<
ir
::
For
>
()
->
extent
.
get_constant
();
range
=
range
>
info
.
grid_dim
(
bind_info
.
offset
)
?
range
:
info
.
grid_dim
(
bind_info
.
offset
);
info
.
set_grid_dim
(
bind_info
.
offset
,
range
);
VLOG
(
3
)
<<
"Set grid dim["
<<
bind_info
.
offset
<<
"] with range "
<<
range
;
}
else
{
LOG
(
FATAL
)
<<
"The for loop's bind info should be gpu block or thread!"
;
}
}
return
(
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
bind_info
().
valid
());
});
lowered_func
->
as_lowered_func_ref
()
->
cuda_axis_info
=
info
;
}
bool
Contains
(
const
Expr
&
container
,
const
Expr
&
expr
)
{
auto
find_expr
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_expr
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
container
,
[
&
](
const
Expr
*
x
)
{
return
(
x
->
node_type
()
==
expr
.
node_type
()
&&
*
x
==
expr
);
...
...
@@ -219,6 +221,14 @@ void ReplaceExpr(Expr* source,
return
;
}
void
ReplaceExpr
(
Expr
*
source
,
const
std
::
map
<
Var
,
Expr
,
CompVar
>&
replacing_map
)
{
if
(
replacing_map
.
empty
())
return
;
MappingVarToExprMutator
mapper
(
replacing_map
);
mapper
(
source
);
return
;
}
std
::
vector
<
int
>
ValidateFactors
(
const
std
::
vector
<
int
>&
factors
,
int
total_extent
,
const
ModuleExpr
&
module_expr
)
{
...
...
@@ -283,13 +293,13 @@ void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) {
auto
*
rf_for
=
rf_loop
.
As
<
ir
::
For
>
();
CHECK
(
rf_for
)
<<
"Expr param of Rfactor must be For node! Please check."
;
// check the rf_loop only has one schedule block
auto
block_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
auto
block_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
rf_loop
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ScheduleBlockRealize
>
();
},
true
);
CHECK_EQ
(
block_nodes
.
size
(),
1U
)
<<
"Rfactor Loop should only have one schedule block"
;
auto
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
rf_loop
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
Store
>
();
},
true
);
CHECK_EQ
(
find_store
.
size
(),
1U
);
auto
indice
=
find_store
.
begin
()
->
As
<
Store
>
()
->
indices
;
...
...
@@ -322,9 +332,9 @@ void CHECKRfactorValidation(const Expr& rf_loop, int rf_axis) {
}
std
::
vector
<
Expr
>
GetLoopsOfExpr
(
const
Expr
&
expr
,
const
Expr
&
root
)
{
auto
loop_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
Contains
(
*
x
,
expr
);
});
auto
loop_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
Contains
(
*
x
,
expr
);
});
std
::
vector
<
Expr
>
result
(
loop_nodes
.
begin
(),
loop_nodes
.
end
());
if
(
result
.
empty
())
LOG
(
FATAL
)
<<
"Didn't find expr's :
\n
"
...
...
@@ -346,8 +356,8 @@ IterRange GetAccessedRange(const Expr& index,
var_maxs
.
emplace_back
(
range
.
min
+
range
.
extent
-
1
);
}
Expr
indice_min
=
optim
::
IRCopy
(
index
);
Expr
indice_max
=
optim
::
IRCopy
(
index
);
Expr
indice_min
=
ir
::
ir_utils
::
IRCopy
(
index
);
Expr
indice_max
=
ir
::
ir_utils
::
IRCopy
(
index
);
// replace the var by the corresponding iter_value
ReplaceExpr
(
&
indice_min
,
iter_vars
,
var_mins
);
ReplaceExpr
(
&
indice_max
,
iter_vars
,
var_maxs
);
...
...
@@ -357,8 +367,16 @@ IterRange GetAccessedRange(const Expr& index,
Expr
indice_extent
;
Expr
mod_extent
(
0
);
if
(
indice_min
.
As
<
Mod
>
()
&&
indice_min
.
As
<
Mod
>
()
->
b
().
is_constant
())
if
(
indice_min
.
As
<
Mod
>
()
&&
indice_min
.
As
<
Mod
>
()
->
b
().
is_constant
())
{
Expr
mod_right_min
=
indice_min
.
As
<
Mod
>
()
->
a
();
Expr
mod_right_max
=
indice_max
.
As
<
Mod
>
()
->
a
();
Expr
mod_right_extent
=
common
::
AutoSimplify
(
mod_right_max
-
mod_right_min
+
1
);
mod_extent
=
indice_min
.
As
<
Mod
>
()
->
b
();
if
(
mod_right_extent
.
get_constant
()
<
mod_extent
.
get_constant
())
{
mod_extent
=
mod_right_extent
;
}
}
if
(
indice_min
==
indice_max
)
{
if
(
common
::
is_zero
(
mod_extent
))
{
...
...
@@ -406,7 +424,7 @@ std::vector<IterRange> CalculateTensorRegions(
std
::
vector
<
IterRange
>
result
;
for
(
int
i
=
0
;
i
<
tensor_indices
.
size
();
++
i
)
{
Expr
binded_index
=
optim
::
IRCopy
(
tensor_indices
[
i
]);
Expr
binded_index
=
ir
::
ir_utils
::
IRCopy
(
tensor_indices
[
i
]);
ReplaceExpr
(
&
binded_index
,
iter_vars
,
iter_values
);
auto
range
=
GetAccessedRange
(
binded_index
,
loop_vars
,
loop_ranges
);
...
...
@@ -439,8 +457,8 @@ Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) {
->
body
;
if
(
is_write
)
{
std
::
vector
<
Expr
>
find_store_vec
;
auto
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
auto
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
Store
>
())
find_store_vec
.
push_back
(
*
x
);
return
x
->
As
<
ir
::
Store
>
();
});
...
...
@@ -450,8 +468,8 @@ Expr GetNthAccessExpr(const Expr& block, int index, bool is_write) {
return
store_index
;
}
else
{
std
::
vector
<
Expr
>
find_load_vec
;
auto
find_load
=
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
auto
find_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
Load
>
())
find_load_vec
.
push_back
(
*
x
);
return
x
->
As
<
ir
::
Load
>
();
});
...
...
@@ -526,7 +544,7 @@ void FindInsertionPoint(const Expr& root, CacheBlockInfo* info, bool is_write) {
Expr
find_tensor
=
is_write
?
Expr
(
info
->
write_tensor
)
:
Expr
(
info
->
read_tensor
);
auto
find_produce_read
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
x
->
As
<
ir
::
Store
>
()
->
tensor
==
find_tensor
;
});
...
...
@@ -654,7 +672,7 @@ Expr ConstructOtherStmtChain(const std::vector<Expr>& stmts,
const
std
::
vector
<
int
>
reordered_indices
)
{
Expr
new_loop
;
for
(
int
i
=
reordered_indices
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
Expr
temp
=
optim
::
IRCopy
(
loops
[
reordered_indices
[
i
]]);
Expr
temp
=
ir
::
ir_utils
::
IRCopy
(
loops
[
reordered_indices
[
i
]]);
CHECK
(
temp
.
defined
());
CHECK
(
temp
.
As
<
ir
::
For
>
());
if
(
new_loop
.
defined
())
{
...
...
@@ -675,9 +693,9 @@ Expr ConstructNewLoopChain(const std::vector<Expr>& chain,
// In each IfThenElse node, find the vars its condition depends on.
for
(
auto
&
if_expr
:
if_nodes
)
{
CHECK
(
if_expr
.
As
<
IfThenElse
>
());
auto
var_set
=
ir
::
CollectIRNodes
(
if_expr
.
As
<
IfThenElse
>
()
->
condition
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_var
();
});
auto
var_set
=
ir
::
ir_utils
::
CollectIRNodes
(
if_expr
.
As
<
IfThenElse
>
()
->
condition
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_var
();
});
std
::
set
<
std
::
string
>
var_name_set
;
for
(
auto
&
i
:
var_set
)
var_name_set
.
insert
(
i
.
as_var
()
->
name
);
condition_vars
.
push_back
(
var_name_set
);
...
...
@@ -693,10 +711,10 @@ Expr ConstructNewLoopChain(const std::vector<Expr>& chain,
Expr
temp
;
if
(
loop_set
.
count
(
loop_in_chain
))
{
CHECK_GE
(
index
,
0
);
temp
=
optim
::
IRCopy
(
ordered_loops
[
index
]);
temp
=
ir
::
ir_utils
::
IRCopy
(
ordered_loops
[
index
]);
--
index
;
}
else
{
temp
=
optim
::
IRCopy
(
loop_in_chain
);
temp
=
ir
::
ir_utils
::
IRCopy
(
loop_in_chain
);
}
CHECK
(
temp
.
defined
());
CHECK
(
temp
.
As
<
ir
::
For
>
());
...
...
@@ -863,9 +881,9 @@ std::vector<Expr> GetProducers(const Expr& block, const Expr& root) {
std
::
string
block_name
=
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
ir
::
CollectIRNodesWithoutTensor
(
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
producer_tensor_names
,
&
block_name
](
const
Expr
*
x
)
{
auto
*
load
=
x
->
As
<
ir
::
Load
>
();
const
ir
::
Load
*
load
=
x
->
As
<
ir
::
Load
>
();
if
(
load
)
{
producer_tensor_names
.
insert
(
load
->
tensor
.
as_tensor
()
->
name
);
if
(
load
->
tensor
.
as_tensor
()
->
name
==
block_name
)
{
...
...
@@ -874,20 +892,36 @@ std::vector<Expr> GetProducers(const Expr& block, const Expr& root) {
}
return
true
;
}
const
ir
::
Store
*
store
=
x
->
As
<
ir
::
Store
>
();
if
(
store
)
{
std
::
set
<
ir
::
Expr
>
call_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
store
->
value
,
[](
const
ir
::
Expr
*
x
)
{
return
x
->
As
<
ir
::
Call
>
();
});
for
(
ir
::
Expr
call
:
call_nodes
)
{
const
std
::
vector
<
ir
::
Expr
>&
read_args
=
call
.
As
<
ir
::
Call
>
()
->
read_args
;
for
(
const
ir
::
Expr
&
arg
:
read_args
)
{
if
(
arg
.
as_tensor
())
{
producer_tensor_names
.
insert
(
arg
.
as_tensor_ref
()
->
name
);
}
}
}
}
return
false
;
});
// traverse each of other blocks and filter those ones which contain at least
// one producer tensor;
auto
find_blocks
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
block
,
&
root
](
const
Expr
*
x
)
{
auto
find_blocks
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
block
,
&
root
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
*
x
!=
block
&&
*
x
!=
root
;
});
for
(
auto
&&
cur
:
find_blocks
)
{
auto
*
cur_block
=
cur
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK
(
cur_block
)
<<
"block result should be a ScheduleBlockRealize"
;
auto
find_stores
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_stores
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
cur_block
->
body
,
[
&
producer_tensor_names
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
producer_tensor_names
.
count
(
...
...
@@ -905,32 +939,44 @@ std::vector<Expr> GetConsumers(const Expr& block, const Expr& root) {
std
::
string
block_tensor
=
GetTensor
(
block
)
->
name
;
if
(
IsReduceInitTensorName
(
block_tensor
))
{
std
::
string
consumer_name
=
GetOriginalReduceTensorName
(
block_tensor
);
auto
consumer
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
==
consumer_name
;
});
auto
consumer
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
==
consumer_name
;
});
CHECK_EQ
(
consumer
.
size
(),
1
);
return
{
*
consumer
.
begin
()};
}
auto
find_block
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
*
x
!=
block
&&
*
x
!=
root
;
});
auto
find_block
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
*
x
!=
block
&&
*
x
!=
root
;
});
for
(
auto
&
i
:
find_block
)
{
CHECK
(
i
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
auto
block_body
=
i
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
;
auto
find_load
=
ir
::
CollectIRNodesWithoutTensor
(
block_body
,
[
&
](
const
Expr
*
x
)
{
auto
find_load_or_call
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block_body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
Call
>
())
{
const
std
::
vector
<
ir
::
Expr
>&
read_args
=
x
->
As
<
ir
::
Call
>
()
->
read_args
;
for
(
const
ir
::
Expr
&
arg
:
read_args
)
{
if
(
arg
.
as_tensor
()
&&
arg
.
as_tensor_ref
()
->
name
==
block_tensor
)
{
return
true
;
}
}
}
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
.
as_tensor_ref
()
->
name
==
block_tensor
;
});
if
(
!
find_load
.
empty
())
consumers
.
emplace_back
(
i
);
if
(
!
find_load
_or_call
.
empty
())
consumers
.
emplace_back
(
i
);
}
return
consumers
;
}
...
...
@@ -938,7 +984,7 @@ std::vector<Expr> GetConsumers(const Expr& block, const Expr& root) {
void
CheckComputeAtValidation
(
const
Expr
&
block
,
const
Expr
&
loop
,
const
Expr
&
root
)
{
auto
find_block
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_block
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
*
x
==
block
;
...
...
@@ -946,13 +992,13 @@ void CheckComputeAtValidation(const Expr& block,
true
);
CHECK
(
!
find_block
.
empty
())
<<
"Didn't find block in root!"
;
auto
find_loop
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_loop
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
*
x
==
loop
;
},
true
);
CHECK
(
!
find_loop
.
empty
())
<<
"Didn't find loop in root!"
;
auto
find_block_in_loop
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_block_in_loop
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
loop
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
ScheduleBlockRealize
>
()
&&
*
x
==
block
;
...
...
@@ -1005,10 +1051,10 @@ std::vector<IterRange> CalculateRequiredRegions(
std
::
set
<
Expr
>
provided_nodes
;
if
(
is_store_provided
)
{
provided_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
provided_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
}
else
{
provided_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
provided_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
();
});
}
...
...
@@ -1025,9 +1071,9 @@ std::vector<IterRange> CalculateRequiredRegions(
for
(
const
Expr
&
req_block
:
required_blocks
)
{
CHECK
(
req_block
.
As
<
ir
::
ScheduleBlockRealize
>
());
Expr
block_body
=
optim
::
IRCopy
(
req_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
);
ir
::
ir_utils
::
IRCopy
(
req_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
);
auto
iter_vars
=
req_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
iter_vars
;
...
...
@@ -1036,7 +1082,7 @@ std::vector<IterRange> CalculateRequiredRegions(
// Notice that we look for For nodes in loop's body instead of loop
// itself.
auto
find_loops
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_loops
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
loop
.
As
<
ir
::
For
>
()
->
body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
Contains
(
*
x
,
req_block
);
});
...
...
@@ -1052,15 +1098,15 @@ std::vector<IterRange> CalculateRequiredRegions(
std
::
set
<
Expr
>
required_nodes
;
if
(
is_store_provided
)
{
required_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
block_body
,
[
&
](
const
Expr
*
x
)
{
required_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
.
as_tensor_ref
()
->
name
==
provided_tensor_name
;
});
}
else
{
required_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
block_body
,
[
&
](
const
Expr
*
x
)
{
required_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
block_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
x
->
As
<
ir
::
Store
>
()
->
tensor
.
as_tensor_ref
()
->
name
==
provided_tensor_name
;
...
...
@@ -1105,7 +1151,7 @@ std::vector<IterRange> CalculateRequiredRegions(
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
iter_values
[
i
].
is_constant
());
if
(
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
iter_values
[
i
].
as_var
())
{
auto
find_for_loops
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
loop_var
->
name
==
block
.
As
<
ir
::
ScheduleBlockRealize
>
()
...
...
@@ -1134,13 +1180,13 @@ Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block,
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
;
// 1. Check the schedule block to be inlined is not a reduce tensor.
auto
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
},
true
);
CHECK_EQ
(
find_store
.
size
(),
1U
);
Expr
tensor
=
(
*
find_store
.
begin
()).
As
<
ir
::
Store
>
()
->
tensor
;
CHECK
(
!
tensor
.
as_tensor_ref
()
->
is_reduce_tensor
());
// 2. Check this schedule block is the only writer of the tensor.
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
...
...
@@ -1151,8 +1197,8 @@ Expr CheckComputeInlineValidationAndGetStore(const Expr& schedule_block,
CHECK_EQ
(
find_store
.
size
(),
1U
);
// 3. Check there is no overlap between the buffers the schedule block reads
// and writes.
auto
find_load
=
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
auto
find_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
==
tensor
;
});
CHECK
(
find_load
.
empty
());
...
...
@@ -1166,14 +1212,14 @@ std::tuple<Expr, Expr, Expr> CheckReverseComputeInlineValidationAndGetExprs(
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
body
;
// 1. Check the schedule block to be reverse inlined is not a reduce tensor.
auto
find_inlined_load
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_inlined_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
();
},
true
);
CHECK_EQ
(
find_inlined_load
.
size
(),
1U
);
Expr
tensor
=
(
*
find_inlined_load
.
begin
()).
As
<
ir
::
Load
>
()
->
tensor
;
CHECK
(
!
tensor
.
as_tensor_ref
()
->
is_reduce_tensor
());
auto
inlined_load
=
*
find_inlined_load
.
begin
();
// 2. Check this schedule block is the only reader of the tensor.
auto
find_load
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_load
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
()
&&
...
...
@@ -1184,20 +1230,20 @@ std::tuple<Expr, Expr, Expr> CheckReverseComputeInlineValidationAndGetExprs(
CHECK_EQ
(
find_load
.
size
(),
1U
);
// 3. Check there is no overlap between the buffers the schedule block reads
// and writes.
auto
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
auto
find_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
x
->
As
<
ir
::
Store
>
()
->
tensor
==
tensor
;
});
CHECK
(
find_store
.
empty
());
// 4. Get store that will be inlined.
auto
find_inlined_store
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
x
->
As
<
ir
::
Store
>
()
->
tensor
==
tensor
;
});
CHECK_EQ
(
find_inlined_store
.
size
(),
1U
);
auto
inlined_store
=
*
find_inlined_store
.
begin
();
// 5. Get target store.
auto
find_target_store
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_target_store
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
},
true
);
CHECK_EQ
(
find_target_store
.
size
(),
1U
);
auto
target_store
=
*
find_target_store
.
begin
();
...
...
@@ -1206,7 +1252,7 @@ std::tuple<Expr, Expr, Expr> CheckReverseComputeInlineValidationAndGetExprs(
bool
ContainVar
(
const
std
::
vector
<
Expr
>&
exprs
,
const
std
::
string
&
var_name
)
{
for
(
auto
&
expr
:
exprs
)
{
auto
find_expr
=
ir
::
CollectIRNodesWithoutTensor
(
auto
find_expr
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
_Var_
>
()
&&
x
->
As
<
_Var_
>
()
->
name
==
var_name
;
...
...
paddle/cinn/ir/schedule/ir_schedule_util.h
View file @
01a10755
...
...
@@ -22,9 +22,9 @@
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/schedule/ir_schedule_error.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/utils/random_engine.h"
#include "paddle/cinn/utils/string.h"
...
...
@@ -193,7 +193,7 @@ Tensor GetReadTensor(const Expr& block, int index);
int
GetLoopExtent
(
const
Expr
&
loop
);
/**
* \brief Given a vector of Ex
o
rs, return whether they contain a var with
* \brief Given a vector of Ex
p
rs, return whether they contain a var with
* specific name.
* @param exprs The given vector of Exprs
* @param var_name The name of specific var
...
...
@@ -241,6 +241,15 @@ void ReplaceExpr(Expr* source,
const
std
::
vector
<
Var
>&
replaced
,
const
std
::
vector
<
Expr
>&
candidates
);
/**
* Replace Vars in replaced to Exprs in candidates in source.
* @param source The Expr we will implement the change.
* @param replacing_map The one-to-one corresponded Vars -> Exprs to be
* replaced.
*/
void
ReplaceExpr
(
Expr
*
source
,
const
std
::
map
<
Var
,
Expr
,
CompVar
>&
replacing_map
);
/**
* Validate the factors param of Split. We will check if factors are validate
* and change -1 to positive integer.
...
...
@@ -427,9 +436,11 @@ IterRange RangeUnion(const IterRange& range1, const IterRange& range2);
* \param loop The loop where we will insert the block under it
* @param root The root of the whole AST.
* \param required_blocks vector of ScheduleBlockRealize nodes that require the
* block \param is_store_provided Whether Store nodes of the block provide the
* block
* \param is_store_provided Whether Store nodes of the block provide the
* tensor, true means it is in compute_at case, otherwise false means in
* reverse_compuate_at case \return Each index's range of block's tensor.
* reverse_compuate_at case
* \return Each index's range and can_keep_loop flag of block's tensor.
* Indicating the buffer region being required.
*/
std
::
vector
<
IterRange
>
CalculateRequiredRegions
(
...
...
paddle/cinn/ir/schedule/schedule_base.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/ir/schedule/schedule_base.h"
namespace
cinn
{
namespace
ir
{
/**
* Replace a node to another node.
* @param src_sref The node to be changed.
* @param tgt_stmt The node we want.
*/
void
ScheduleBase
::
Replace
(
const
Expr
&
src_sref
,
const
Expr
&
tgt_stmt
)
{
CHECK
(
src_sref
.
As
<
ir
::
For
>
()
||
src_sref
.
As
<
ir
::
Block
>
()
||
src_sref
.
As
<
ir
::
ScheduleBlockRealize
>
());
CHECK
(
tgt_stmt
.
As
<
ir
::
For
>
()
||
tgt_stmt
.
As
<
ir
::
Block
>
()
||
tgt_stmt
.
As
<
ir
::
ScheduleBlockRealize
>
());
if
(
src_sref
==
tgt_stmt
)
{
return
;
}
struct
ForLoopMutator
:
public
ir
::
IRMutator
<>
{
ForLoopMutator
(
const
Expr
&
source
,
const
Expr
&
target
)
:
source_
(
source
),
target_
(
target
)
{}
void
operator
()(
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
For
*
op
,
Expr
*
expr
)
override
{
if
(
*
expr
==
source_
)
{
*
expr
=
target_
;
return
;
}
ir
::
IRMutator
<>::
Visit
(
op
,
expr
);
}
void
Visit
(
const
ir
::
ScheduleBlockRealize
*
op
,
Expr
*
expr
)
override
{
if
(
*
expr
==
source_
)
{
*
expr
=
target_
;
return
;
}
ir
::
IRMutator
<>::
Visit
(
op
,
expr
);
}
void
Visit
(
const
ir
::
Block
*
op
,
Expr
*
expr
)
override
{
if
(
*
expr
==
source_
)
{
*
expr
=
target_
;
return
;
}
ir
::
IRMutator
<>::
Visit
(
op
,
expr
);
}
const
Expr
&
source_
;
const
Expr
&
target_
;
};
auto
exprs
=
module_expr_
.
GetExprs
();
ForLoopMutator
mutator
(
src_sref
,
tgt_stmt
);
for
(
auto
&
i
:
exprs
)
{
mutator
(
&
i
);
}
}
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/schedule/schedule_base.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/utils/error.h"
#include "paddle/cinn/utils/random_engine.h"
PD_DECLARE_int32
(
cinn_error_message_level
);
namespace
cinn
{
namespace
ir
{
/**
* A struct representing a module that contains Expr. This struct is only used
* in Schedule process.
*/
class
ModuleExpr
{
public:
ModuleExpr
()
=
default
;
ModuleExpr
(
const
ModuleExpr
&
mod_expr
)
=
default
;
ModuleExpr
(
ModuleExpr
&&
mod_expr
)
=
default
;
ModuleExpr
&
operator
=
(
const
ModuleExpr
&
mod_expr
)
=
default
;
explicit
ModuleExpr
(
const
std
::
vector
<
Expr
>&
exprs
)
:
exprs_
(
exprs
)
{}
explicit
ModuleExpr
(
std
::
vector
<
Expr
>&&
exprs
)
:
exprs_
(
std
::
move
(
exprs
))
{}
//! Get all the Expr in this ModuleExpr.
std
::
vector
<
Expr
>
GetExprs
()
{
return
exprs_
;
}
std
::
vector
<
Expr
>
GetExprs
()
const
{
return
exprs_
;
}
void
SetExprs
(
const
std
::
vector
<
Expr
>&
exprs
)
{
exprs_
=
exprs
;
}
private:
//! Exprs stored in ModuleExpr. Each one is an AST, representing a computation
//! kernel.
std
::
vector
<
Expr
>
exprs_
;
};
/**
* Define the interface for scheduling primitives,
* with subclasses DyScheduleImpl and StScheduleImpl.
*/
class
ScheduleBase
{
public:
ScheduleBase
()
=
delete
;
explicit
ScheduleBase
(
const
ModuleExpr
&
module_expr
,
bool
debug_flag
=
false
,
utils
::
ErrorMessageLevel
err_msg_level
=
utils
::
ErrorMessageLevel
::
kGeneral
)
:
module_expr_
(
module_expr
),
debug_flag_
(
debug_flag
)
{
err_msg_level_
=
static_cast
<
utils
::
ErrorMessageLevel
>
(
FLAGS_cinn_error_message_level
||
static_cast
<
int
>
(
err_msg_level
));
}
explicit
ScheduleBase
(
ModuleExpr
&&
module_expr
)
:
module_expr_
(
std
::
move
(
module_expr
))
{}
static
std
::
unique_ptr
<
ScheduleBase
>
Make
(
const
ModuleExpr
&
module_expr
,
bool
debug_flag
=
false
,
utils
::
ErrorMessageLevel
err_msg_level
=
utils
::
ErrorMessageLevel
::
kGeneral
,
bool
is_dynamic
=
false
);
static
std
::
unique_ptr
<
ScheduleBase
>
Make
(
ModuleExpr
&&
module_expr
,
bool
is_dynamic
=
false
);
void
SetDebugFlag
(
bool
debug_flag
)
{
debug_flag_
=
debug_flag
;
}
const
ModuleExpr
&
GetModule
()
const
{
return
module_expr_
;
}
void
SetExprs
(
const
std
::
vector
<
Expr
>&
exprs
)
{
module_expr_
.
SetExprs
(
exprs
);
}
virtual
void
MergeExprs
()
=
0
;
virtual
bool
HasBlock
(
const
std
::
string
&
block_name
)
const
=
0
;
virtual
std
::
vector
<
Expr
>
GetLoops
(
const
Expr
&
block
)
const
=
0
;
virtual
std
::
vector
<
Expr
>
GetLoops
(
const
std
::
string
&
block_name
)
const
=
0
;
virtual
std
::
vector
<
Expr
>
GetAllBlocks
()
const
=
0
;
virtual
std
::
vector
<
Expr
>
GetChildBlocks
(
const
Expr
&
expr
)
const
=
0
;
virtual
Expr
GetBlock
(
const
std
::
string
&
block_name
)
const
=
0
;
virtual
std
::
vector
<
Expr
>
Split
(
const
Expr
&
loop
,
const
std
::
vector
<
int
>&
factors
)
=
0
;
virtual
std
::
vector
<
Expr
>
SamplePerfectTile
(
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
,
const
Expr
&
loop
,
int
n
,
int
max_innermost_factor
)
=
0
;
virtual
Expr
Fuse
(
const
std
::
vector
<
Expr
>&
loops
)
=
0
;
virtual
Expr
Fuse
(
const
std
::
string
&
block_name
,
const
std
::
vector
<
int
>&
loops_index
)
=
0
;
virtual
Expr
Fuse
(
const
Expr
&
block
,
const
std
::
vector
<
int
>&
loops_index
)
=
0
;
virtual
void
ComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
,
bool
keep_unit_loops
)
=
0
;
virtual
void
SimpleComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
)
=
0
;
virtual
void
ReverseComputeAt
(
const
Expr
&
block
,
const
Expr
&
loop
,
bool
keep_unit_loops
)
=
0
;
virtual
Expr
GetRootBlock
(
const
Expr
&
expr
)
const
=
0
;
virtual
Expr
CacheRead
(
const
Expr
&
block
,
int
read_buffer_index
,
const
std
::
string
&
memory_type
)
=
0
;
virtual
Expr
CacheWrite
(
const
Expr
&
block
,
int
write_buffer_index
,
const
std
::
string
&
memory_type
)
=
0
;
virtual
void
SyncThreads
(
const
Expr
&
ir_node
,
bool
after_node
=
true
)
=
0
;
virtual
void
SetBuffer
(
Expr
&
block
,
// NOLINT
const
std
::
string
&
memory_type
,
bool
fixed
=
false
)
=
0
;
virtual
Expr
Reorder
(
const
std
::
vector
<
Expr
>&
loops
)
=
0
;
virtual
Expr
Reorder
(
const
std
::
string
&
block_name
,
const
std
::
vector
<
int
>&
loops_index
)
=
0
;
virtual
Expr
Reorder
(
const
Expr
&
block
,
const
std
::
vector
<
int
>&
loops_index
)
=
0
;
virtual
DeviceAPI
GetDeviceAPI
()
const
=
0
;
virtual
void
MutateForType
(
const
Expr
&
loop
,
ForType
for_type
,
int
factor
=
-
1
)
=
0
;
virtual
void
Parallel
(
const
Expr
&
loop
)
=
0
;
virtual
void
Vectorize
(
const
Expr
&
loop
,
int
factor
)
=
0
;
virtual
void
Unroll
(
const
Expr
&
loop
)
=
0
;
virtual
void
ComputeInline
(
const
Expr
&
schedule_block
)
=
0
;
virtual
void
ReverseComputeInline
(
const
Expr
&
schedule_block
)
=
0
;
virtual
void
Bind
(
const
Expr
&
loop
,
const
std
::
string
&
thread_axis
)
=
0
;
virtual
Expr
Rfactor
(
const
Expr
&
rf_loop
,
int
rf_axis
)
=
0
;
virtual
Expr
FactorizeReduction
(
const
Expr
&
rf_loop
,
int
rf_axis
)
=
0
;
virtual
Expr
AddUnitLoop
(
const
Expr
&
block
)
const
=
0
;
virtual
void
Annotate
(
const
Expr
&
block
,
const
std
::
string
&
key
,
const
attr_t
&
value
)
=
0
;
virtual
void
Unannotate
(
Expr
&
block
,
const
std
::
string
&
key
)
=
0
;
// NOLINT
virtual
void
FlattenLoops
(
const
std
::
vector
<
Expr
>&
loops
,
const
bool
force_flat
=
false
)
=
0
;
virtual
void
CopyTransformAndLoopInfo
(
const
Expr
&
block
,
const
Expr
&
block_target
)
=
0
;
virtual
void
CopyTransformAndLoopInfo
(
const
std
::
string
&
block_name
,
const
std
::
string
&
block_target_name
)
=
0
;
virtual
Expr
SampleCategorical
(
utils
::
LinearRandomEngine
::
StateType
*
rand_seed
,
const
std
::
vector
<
int
>&
candidates
,
const
std
::
vector
<
float
>&
probs
)
=
0
;
protected:
void
Replace
(
const
Expr
&
src_sref
,
const
Expr
&
tgt_stmt
);
ModuleExpr
module_expr_
;
bool
debug_flag_
{
false
};
utils
::
ErrorMessageLevel
err_msg_level_
=
utils
::
ErrorMessageLevel
::
kGeneral
;
};
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/schedule/schedule_desc.cc
View file @
01a10755
...
...
@@ -422,6 +422,12 @@ CINN_BUILD_STEP_KIND(SetBuffer)
.
SetApplyFn
(
APPLY_FUNC_UNIFORM
(
FREE_FUNCTION_CONVERTER
(
&
IRSchedule
::
SetBuffer
)));
CINN_BUILD_STEP_KIND
(
AddUnitLoop
)
.
Inputs
({
"block"
})
.
SetApplyFn
(
APPLY_FUNC_UNIFORM
(
FREE_FUNCTION_CONVERTER
(
static_cast
<
Expr
(
IRSchedule
::*
)(
const
Expr
&
)
>
(
&
IRSchedule
::
AddUnitLoop
))));
CINN_BUILD_STEP_KIND
(
Reorder
).
Inputs
({
"loops"
}).
SetApplyFn
(
APPLY_FUNC_UNIFORM
(
FREE_FUNCTION_CONVERTER
(
static_cast
<
Expr
(
IRSchedule
::*
)(
const
std
::
vector
<
Expr
>&
)
>
(
...
...
@@ -474,6 +480,12 @@ CINN_BUILD_STEP_KIND(Rfactor)
.
SetApplyFn
(
APPLY_FUNC_UNIFORM
(
FREE_FUNCTION_CONVERTER
(
&
IRSchedule
::
Rfactor
)));
CINN_BUILD_STEP_KIND
(
FactorizeReduction
)
.
Inputs
({
"rf_loop"
})
.
Attrs
({
"rf_axis"
})
.
SetApplyFn
(
APPLY_FUNC_UNIFORM
(
FREE_FUNCTION_CONVERTER
(
&
IRSchedule
::
FactorizeReduction
)));
CINN_BUILD_STEP_KIND
(
MergeExprs
)
.
SetApplyFn
(
APPLY_FUNC_UNIFORM
(
FREE_FUNCTION_CONVERTER
(
&
IRSchedule
::
MergeExprs
)));
...
...
paddle/cinn/ir/schedule_block_graph.cc
View file @
01a10755
...
...
@@ -14,8 +14,8 @@
#include "paddle/cinn/ir/schedule_block_graph.h"
#include "paddle/cinn/common/dfs_topo_walker.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
ir
{
...
...
paddle/cinn/ir/schedule_block_graph.h
View file @
01a10755
...
...
@@ -20,11 +20,9 @@
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/hlir/framework/graph.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
using
Group
=
cinn
::
hlir
::
framework
::
Graph
::
Group
;
namespace
cinn
{
namespace
ir
{
...
...
paddle/cinn/ir/tensor.cc
View file @
01a10755
...
...
@@ -16,6 +16,7 @@
#include <cstring>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/axis.h"
...
...
@@ -23,10 +24,10 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/poly/isl_utils.h"
#include "paddle/cinn/poly/stage.h"
...
...
@@ -52,6 +53,67 @@ Tensor _Tensor_::Make(const std::string &name,
return
Tensor
(
n
);
}
Tensor
_Tensor_
::
Make
(
const
std
::
string
&
name
,
Type
dtype
,
const
std
::
vector
<
Expr
>
&
shape
,
const
std
::
vector
<
Expr
>
&
domain
,
const
std
::
vector
<
Var
>
&
reduce_axis
)
{
CHECK
(
!
name
.
empty
())
<<
"Cannot set empty Tensor name in Tensor::Make"
;
auto
n
=
make_shared
<
_Tensor_
>
();
n
->
name
=
name
;
n
->
shape
=
shape
;
n
->
domain
=
domain
;
n
->
reduce_axis
=
reduce_axis
;
n
->
operation
=
PlaceholderOp
::
Make
(
n
->
name
,
n
->
shape
,
Float
(
32
));
n
->
set_type
(
dtype
);
n
->
InitAxis
();
return
Tensor
(
n
);
}
Tensor
_Tensor_
::
Make
(
const
std
::
string
&
name
,
Type
dtype
,
const
std
::
vector
<
Dim
>
&
sym_shape
,
const
std
::
vector
<
Expr
>
&
domain
,
FunctionRef
fn
,
const
std
::
vector
<
Var
>
&
reduce_axis
)
{
CHECK
(
!
name
.
empty
())
<<
"Tensor name is set empty"
;
auto
n
=
make_shared
<
_Tensor_
>
();
n
->
name
=
name
;
n
->
sym_shape
=
sym_shape
;
n
->
shape
.
reserve
(
sym_shape
.
size
());
for
(
int
i
=
0
;
i
<
sym_shape
.
size
();
i
++
)
{
n
->
shape
[
i
]
=
sym_shape
[
i
]
->
dim_expr
;
}
n
->
domain
=
domain
;
n
->
reduce_axis
=
reduce_axis
;
n
->
set_type
(
dtype
);
n
->
operation
=
fn
;
n
->
InitAxis
();
return
Tensor
(
n
);
}
Tensor
_Tensor_
::
Make
(
const
std
::
string
&
name
,
Type
dtype
,
const
std
::
vector
<
Dim
>
&
sym_shape
,
const
std
::
vector
<
Expr
>
&
domain
,
const
std
::
vector
<
Var
>
&
reduce_axis
)
{
CHECK
(
!
name
.
empty
())
<<
"Cannot set empty Tensor name in Tensor::Make"
;
auto
n
=
make_shared
<
_Tensor_
>
();
n
->
name
=
name
;
n
->
sym_shape
=
sym_shape
;
n
->
shape
.
reserve
(
sym_shape
.
size
());
for
(
int
i
=
0
;
i
<
sym_shape
.
size
();
i
++
)
{
n
->
shape
[
i
]
=
sym_shape
[
i
]
->
dim_expr
;
}
n
->
domain
=
domain
;
n
->
reduce_axis
=
reduce_axis
;
n
->
operation
=
PlaceholderOp
::
Make
(
n
->
name
,
n
->
shape
,
Float
(
32
));
n
->
set_type
(
dtype
);
n
->
InitAxis
();
return
Tensor
(
n
);
}
size_t
Tensor
::
ndims
()
const
{
return
operator
->
()
->
shape
.
size
();
}
...
...
@@ -59,7 +121,7 @@ std::set<std::string> _Tensor_::GetDependTensorNames() const {
std
::
set
<
std
::
string
>
names
;
auto
add_depend_tensors_from_expr
=
[
&
](
Expr
expr
)
{
auto
tensors
=
CollectIRNodes
(
expr
,
[
&
](
const
Expr
*
x
)
{
auto
tensors
=
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
name
!=
this
->
name
;
});
for
(
auto
&
e
:
tensors
)
{
...
...
@@ -514,7 +576,7 @@ bool _Tensor_::IsDependOnStatement(absl::string_view statement) {
std
::
set
<
std
::
string
>
_Tensor_
::
DependingTensorNames
()
{
std
::
set
<
std
::
string
>
res
;
if
(
body
().
defined
())
{
auto
depend_tensors
=
ir
::
CollectIRNodes
(
auto
depend_tensors
=
ir
::
ir_utils
::
CollectIRNodes
(
body
(),
[](
const
Expr
*
x
)
->
bool
{
return
x
->
as_tensor
();
});
for
(
const
auto
&
x
:
depend_tensors
)
{
if
(
x
.
get
()
!=
this
)
{
...
...
@@ -537,7 +599,7 @@ std::vector<Var> _Tensor_::axis_with_reduce() const {
}
bool
_Tensor_
::
Uses
(
const
Tensor
&
other
)
const
{
auto
loads
=
ir
::
CollectIRNodes
(
body
(),
[
&
](
const
Expr
*
x
)
{
auto
loads
=
ir
::
ir_utils
::
CollectIRNodes
(
body
(),
[
&
](
const
Expr
*
x
)
{
auto
*
loadn
=
x
->
As
<
ir
::
Load
>
();
if
(
!
loadn
)
return
false
;
return
loadn
->
tensor
.
as_tensor
()
->
name
==
other
->
name
;
...
...
paddle/cinn/ir/tensor.h
View file @
01a10755
...
...
@@ -25,36 +25,23 @@
#include <utility>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/dim.h"
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/lang/buffer.h"
#include "paddle/cinn/poly/stage.h"
namespace
cinn
{
namespace
ir
{
class
Tensor
;
}
// namespace ir
namespace
lang
{
template
<
typename
T
>
struct
Placeholder
;
void
InitReduceTensor
(
poly
::
StageMap
stages
,
const
ir
::
Tensor
&
tensor
,
const
Target
&
target
=
common
::
DefaultHostTarget
());
}
// namespace lang
namespace
ast_gen_ius
{
class
TensorGroup
;
}
// namespace ast_gen_ius
namespace
ir
{
namespace
detail
{
constexpr
bool
LE
(
int
a
,
int
b
)
{
return
a
<=
b
;
}
constexpr
bool
GE
(
int
a
,
int
b
)
{
return
a
>=
b
;
}
}
// namespace detail
class
_Tensor_
;
class
Tensor
;
class
Tensor
:
public
ir
::
IrNodeRef
{
public:
...
...
@@ -84,8 +71,8 @@ class Tensor : public ir::IrNodeRef {
return
operator
()(
std
::
vector
<
Expr
>
({
a
}));
}
template
<
typename
...
Args
>
inline
typename
std
::
enable_if
<
detail
::
GE
(
sizeof
...(
Args
)
,
2
)
,
Expr
>::
type
operator
()(
Args
&&
...
args
)
const
{
inline
typename
std
::
enable_if
<
sizeof
...(
Args
)
>=
2
,
Expr
>::
type
operator
()(
Args
&&
...
args
)
const
{
return
operator
()({
std
::
forward
<
Args
>
(
args
)...});
}
// @}
...
...
@@ -135,6 +122,8 @@ struct WriteCacheRelation;
*/
class
_Tensor_
:
public
ExprNode
<
_Tensor_
>
{
public:
//! Symbolic Shape of this tensor(buffer).
std
::
vector
<
Dim
>
sym_shape
;
//! Shape of this tensor(buffer).
std
::
vector
<
Expr
>
shape
;
//! The domain of each axis(without reduce_axis)
...
...
@@ -163,6 +152,28 @@ class _Tensor_ : public ExprNode<_Tensor_> {
FunctionRef
fn
,
const
std
::
vector
<
Var
>&
reduce_axis
=
{});
// Manual tensor construction, no FunctionRef information
static
Tensor
Make
(
const
std
::
string
&
name
,
Type
dtype
,
const
std
::
vector
<
Expr
>&
shape
,
const
std
::
vector
<
Expr
>&
domain
,
const
std
::
vector
<
Var
>&
reduce_axis
=
{});
//! (Symbolic Shape) Generate a tensor from a function.
static
Tensor
Make
(
const
std
::
string
&
name
,
Type
dtype
,
const
std
::
vector
<
Dim
>&
sym_shape
,
const
std
::
vector
<
Expr
>&
domain
,
FunctionRef
fn
,
const
std
::
vector
<
Var
>&
reduce_axis
=
{});
// (Symbolic Shape) Manual tensor construction, no FunctionRef information
static
Tensor
Make
(
const
std
::
string
&
name
,
Type
dtype
,
const
std
::
vector
<
Dim
>&
sym_shape
,
const
std
::
vector
<
Expr
>&
domain
,
const
std
::
vector
<
Var
>&
reduce_axis
=
{});
void
Verify
()
const
override
;
bool
IsReduceInited
(
poly
::
StageMap
stages
)
const
;
...
...
@@ -288,12 +299,6 @@ class _Tensor_ : public ExprNode<_Tensor_> {
poly
::
StageMap
stages
,
const
Target
&
target
=
common
::
DefaultHostTarget
())
const
;
private:
//! Initialize the axis field after the shape field is assigned.
void
InitAxis
()
const
;
isl
::
set
GenerateIslDomain
()
const
;
/**
* Create the initialization tensor.
* @param stages The stages.
...
...
@@ -304,15 +309,17 @@ class _Tensor_ : public ExprNode<_Tensor_> {
poly
::
StageMap
stages
,
const
Target
&
target
=
common
::
DefaultHostTarget
())
const
;
private:
//! Initialize the axis field after the shape field is assigned.
void
InitAxis
()
const
;
isl
::
set
GenerateIslDomain
()
const
;
//! The names of the tensors depend the same buffer and should schedule before
//! this.
std
::
set
<
std
::
string
>
buffer_depended_tensor_names_
;
friend
Shared
<
poly
::
Stage
>
CreateStage
(
Tensor
tensor
);
friend
void
lang
::
InitReduceTensor
(
poly
::
StageMap
stages
,
const
ir
::
Tensor
&
tensor
,
const
Target
&
target
);
};
Shared
<
poly
::
Stage
>
CreateStage
(
Tensor
tensor
);
...
...
Prev
1
…
18
19
20
21
22
23
24
25
26
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment