Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
992bec46
Commit
992bec46
authored
Oct 08, 2023
by
“yuguo”
Browse files
2.5
parent
0259837d
Changes
357
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3517 additions
and
0 deletions
+3517
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
...n/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
+246
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
...nn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
+76
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
...o_schedule/search_space/auto_gen_rule/auto_inline_test.cc
+527
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc
...n/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc
+136
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h
...nn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h
+57
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
...o_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
+122
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc
...uto_schedule/search_space/auto_gen_rule/mix_rules_test.cc
+69
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc
...schedule/search_space/auto_gen_rule/multi_level_tiling.cc
+460
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h
..._schedule/search_space/auto_gen_rule/multi_level_tiling.h
+144
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
...ule/search_space/auto_gen_rule/multi_level_tiling_test.cc
+575
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc
...inn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc
+38
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h
...cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h
+49
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc
...uto_schedule/search_space/auto_gen_rule/skip_rule_test.cc
+126
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
...n/auto_schedule/search_space/auto_gen_rule/test_helper.cc
+273
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h
...nn/auto_schedule/search_space/auto_gen_rule/test_helper.h
+98
-0
paddle/cinn/auto_schedule/search_space/block_sampler.cc
paddle/cinn/auto_schedule/search_space/block_sampler.cc
+109
-0
paddle/cinn/auto_schedule/search_space/block_sampler.h
paddle/cinn/auto_schedule/search_space/block_sampler.h
+126
-0
paddle/cinn/auto_schedule/search_space/block_sampler_test.cc
paddle/cinn/auto_schedule/search_space/block_sampler_test.cc
+78
-0
paddle/cinn/auto_schedule/search_space/rule_sampler.cc
paddle/cinn/auto_schedule/search_space/rule_sampler.cc
+85
-0
paddle/cinn/auto_schedule/search_space/rule_sampler.h
paddle/cinn/auto_schedule/search_space/rule_sampler.h
+123
-0
No files found.
Too many changes to show.
To preserve performance only
357 of 357+
files are displayed.
Plain diff
Email patch
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include <memory>
#include <set>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
auto_schedule
{
AutoInline
::
AutoInline
(
const
common
::
Target
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_inline_output_names
)
:
AutoGenRule
(
target
),
no_inline_output_names_
(
no_inline_output_names
)
{}
bool
AutoInline
::
CanInlineIntoConsumer
(
const
Expr
&
sche_block_realize_expr
,
ir
::
IRSchedule
*
ir_sch
)
const
{
const
ir
::
ScheduleBlockRealize
*
sche_block_realize
=
sche_block_realize_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
const
ir
::
ScheduleBlock
*
sche_block
=
sche_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
ir
::
Expr
compute_body
=
sche_block
->
body
;
ir
::
Expr
root
=
ir_sch
->
GetRootBlock
(
sche_block_realize_expr
);
// Check the schedule block to be inlined is not a reduce tensor.
std
::
set
<
ir
::
Expr
>
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
if
(
find_store
.
size
()
!=
1UL
)
{
return
false
;
}
ir
::
Expr
tensor_expr
=
(
*
find_store
.
begin
()).
As
<
ir
::
Store
>
()
->
tensor
;
ir
::
Tensor
tensor
=
tensor_expr
.
as_tensor_ref
();
if
(
tensor
->
is_reduce_tensor
())
{
return
false
;
}
// LoweredFunc output can be tensor name or tensor buffer name
if
(
no_inline_output_names_
.
find
(
tensor
->
name
)
!=
no_inline_output_names_
.
end
()
||
no_inline_output_names_
.
find
(
tensor
->
buffer
->
name
)
!=
no_inline_output_names_
.
end
())
{
return
false
;
}
// write_buffers.size() = 1 and read_buffers is empty, means const
// we can inline to consumer
if
(
sche_block
->
read_buffers
.
empty
())
{
return
true
;
}
// Check this schedule block is the only writer of the tensor.
find_store
=
ir
::
CollectIRNodesWithoutTensor
(
root
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
&&
(
x
->
As
<
ir
::
Store
>
()
->
tensor
).
as_tensor_ref
()
->
name
==
tensor
->
name
;
});
if
(
find_store
.
size
()
!=
1UL
)
{
return
false
;
}
// Check there is no overlap between the buffers the schedule block reads and
// writes.
std
::
set
<
ir
::
Expr
>
find_load
=
ir
::
CollectIRNodesWithoutTensor
(
compute_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Load
>
()
&&
x
->
As
<
ir
::
Load
>
()
->
tensor
==
tensor_expr
;
});
if
(
!
find_load
.
empty
())
{
return
false
;
}
ir
::
Expr
store
=
*
(
find_store
.
begin
());
ir
::
ComputeInliner
inliner
(
store
.
As
<
ir
::
Store
>
()
->
tensor
.
as_tensor_ref
(),
store
);
if
(
!
inliner
.
BodyPatternAllowInline
())
{
return
false
;
}
ir
::
LeafBlockRemovalPlan
remove_plan
(
sche_block_realize_expr
,
&
inliner
.
src_stmt
,
&
inliner
.
tgt_stmt
);
remove_plan
(
&
root
);
if
(
!
inliner
.
src_stmt
.
defined
()
||
!
inliner
.
tgt_stmt
.
defined
())
{
return
false
;
}
VLOG
(
6
)
<<
"Found store Expr "
<<
store
<<
", which CanInlineIntoConsumer"
;
return
true
;
}
AutoInlineType
AutoInline
::
AnalyzeInlineType
(
const
Expr
&
sche_block_realize_expr
,
ir
::
IRSchedule
*
ir_sch
)
const
{
const
ir
::
ScheduleBlockRealize
*
sche_block_realize
=
sche_block_realize_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
const
ir
::
ScheduleBlock
*
sche_block
=
sche_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
// Inline if the block has only 1 write buffer
if
(
sche_block
->
write_buffers
.
size
()
!=
1
)
{
return
AutoInlineType
::
kCannotInline
;
}
std
::
unordered_set
<
ir
::
IrNodeTy
>
no_inline_node_types
=
{
ir
::
IrNodeTy
::
IfThenElse
};
if
(
ContainsNodeType
(
sche_block
->
body
,
no_inline_node_types
))
{
return
AutoInlineType
::
kCannotInline
;
}
// InlineIntoConsumer other than above situations
if
(
CanInlineIntoConsumer
(
sche_block_realize_expr
,
ir_sch
))
{
return
AutoInlineType
::
kInlineIntoConsumer
;
}
// TODO(zhhsplendid): We don't have ReverseComputeInline in IRSchedule now,
// so we just do kInlineIntoConsumer here. Add CanInlineIntoProducer
// once ReverseComputeInline is ready.
return
AutoInlineType
::
kCannotInline
;
}
RuleApplyType
AutoInline
::
Init
(
ir
::
IRSchedule
*
ir_schedule
)
{
ir_schedule_
=
ir_schedule
;
all_block_realizes_
=
ir_schedule_
->
GetAllBlocks
();
apply_indices_and_type_
.
clear
();
num_applicable_
=
0
;
for
(
size_t
i
=
0
;
i
<
all_block_realizes_
.
size
();
++
i
)
{
ir
::
ScheduleBlockRealize
*
sche_block_realize
=
all_block_realizes_
[
i
].
As
<
ir
::
ScheduleBlockRealize
>
();
AnalyzeScheduleBlockReadWriteBuffer
(
sche_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
AutoInlineType
type
=
AnalyzeInlineType
(
all_block_realizes_
[
i
],
ir_schedule_
);
if
(
type
!=
AutoInlineType
::
kCannotInline
)
{
++
num_applicable_
;
apply_indices_and_type_
.
push_back
({
i
,
type
});
}
}
return
num_applicable_
>
0
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
void
AutoInline
::
Apply
(
int
index
)
{
CHECK
(
ir_schedule_
!=
nullptr
)
<<
"Run AutoInline::Apply without Init"
;
CHECK
(
num_applicable_
>
0
&&
apply_indices_and_type_
.
size
()
==
num_applicable_
)
<<
"AutoInline::Apply pre-condition doesn't meet"
;
CHECK
(
index
>=
0
&&
num_applicable_
>
index
)
<<
"Invalid index for AutoInline::Apply, the index needs 0 <= index && "
"index < NumberApplicable(), "
<<
"Currently index = "
<<
index
<<
", NumberApplicable() = "
<<
num_applicable_
;
int
apply_index
=
apply_indices_and_type_
[
index
].
first
;
Apply
(
ir_schedule_
,
all_block_realizes_
[
apply_index
]);
return
;
}
std
::
string
AutoInline
::
GetRuleName
()
const
{
return
"AutoInline"
;
}
RuleApplyType
AutoInline
::
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
{
Expr
block_expr
=
state
->
ir_schedule
.
GetBlock
(
block_name
);
auto
*
block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK
(
block_realize
)
<<
"stmt is not a ScheduleBlockRealize:"
<<
block_expr
;
AnalyzeScheduleBlockReadWriteBuffer
(
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
AutoInlineType
type
=
AnalyzeInlineType
(
block_expr
,
&
state
->
ir_schedule
);
return
type
==
AutoInlineType
::
kCannotInline
?
RuleApplyType
::
kCannotApply
:
RuleApplyType
::
kApplyAndPruneOtherRules
;
}
std
::
vector
<
SearchState
>
AutoInline
::
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
{
SearchState
new_state
=
state
.
Copy
();
Expr
block_expr
=
new_state
->
ir_schedule
.
GetBlock
(
block_name
);
Apply
(
&
new_state
->
ir_schedule
,
block_expr
);
return
{
new_state
};
}
void
AutoInline
::
Apply
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
)
{
auto
*
block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK
(
block_realize
)
<<
"stmt is not a ScheduleBlockRealize:"
<<
block_expr
;
AnalyzeScheduleBlockReadWriteBuffer
(
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
AutoInlineType
type
=
AnalyzeInlineType
(
block_expr
,
ir_schedule
);
if
(
type
==
AutoInlineType
::
kInlineIntoConsumer
)
{
VLOG
(
6
)
<<
"Apply ComputeInline on "
<<
block_expr
;
ir_schedule
->
ComputeInline
(
block_expr
);
VLOG
(
6
)
<<
"After ComputeInline: "
<<
block_expr
;
}
else
if
(
type
==
AutoInlineType
::
kInlineIntoProducer
)
{
// TODO(zhhsplendid): We don't have ReverseComputeInline in IRSchedule now,
// so we just do kInlineIntoConsumer here. Add CanInlineIntoConsumer
// once ReverseComputeInline is ready.
// ir_schedule->ReverseComputeInline(all_block_realizes_[apply_index]);
}
// Make sure re-apply the AutoInline won't be error.
// AutoInline changes the read and write buffers of schedule blocks,
// we need to re-analyze
all_block_realizes_
=
ir_schedule
->
GetAllBlocks
();
for
(
size_t
i
=
0
;
i
<
all_block_realizes_
.
size
();
++
i
)
{
ir
::
ScheduleBlockRealize
*
sche_block_realize
=
all_block_realizes_
[
i
].
As
<
ir
::
ScheduleBlockRealize
>
();
ir
::
ScheduleBlock
*
sche_block
=
sche_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
sche_block
->
read_buffers
=
{};
sche_block
->
write_buffers
=
{};
AnalyzeScheduleBlockReadWriteBuffer
(
sche_block
);
}
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* The types of the AutoInline
*/
enum
class
AutoInlineType
:
int
{
// The block cannot be inlined
kCannotInline
=
0
,
// Inline this block into the consumer
kInlineIntoConsumer
,
// Inline this block into the producer
kInlineIntoProducer
,
};
class
AutoInline
:
public
AutoGenRule
{
public:
AutoInline
(
const
common
::
Target
&
target
,
const
std
::
unordered_set
<
std
::
string
>&
no_inline_output_names
);
~
AutoInline
()
=
default
;
RuleApplyType
Init
(
ir
::
IRSchedule
*
ir_schedule
)
override
;
void
Apply
(
int
index
)
override
;
std
::
string
GetRuleName
()
const
override
;
AutoInlineType
AnalyzeInlineType
(
const
Expr
&
sche_block_realize_expr
,
ir
::
IRSchedule
*
ir_sch
)
const
;
bool
CanInlineIntoConsumer
(
const
Expr
&
sche_block_realize_expr
,
ir
::
IRSchedule
*
ir_sch
)
const
;
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
override
;
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
;
private:
void
Apply
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
);
// NOLINT
private:
std
::
vector
<
ir
::
Expr
>
all_block_realizes_
;
std
::
vector
<
std
::
pair
<
int
,
AutoInlineType
>>
apply_indices_and_type_
;
std
::
unordered_set
<
std
::
string
>
no_inline_output_names_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_inline.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/function_base.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/runtime/flags.h"
#include "paddle/cinn/utils/string.h"
#include "test/cpp/cinn/concrete_program_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
hlir
::
framework
::
Graph
;
using
::
cinn
::
hlir
::
framework
::
OpLowerer
;
TEST
(
AutoInline
,
SingleLoopInline
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
Target
target
=
common
::
DefaultHostTarget
();
Expr
M
(
32
);
Placeholder
<
float
>
A
(
"A"
,
{
M
});
ir
::
Tensor
B
=
Compute
(
{
M
},
[
&
](
Var
i
)
{
return
A
(
i
)
*
ir
::
Expr
(
2.
f
);
},
"B"
);
ir
::
Tensor
C
=
Compute
(
{
M
},
[
&
](
Var
i
)
{
return
B
(
i
)
+
ir
::
Expr
(
1.
f
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
A
,
B
,
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestAutoInline_SingleLoopInline"
,
stages
,
{
A
,
C
},
{},
{},
nullptr
,
target
,
true
);
VLOG
(
6
)
<<
"Expr after lowering:"
;
VLOG
(
6
)
<<
funcs
[
0
]
->
body
;
/*
* We have to use ComputeAt to put two Tensor loops together to create IR
* test case for AutoInline.
*/
ir
::
IRSchedule
ir_sch
(
ir
::
ModuleExpr
(
std
::
vector
<
ir
::
Expr
>
{
funcs
[
0
]
->
body
}));
SearchState
state
(
ir_sch
,
0
,
{});
ir
::
Expr
block_b
=
ir_sch
.
GetBlock
(
"B"
);
std
::
vector
<
ir
::
Expr
>
loops
=
ir_sch
.
GetLoops
(
"C"
);
ir_sch
.
ComputeAt
(
block_b
,
loops
[
0
]);
ir
::
ModuleExpr
mod_expr_before_inline
=
ir_sch
.
GetModule
();
VLOG
(
6
)
<<
"Expr after ComputeAt:"
;
VLOG
(
6
)
<<
mod_expr_before_inline
.
GetExprs
()[
0
];
AutoInline
auto_inline
(
target
,
{
"C"
});
EXPECT_EQ
(
auto_inline
.
Init
(
&
ir_sch
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
EXPECT_EQ
(
auto_inline
.
NumberApplicable
(),
1
);
auto_inline
.
ApplyRandomly
();
std
::
vector
<
ir
::
Expr
>
exprs
=
ir_sch
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
// ApplyOnBlock
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
state
,
"B"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"B"
);
auto
test_func
=
[](
ir
::
IRSchedule
*
ir_sch
)
{
ir
::
ModuleExpr
mod_expr_after_inline
=
ir_sch
->
GetModule
();
std
::
vector
<
ir
::
Expr
>
exprs
=
mod_expr_after_inline
.
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
std
::
stringstream
ss
;
ss
<<
exprs
[
0
];
std
::
string
expr_str
=
ss
.
str
();
VLOG
(
6
)
<<
"After AutoInline:"
;
VLOG
(
6
)
<<
expr_str
;
std
::
string
target_str
=
R"ROC({
ScheduleBlock(root)
{
{
serial for (i, 0, 32)
{
ScheduleBlock(C)
{
i0 = axis.bind(i)
read_buffers(_A[i0(0:32)])
write_buffers(_C[i0(0:32)])
C[i0] = ((A[i0] * 2.00000000f) + 1.00000000f)
}
}
}
}
})ROC"
;
EXPECT_EQ
(
expr_str
,
target_str
);
};
test_func
(
&
ir_sch
);
test_func
(
&
new_states
[
0
]
->
ir_schedule
);
// Cannot inline above expr again
EXPECT_EQ
(
auto_inline
.
Init
(
&
ir_sch
),
RuleApplyType
::
kCannotApply
);
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
new_states
[
0
],
"C"
),
RuleApplyType
::
kCannotApply
);
}
TEST
(
AutoInline
,
AddReluInline
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
Target
target
=
common
::
DefaultHostTarget
();
frontend
::
NetBuilder
builder
(
"test"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
1
,
64
,
112
,
112
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
64
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
,
1
);
auto
d
=
builder
.
Relu
(
c
);
frontend
::
Program
program
=
builder
.
Build
();
auto
graph
=
std
::
make_shared
<
Graph
>
(
program
,
target
);
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
const
auto
&
dtype_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
auto
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
dtype_dict
,
shape_dict
,
target
);
EXPECT_EQ
(
graph
->
fusion_groups
.
size
(),
1UL
);
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
op_lowerer
->
Lower
(
graph
->
fusion_groups
[
0
],
/*apply_op_schedule = */
false
,
/*apply_group_schedule=*/
false
);
VLOG
(
6
)
<<
"Expr before auto inline: "
<<
funcs
[
0
]
->
body
;
ir
::
ModuleExpr
mod_expr_before_inline
(
std
::
vector
<
Expr
>
({
funcs
[
0
]
->
body
}));
ir
::
IRSchedule
ir_sch
(
mod_expr_before_inline
);
SearchState
state
(
ir_sch
,
0
,
{});
AutoInline
auto_inline
(
target
,
{
"var_2"
});
EXPECT_EQ
(
auto_inline
.
Init
(
&
ir_sch
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
EXPECT_EQ
(
auto_inline
.
NumberApplicable
(),
2
);
auto_inline
.
Apply
(
1
);
ir
::
ModuleExpr
mod_expr_after_inline
=
ir_sch
.
GetModule
();
std
::
vector
<
ir
::
Expr
>
exprs
=
mod_expr_after_inline
.
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
std
::
stringstream
ss
;
ss
<<
exprs
[
0
];
std
::
string
expr_str
=
ss
.
str
();
VLOG
(
6
)
<<
"After AutoInline:"
;
VLOG
(
6
)
<<
expr_str
;
// Auto Inline again
EXPECT_EQ
(
auto_inline
.
Init
(
&
ir_sch
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
EXPECT_EQ
(
auto_inline
.
NumberApplicable
(),
1
);
auto_inline
.
Apply
(
0
);
// ApplyOnBlock
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
state
,
"var_1"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"var_1"
);
// Auto Inline again
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
new_states
[
0
],
"var_3"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
new_states
=
auto_inline
.
ApplyOnBlock
(
new_states
[
0
],
"var_3"
);
auto
test_func
=
[](
ir
::
IRSchedule
*
ir_sch
)
{
ir
::
ModuleExpr
final_mod_expr
=
ir_sch
->
GetModule
();
auto
exprs
=
final_mod_expr
.
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
std
::
stringstream
ss
;
ss
<<
exprs
[
0
];
std
::
string
expr_str
=
ss
.
str
();
VLOG
(
6
)
<<
"Final AutoInline:"
;
VLOG
(
6
)
<<
expr_str
;
std
::
string
target_str
=
R"ROC({
ScheduleBlock(root)
{
{
serial for (i, 0, 1)
{
serial for (j, 0, 64)
{
serial for (k, 0, 112)
{
serial for (a, 0, 112)
{
ScheduleBlock(var_2)
{
i0, i1, i2, i3 = axis.bind(0, j, k, a)
read_buffers(_A[i0(0:1), i1(0:64), i2(0:112), i3(0:112)], _B[i1(0:64)])
write_buffers(_var_2[i0(0:1), i1(0:64), i2(0:112), i3(0:112)])
var_2[i0, i1, i2, i3] = cinn_max((A[i0, i1, i2, i3] + B[i1]), 0.00000000f)
}
}
}
}
}
}
}
})ROC"
;
EXPECT_EQ
(
expr_str
,
target_str
);
};
test_func
(
&
ir_sch
);
test_func
(
&
new_states
[
0
]
->
ir_schedule
);
// Cannot inline above expr again
EXPECT_EQ
(
auto_inline
.
Init
(
&
ir_sch
),
RuleApplyType
::
kCannotApply
);
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
new_states
[
0
],
"var_2"
),
RuleApplyType
::
kCannotApply
);
}
#ifdef CINN_WITH_CUDA
class
TestAutoInline
:
public
TestAutoGenRuleBase
{};
/* The single chain graph composed of multiple blocks can be inlined into one.
*
* Before AutoInline: The output of the previous block is the input of another
* block. Loop1: x1 = Add() Loop2: x2 = Multiply(x1) Loop3: x3 = Add(x2) Loop4:
* x4 = Relu(x3)
*
* After AutoInline: All loops are inlined into a loop.
* Loop:
* Add(Multiply(Add(Relu())))
*/
TEST_F
(
TestAutoInline
,
SingleChain
)
{
Target
target
=
common
::
DefaultNVGPUTarget
();
Initialize
(
target
);
std
::
vector
<
std
::
string
>
input_names
=
{
"bias"
,
"conv_output"
,
"bn_scale"
,
"bn_offset"
};
std
::
vector
<
std
::
string
>
output_names
=
{
"var_6"
,
"var_5"
,
"var_1"
,
"var"
,
"var_0"
,
"var_4"
,
"var_3"
};
std
::
vector
<
int32_t
>
conv_output_shape
=
{
1
,
512
,
56
,
56
};
int32_t
channel
=
conv_output_shape
[
1
];
std
::
vector
<
tests
::
VariableInfo
>
inputs_varinfo
(
{{
"conv_output"
,
conv_output_shape
},
{
"bias"
,
{
channel
,
1
,
1
}},
{
"bn_scale"
,
{
channel
,
1
,
1
}},
{
"bn_offset"
,
{
channel
,
1
,
1
}}});
// Construct the computation graph and convert it to ir::Expr
Context
::
Global
().
ResetNameId
();
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
tests
::
BiasBnReLUBuilder
().
Build
(
inputs_varinfo
));
SearchState
state
(
ir_schedule
,
0
,
{});
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// Apply AutoInline for every block that can be inline
AutoInline
auto_inline
(
target_
,
{
output_names
.
front
()});
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
state
,
"var_3"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"var_3"
);
std
::
vector
<
std
::
string
>
inline_block_names
(
{
"var_4"
,
"var_5"
,
"var_6"
,
"var"
,
"var_0"
,
"var_1"
});
for
(
const
auto
&
inline_block_name
:
inline_block_names
)
{
new_states
=
auto_inline
.
ApplyOnBlock
(
new_states
[
0
],
inline_block_name
);
}
std
::
vector
<
ir
::
Expr
>
exprs
=
new_states
[
0
]
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Expr after AutoInline applied on block: "
<<
exprs
[
0
];
// build ir::Module and debug source code
auto
build_module_auto
=
BuildIRModule
(
new_states
[
0
]
->
ir_schedule
);
auto
build_module_manually
=
BuildIRModule
(
MakeIRSchedule
(
tests
::
BiasBnReLUBuilder
().
Build
(
inputs_varinfo
),
-
1
,
true
));
auto
source_code_auto
=
GenSourceCode
(
build_module_auto
);
VLOG
(
6
)
<<
" auto-schedule source code:
\n
"
<<
source_code_auto
;
auto
source_code_manually
=
GenSourceCode
(
build_module_manually
);
VLOG
(
6
)
<<
" manually-schedule source code:
\n
"
<<
source_code_manually
;
CheckResult
(
GenExecutableKernel
(
build_module_auto
),
GenExecutableKernel
(
build_module_manually
),
input_names
,
output_names
,
{{
conv_output_shape
[
1
],
1
,
1
},
conv_output_shape
,
conv_output_shape
,
conv_output_shape
},
{
conv_output_shape
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
},
{
1
}},
target
);
}
/* An op can be inlined into multiple consumers at the same time.
*
* Before AutoInline: The output of Exp is used by Add and Multiply.
* Loop1:
* x = Exp()
* Loop2:
* y = Add(x)
* Loop3:
* z = Multiply(x)
*
* After AutoInline: Exp is inlined into Add and Multiply.
* Loop:
* y = Add(Exp())
* z = Multiply(Exp())
*/
TEST_F
(
TestAutoInline
,
InlineToMultiConsumers
)
{
Target
target
=
common
::
DefaultNVGPUTarget
();
Initialize
(
target
);
std
::
vector
<
std
::
string
>
input_names
=
{
"x"
};
std
::
vector
<
std
::
string
>
output_names
=
{
"var_2"
,
"var_1"
,
"var_0"
};
std
::
vector
<
int32_t
>
input_shape
{
256
,
256
};
std
::
vector
<
tests
::
VariableInfo
>
inputs_varinfo
({{
"x"
,
input_shape
}});
// Construct the computation graph and convert it to ir::Expr
Context
::
Global
().
ResetNameId
();
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
tests
::
ExpTwoConsumersOpBuilder
().
Build
(
inputs_varinfo
));
SearchState
state
(
ir_schedule
,
0
,
{});
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// Apply AutoInline for every block that can be inline
AutoInline
auto_inline
(
target_
,
{
output_names
.
front
()});
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
state
,
"var_0"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"var_1"
);
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"var_0"
);
std
::
vector
<
ir
::
Expr
>
exprs
=
new_states
[
0
]
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Expr after AutoInline applied on block: "
<<
exprs
[
0
];
// build ir::Module and debug source code
auto
build_module_auto
=
BuildIRModule
(
new_states
[
0
]
->
ir_schedule
);
auto
build_module_manually
=
BuildIRModule
(
MakeIRSchedule
(
tests
::
ExpTwoConsumersOpBuilder
().
Build
(
inputs_varinfo
),
-
1
,
true
));
auto
source_code_auto
=
GenSourceCode
(
build_module_auto
);
VLOG
(
6
)
<<
" auto-schedule source code:
\n
"
<<
source_code_auto
;
auto
source_code_manually
=
GenSourceCode
(
build_module_manually
);
VLOG
(
6
)
<<
" manually-schedule source code:
\n
"
<<
source_code_manually
;
CheckResult
(
GenExecutableKernel
(
build_module_auto
),
GenExecutableKernel
(
build_module_manually
),
input_names
,
output_names
,
{
input_shape
},
{
input_shape
,
{
1
},
{
1
}},
target
);
}
/* Operators of type elementwise or injective can all be inlined.
*
* Before AutoInline: A graph of Gather, Add and Subtract
* Loop1:
* x1 = Gather()
* Loop2:
* x2 = Add(x1)
* Loop3:
* y1 = Gather()
* Loop4:
* z1 = Subtract(y1, x1)
*
* After AutoInline: All loops are inlined to one
* z1 = Subtract(Gather(), Add(Gather()))
*/
TEST_F
(
TestAutoInline
,
OnlySpatialOp
)
{
Target
target
=
common
::
DefaultNVGPUTarget
();
Initialize
(
target
);
std
::
vector
<
std
::
string
>
input_names
=
{
"x"
,
"y"
};
std
::
vector
<
std
::
string
>
output_names
=
{
"var_6"
,
"var_4"
,
"constant_idx_last"
,
"constant_idx_first"
,
"var_2"
,
"var_5"
};
std
::
vector
<
int32_t
>
input_shape
{
256
,
256
};
std
::
vector
<
tests
::
VariableInfo
>
inputs_varinfo
(
{{
"x"
,
input_shape
},
{
"y"
,
input_shape
}});
// Construct the computation graph and convert it to ir::Expr
Context
::
Global
().
ResetNameId
();
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
tests
::
GatherAddSubBuilder
().
Build
(
inputs_varinfo
));
SearchState
state
(
ir_schedule
,
0
,
{});
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// Apply AutoInline for every block that can be inline
AutoInline
auto_inline
(
target_
,
{
output_names
.
front
()});
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
state
,
"constant_idx_first"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"constant_idx_first"
);
std
::
vector
<
std
::
string
>
inline_block_names
(
{
"constant_idx_last"
,
"var_2"
,
"var_5"
,
"var_4"
});
for
(
const
auto
&
inline_block_name
:
inline_block_names
)
{
new_states
=
auto_inline
.
ApplyOnBlock
(
new_states
[
0
],
inline_block_name
);
}
std
::
vector
<
ir
::
Expr
>
exprs
=
new_states
[
0
]
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Expr after AutoInline applied on block: "
<<
exprs
[
0
];
// build ir::Module and debug source code
auto
build_module_auto
=
BuildIRModule
(
new_states
[
0
]
->
ir_schedule
);
auto
build_module_manually
=
BuildIRModule
(
MakeIRSchedule
(
tests
::
GatherAddSubBuilder
().
Build
(
inputs_varinfo
),
-
1
,
true
));
auto
source_code_auto
=
GenSourceCode
(
build_module_auto
);
VLOG
(
6
)
<<
" auto-schedule source code:
\n
"
<<
source_code_auto
;
auto
source_code_manually
=
GenSourceCode
(
build_module_manually
);
VLOG
(
6
)
<<
" manually-schedule source code:
\n
"
<<
source_code_manually
;
CheckResult
(
GenExecutableKernel
(
build_module_auto
),
GenExecutableKernel
(
build_module_manually
),
input_names
,
output_names
,
{
input_shape
,
input_shape
},
{
input_shape
,
{
1
},
{
1
},
{
1
},
{
1
},
{
1
}},
target
);
}
/* An op that does not read data can be directly inlined.
*
* Before AutoInline: fill_constant op is in a separate loop.
* Loop1:
* x = fill_constant()
* Loop2:
* y = Add(x)
*
* After AutoInline: fill_constant op is inlined into other loop
* Loop:
* y = Add(fill_constant())
*/
TEST_F
(
TestAutoInline
,
NoReadBufferOp
)
{
Target
target
=
common
::
DefaultNVGPUTarget
();
Initialize
(
target
);
std
::
vector
<
std
::
string
>
input_names
=
{
"x"
};
std
::
vector
<
std
::
string
>
output_names
=
{
"var_0"
,
"fill_constant"
};
std
::
vector
<
int32_t
>
input_shape
{
256
,
256
};
std
::
vector
<
tests
::
VariableInfo
>
inputs_varinfo
({{
"x"
,
input_shape
}});
// Construct the computation graph and convert it to ir::Expr
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
tests
::
FillConstantAddBuilder
().
Build
(
inputs_varinfo
));
SearchState
state
(
ir_schedule
,
0
,
{});
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// Apply AutoInline for every block that can be inline
AutoInline
auto_inline
(
target_
,
{
output_names
.
front
()});
EXPECT_EQ
(
auto_inline
.
AnalyseApplyType
(
state
,
"fill_constant"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
auto_inline
.
ApplyOnBlock
(
state
,
"fill_constant"
);
std
::
vector
<
ir
::
Expr
>
exprs
=
new_states
[
0
]
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Expr after AutoInline applied on block: "
<<
exprs
[
0
];
// build ir::Module and debug source code
auto
build_module_auto
=
BuildIRModule
(
new_states
[
0
]
->
ir_schedule
);
auto
build_module_manually
=
BuildIRModule
(
MakeIRSchedule
(
tests
::
FillConstantAddBuilder
().
Build
(
inputs_varinfo
),
-
1
,
true
));
auto
source_code_auto
=
GenSourceCode
(
build_module_auto
);
VLOG
(
6
)
<<
" auto-schedule source code:
\n
"
<<
source_code_auto
;
auto
source_code_manually
=
GenSourceCode
(
build_module_manually
);
VLOG
(
6
)
<<
" manually-schedule source code:
\n
"
<<
source_code_manually
;
CheckResult
(
GenExecutableKernel
(
build_module_auto
),
GenExecutableKernel
(
build_module_manually
),
input_names
,
output_names
,
{
input_shape
},
{
input_shape
,
{
1
}},
target
);
}
/* An op can be inlined into multiple producers at the same time.
*/
// TEST_F(TestAutoInline, InlineToMultiProducers) {
// TODO(6clc): Complete the unit test, once ReverseComputeInline is ready.
// }
#endif
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include <glog/logging.h>
#include <cstdlib>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
auto_schedule
{
static
std
::
vector
<
int
>
auto_unroll_options
=
{
0
,
8
,
32
,
128
};
bool
AutoUnroll
::
MeetCondition
(
const
ir
::
ScheduleBlock
*
schedule_block
)
const
{
// whether any block has reduce iter
auto
has_reduce_iter
=
[](
const
Expr
*
x
)
{
auto
*
block_realize
=
x
->
As
<
ir
::
ScheduleBlockRealize
>
();
if
(
block_realize
)
{
auto
*
schedule_block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK
(
schedule_block
)
<<
"schedule_block field is not a ScheduleBlock"
;
for
(
auto
&&
var
:
schedule_block
->
iter_vars
)
{
if
(
var
->
is_reduce_axis
)
{
VLOG
(
6
)
<<
"find ScheduleBlockRealize:"
<<
*
x
<<
" has reduce_axis:"
<<
var
;
return
true
;
}
}
}
return
false
;
};
// whether has any for-loop with non-serial type
auto
has_nonserial_loop
=
[](
const
Expr
*
x
)
{
if
(
x
->
As
<
ir
::
For
>
()
&&
x
->
As
<
ir
::
For
>
()
->
for_type
()
!=
ir
::
ForType
::
Serial
)
{
VLOG
(
6
)
<<
"find non-serial loop:"
<<
*
x
;
return
true
;
}
return
false
;
};
auto
find_target_exprs
=
ir
::
CollectIRNodesWithoutTensor
(
schedule_block
->
body
,
[
&
has_reduce_iter
,
&
has_nonserial_loop
](
const
Expr
*
x
)
{
return
has_reduce_iter
(
x
)
||
has_nonserial_loop
(
x
);
});
return
!
find_target_exprs
.
empty
();
}
RuleApplyType
AutoUnroll
::
Init
(
ir
::
IRSchedule
*
ir_schedule
)
{
ir_schedule_
=
ir_schedule
;
auto
block_realizes
=
ir_schedule_
->
GetAllBlocks
();
// A schedule block can perform `auto_unroll` rule should meet two conditions:
// (1) it is a root block
// (2) MeetCondition returns true with it
applicable_schedule_blocks_
.
clear
();
std
::
set
<
Expr
>
deduplicate_results
;
for
(
size_t
i
=
0
;
i
<
block_realizes
.
size
();
++
i
)
{
// find root block
Expr
root_block
=
ir_schedule_
->
GetRootBlock
(
block_realizes
[
i
]);
auto
*
block_realize
=
root_block
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK
(
block_realize
)
<<
"stmt is not a ScheduleBlockRealize:"
<<
root_block
;
auto
*
schedule_block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK
(
schedule_block
)
<<
"schedule_block field is not a ScheduleBlock:"
<<
Expr
(
block_realize
);
if
(
MeetCondition
(
schedule_block
))
{
deduplicate_results
.
emplace
(
root_block
);
}
}
applicable_schedule_blocks_
=
{
deduplicate_results
.
begin
(),
deduplicate_results
.
end
()};
num_applicable_
=
applicable_schedule_blocks_
.
size
();
VLOG
(
6
)
<<
"Collect applicable_schedule_blocks_:"
<<
num_applicable_
;
return
num_applicable_
>
0
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
void
AutoUnroll
::
Apply
(
int
index
)
{
CHECK_LT
(
index
,
applicable_schedule_blocks_
.
size
())
<<
"invalid apply index:"
<<
index
;
auto
applied_block
=
applicable_schedule_blocks_
.
at
(
index
);
int
max_step
=
auto_unroll_options
[
std
::
rand
()
%
auto_unroll_options
.
size
()];
ir_schedule_
->
Annotate
(
applied_block
,
ir
::
attr
::
auto_unroll_max_step
,
max_step
);
return
;
}
RuleApplyType
AutoUnroll
::
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
{
Expr
block_expr
=
state
->
ir_schedule
.
GetBlock
(
block_name
);
Expr
root_block
=
state
->
ir_schedule
.
GetRootBlock
(
block_expr
);
auto
*
block_realize
=
root_block
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK
(
block_realize
)
<<
"stmt is not a ScheduleBlockRealize:"
<<
root_block
;
auto
*
schedule_block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK
(
schedule_block
)
<<
"schedule_block field is not a ScheduleBlock:"
<<
Expr
(
block_realize
);
return
MeetCondition
(
schedule_block
)
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
std
::
vector
<
SearchState
>
AutoUnroll
::
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
{
SearchState
new_state
=
state
.
Copy
();
Expr
block_expr
=
new_state
->
ir_schedule
.
GetBlock
(
block_name
);
Expr
applied_block
=
new_state
->
ir_schedule
.
GetRootBlock
(
block_expr
);
int
max_step
=
auto_unroll_options
[
std
::
rand
()
%
auto_unroll_options
.
size
()];
new_state
->
ir_schedule
.
Annotate
(
applied_block
,
ir
::
attr
::
auto_unroll_max_step
,
max_step
);
return
{
new_state
};
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
// This rule can be applied in a ScheduleBlock has reduce axis or has loops with
// non-serial type. As a result, it will set a attribute with key named
// ir::attr::auto_unroll_max_step and value indicating max permitted unrolled
// step in the applied ScheduleBlock. Finally, UnrollLoop pass will do unroll
// based on actual situation.
class
AutoUnroll
:
public
AutoGenRule
{
public:
explicit
AutoUnroll
(
const
common
::
Target
&
target
)
:
AutoGenRule
(
target
)
{}
~
AutoUnroll
()
=
default
;
RuleApplyType
Init
(
ir
::
IRSchedule
*
init_schedule
)
override
;
void
Apply
(
int
index
)
override
;
std
::
string
GetRuleName
()
const
override
{
return
"AutoUnroll"
;
}
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
override
;
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
;
private:
bool
MeetCondition
(
const
ir
::
ScheduleBlock
*
schedule_block
)
const
;
private:
std
::
vector
<
Expr
>
applicable_schedule_blocks_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_unroll.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/lower.h"
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
AutoUnroll
,
Init
)
{
using
namespace
ir
;
// NOLINT
Expr
M
(
100
);
Expr
N
(
4
);
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
});
Placeholder
<
float
>
B
(
"B"
,
{
M
,
N
});
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
)
*
B
(
i
,
j
);
},
"C"
);
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
auto
stages
=
CreateStages
({
C
});
auto
funcs
=
cinn
::
lang
::
LowerVec
(
"test_init"
,
stages
,
{
A
,
B
,
C
},
{},
{},
nullptr
,
target
,
true
);
auto
ast_expr
=
funcs
[
0
]
->
body
;
ir
::
IRSchedule
init_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
AutoUnroll
test_rule
(
target
);
// not meet specific condition
ASSERT_EQ
(
test_rule
.
Init
(
&
init_schedule
),
RuleApplyType
::
kCannotApply
);
}
TEST
(
AutoUnroll
,
UnrollableApply
)
{
using
namespace
ir
;
// NOLINT
Expr
M
(
100
);
Expr
N
(
4
);
Expr
K
(
32
);
Placeholder
<
float
>
A
(
"A"
,
{
M
,
K
});
Placeholder
<
float
>
B
(
"B"
,
{
K
,
N
});
Var
k
(
K
.
as_int32
(),
"k0"
);
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
ReduceSum
(
A
(
i
,
k
)
*
B
(
k
,
j
),
{
k
});
},
"C"
);
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
auto
stages
=
CreateStages
({
C
});
auto
funcs
=
cinn
::
lang
::
LowerVec
(
"test_unrollable"
,
stages
,
{
A
,
B
,
C
},
{},
{},
nullptr
,
target
,
true
);
auto
ast_expr
=
funcs
[
0
]
->
body
;
auto
*
init_block_realize
=
ast_expr
.
As
<
ir
::
Block
>
()
->
stmts
.
front
().
As
<
ir
::
ScheduleBlockRealize
>
();
auto
*
init_schedule_block
=
init_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
ASSERT_NE
(
init_schedule_block
,
nullptr
);
ASSERT_TRUE
(
init_schedule_block
->
attrs
.
empty
());
VLOG
(
6
)
<<
"Before auto-unroll:
\n
"
<<
ast_expr
;
AutoUnroll
test_rule
(
target
);
ir
::
IRSchedule
ir_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
SearchState
state
(
ir_schedule
,
0
,
{});
ASSERT_EQ
(
test_rule
.
Init
(
&
ir_schedule
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
EXPECT_EQ
(
test_rule
.
NumberApplicable
(),
1
);
test_rule
.
ApplyRandomly
();
// ApplyOnBlock
EXPECT_EQ
(
test_rule
.
AnalyseApplyType
(
state
,
"C"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
std
::
vector
<
cinn
::
auto_schedule
::
SearchState
>
states
=
test_rule
.
ApplyOnBlock
(
state
,
"C"
);
auto
test_func
=
[](
IRSchedule
*
ir_sch
)
{
Expr
applied_expr
=
ir_sch
->
GetModule
().
GetExprs
().
front
();
auto
*
applied_block_realize
=
applied_expr
.
As
<
ir
::
Block
>
()
->
stmts
.
front
()
.
As
<
ir
::
ScheduleBlockRealize
>
();
auto
*
applied_schedule_block
=
applied_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
ASSERT_FALSE
(
applied_schedule_block
->
attrs
.
empty
());
EXPECT_EQ
(
applied_schedule_block
->
attrs
.
count
(
ir
::
attr
::
auto_unroll_max_step
),
1
);
const
auto
&
attr_value
=
applied_schedule_block
->
attrs
.
at
(
ir
::
attr
::
auto_unroll_max_step
);
const
int
*
max_step
=
absl
::
get_if
<
int
>
(
&
attr_value
);
EXPECT_NE
(
max_step
,
nullptr
);
EXPECT_LE
(
*
max_step
,
128
);
VLOG
(
6
)
<<
"After auto-unroll:max_step="
<<
*
max_step
<<
", Ast:
\n
"
<<
ir_sch
->
GetModule
().
GetExprs
().
front
();
};
test_func
(
&
ir_schedule
);
test_func
(
&
states
[
0
]
->
ir_schedule
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/mix_rules_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
class
TestMixRules
:
public
TestAutoGenRuleBase
{
public:
std
::
vector
<
std
::
string
>
default_input_names
=
{
"X"
,
"Y"
};
std
::
vector
<
std
::
string
>
default_output_names
=
{
"temp_matmul_out"
};
};
TEST_F
(
TestMixRules
,
2
DMatmulOnMultiTilingRelated
)
{
frontend
::
Program
matmul_op
=
tests
::
OpBuilder
(
"matmul"
).
Build
({{
"X"
,
{
32
,
32
}},
{
"Y"
,
{
32
,
32
}}});
Initialize
(
common
::
DefaultNVGPUTarget
());
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
matmul_op
);
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// Apply MultiLevelTiling
MultiLevelTiling
multi_level_tiling
(
target_
,
MultiLevelTiling
::
kConfigs
.
at
(
target_
.
arch
));
multi_level_tiling
.
Init
(
&
ir_schedule
);
ASSERT_EQ
(
multi_level_tiling
.
NumberApplicable
(),
1
);
multi_level_tiling
.
ApplyRandomly
();
VLOG
(
6
)
<<
"after MultiLevelTiling Expr:
\n
"
<<
func_bodys
[
0
];
// build ir::Module and debug source code
auto
ir_module
=
BuildIRModule
(
ir_schedule
);
auto
source_code
=
GenSourceCode
(
ir_module
);
VLOG
(
6
)
<<
"scheduled source code:
\n
"
<<
source_code
;
// execute and check precision
CheckResult
(
GenExecutableKernel
(
ir_module
),
GenExecutableKernel
(
BuildIRModule
(
MakeIRSchedule
(
matmul_op
,
/* apply_manual_schedule */
true
))),
default_input_names
,
default_output_names
,
{{
32
,
32
},
{
32
,
32
}},
{{
32
,
32
}},
target_
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include <glog/logging.h>
#include <algorithm>
#include <memory>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
auto_schedule
{
MultiLevelTiling
::
MultiLevelTiling
(
const
common
::
Target
&
target
,
const
Config
&
config
)
:
AutoGenRule
(
target
),
config_
(
config
)
{
for
(
int
i
=
0
;
i
<
config_
.
tile_struct
.
size
();
++
i
)
{
if
(
config_
.
tile_struct
[
i
]
==
'S'
)
{
s_indices_
.
push_back
(
i
);
}
else
if
(
config_
.
tile_struct
[
i
]
==
'R'
)
{
r_indices_
.
push_back
(
i
);
}
else
{
CHECK
(
false
)
<<
"Illegal tiling structure string"
;
}
}
}
bool
MultiLevelTiling
::
MeetCondition
(
const
ir
::
ScheduleBlockRealize
&
sche_block_realize
)
const
{
return
NeedsMultiLevelTiling
(
sche_block_realize
);
}
RuleApplyType
MultiLevelTiling
::
Init
(
ir
::
IRSchedule
*
ir_schedule
)
{
ir_schedule_
=
ir_schedule
;
all_block_realizes_
=
ir_schedule_
->
GetAllBlocks
();
applicable_indices_
.
clear
();
num_applicable_
=
0
;
for
(
size_t
i
=
0
;
i
<
all_block_realizes_
.
size
();
++
i
)
{
ir
::
ScheduleBlockRealize
*
sche_block_realize
=
all_block_realizes_
[
i
].
As
<
ir
::
ScheduleBlockRealize
>
();
AnalyzeScheduleBlockReadWriteBuffer
(
sche_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
if
(
MeetCondition
(
*
sche_block_realize
))
{
++
num_applicable_
;
applicable_indices_
.
push_back
(
i
);
}
}
return
num_applicable_
>
0
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
void
MultiLevelTiling
::
Apply
(
int
index
)
{
CHECK
(
ir_schedule_
!=
nullptr
)
<<
"Run MultiLevelTiling::Apply without Init"
;
CHECK
(
num_applicable_
>
0
&&
applicable_indices_
.
size
()
==
num_applicable_
)
<<
"MultiLevelTiling::Apply pre-condition doesn't meet"
;
CHECK
(
index
>=
0
&&
num_applicable_
>
index
)
<<
"Invalid index for MultiLevelTiling::Apply, the index needs 0 <= "
"index && index < NumberApplicable(), "
<<
"Currently index = "
<<
index
<<
", NumberApplicable() = "
<<
num_applicable_
;
int
apply_index
=
applicable_indices_
[
index
];
std
::
string
block_name
=
all_block_realizes_
[
apply_index
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
Expr
block_expr
=
all_block_realizes_
[
apply_index
];
ApplyTiling
(
ir_schedule_
,
block_expr
);
block_expr
=
ir_schedule_
->
GetBlock
(
block_name
);
ApplyCacheRead
(
ir_schedule_
,
block_expr
);
block_expr
=
ir_schedule_
->
GetBlock
(
block_name
);
ApplyCacheWrite
(
ir_schedule_
,
block_expr
);
VLOG
(
4
)
<<
"Returning the result of MultiLevelTiling"
;
return
;
}
std
::
string
MultiLevelTiling
::
GetRuleName
()
const
{
return
"MultiLevelTiling"
;
}
RuleApplyType
MultiLevelTiling
::
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
{
Expr
block_expr
=
state
->
ir_schedule
.
GetBlock
(
block_name
);
auto
*
block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
CHECK
(
block_realize
)
<<
"stmt is not a ScheduleBlockRealize:"
<<
block_expr
;
AnalyzeScheduleBlockReadWriteBuffer
(
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
());
return
NeedsMultiLevelTiling
(
*
block_realize
)
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
std
::
vector
<
SearchState
>
MultiLevelTiling
::
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
{
SearchState
new_state
=
state
.
Copy
();
ir
::
IRSchedule
*
ir_sch
=
&
new_state
->
ir_schedule
;
Expr
block_expr
=
ir_sch
->
GetBlock
(
block_name
);
ApplyTiling
(
ir_sch
,
block_expr
);
block_expr
=
ir_sch
->
GetBlock
(
block_name
);
ApplyCacheRead
(
ir_sch
,
block_expr
);
block_expr
=
ir_sch
->
GetBlock
(
block_name
);
ApplyCacheWrite
(
ir_sch
,
block_expr
);
VLOG
(
4
)
<<
"Returning the result of MultiLevelTiling"
;
return
{
new_state
};
}
void
MultiLevelTiling
::
ApplyTiling
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
)
{
ir
::
ScheduleBlockRealize
*
sche_block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
ir
::
ScheduleBlock
*
sche_block
=
sche_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
tile_loops_
.
clear
();
tile_loops_
.
resize
(
config_
.
tile_struct
.
size
());
std
::
vector
<
Expr
>
for_exprs
=
ir_schedule
->
GetLoops
(
block_expr
);
VLOG
(
5
)
<<
"The number of loops to split in MultiLevelTiling is "
<<
for_exprs
.
size
();
for
(
int
i
=
for_exprs
.
size
()
-
1
;
i
>=
0
;
--
i
)
{
ir
::
For
*
ir_for
=
for_exprs
[
i
].
As
<
ir
::
For
>
();
VLOG
(
6
)
<<
"Applying Split for MultiLevelTiling on: "
<<
Expr
(
ir_for
);
const
std
::
vector
<
int
>*
idx
=
nullptr
;
if
(
sche_block
->
iter_vars
[
i
]
->
is_reduce_axis
)
{
idx
=
&
r_indices_
;
}
else
{
idx
=
&
s_indices_
;
}
// TODO(zhhsplendid): support more iterator variable types
int
extent
=
ir_for
->
extent
.
as_int32
();
// maybe int64?
int
num_split
=
idx
->
size
();
if
(
num_split
>
1
)
{
std
::
vector
<
Expr
>
tile_split_factor
=
ir_schedule
->
SamplePerfectTile
(
Expr
(
ir_for
),
num_split
,
64
);
std
::
vector
<
Expr
>
splited
=
ir_schedule
->
Split
(
Expr
(
ir_for
),
tile_split_factor
);
VLOG
(
6
)
<<
"Finish Split for MultiLevelTiling on above loop"
;
for
(
int
j
=
0
;
j
<
num_split
;
++
j
)
{
tile_loops_
[
idx
->
at
(
j
)].
push_back
(
splited
[
j
]);
}
}
else
{
tile_loops_
[
idx
->
at
(
0
)].
push_back
(
for_exprs
[
i
]);
}
}
VLOG
(
5
)
<<
"Finish Split in MultiLevelTiling, before Reorder."
;
// Have to GetLoops again because Split can change Block Expr(s)
for_exprs
=
ir_schedule
->
GetLoops
(
sche_block
->
name
);
std
::
unordered_map
<
std
::
string
,
int
>
loop_var_name_to_idx
;
for
(
int
i
=
0
;
i
<
for_exprs
.
size
();
++
i
)
{
loop_var_name_to_idx
[
for_exprs
[
i
].
As
<
ir
::
For
>
()
->
loop_var
->
name
]
=
i
;
}
CHECK
(
loop_var_name_to_idx
.
size
()
==
for_exprs
.
size
())
<<
"Loops contain duplicate loop var names after split"
;
std
::
vector
<
Expr
>
splited_loops
;
for
(
auto
&
t
:
tile_loops_
)
{
std
::
reverse
(
t
.
begin
(),
t
.
end
());
for
(
auto
&
tile_loop_expr
:
t
)
{
const
ir
::
For
*
tile_loop
=
tile_loop_expr
.
As
<
ir
::
For
>
();
CHECK
(
tile_loop
)
<<
"tiles store non For Expr"
;
int
idx
=
loop_var_name_to_idx
[
tile_loop
->
loop_var
->
name
];
splited_loops
.
push_back
(
for_exprs
[
idx
]);
}
}
Expr
reordered_expr
=
ir_schedule
->
Reorder
(
splited_loops
);
VLOG
(
5
)
<<
"Finish Reorder in MultiLevelTiling, now do Fuse and Binding on "
"the main loop chain"
;
int
num_binds
=
std
::
min
(
config_
.
bind_axis
.
size
(),
tile_loops_
.
size
());
for
(
int
i
=
0
;
i
<
num_binds
;
++
i
)
{
loop_var_name_to_idx
.
clear
();
for_exprs
=
ir_schedule
->
GetLoops
(
sche_block
->
name
);
for
(
int
j
=
0
;
j
<
for_exprs
.
size
();
++
j
)
{
loop_var_name_to_idx
[
for_exprs
[
j
].
As
<
ir
::
For
>
()
->
loop_var
->
name
]
=
j
;
}
CHECK
(
loop_var_name_to_idx
.
size
()
==
for_exprs
.
size
())
<<
"Loops contain duplicate loop var names before Fusion"
;
// Some loops extent may exceed the limited max factor (For example,
// exceed the limit number of CUDA threads), here we check whether
// the fused loop extent, which is the production of extends of loops
// to be fused, is less or equal to the max factor.
//
// If yes, we fuse those loops and bind the fused loop
// If no, we bind the first loop whose extent is less than the factor.
int
extent_prod
=
1
;
int
first_idx_less_than_max_factor
=
-
1
;
for
(
int
j
=
0
;
j
<
tile_loops_
[
i
].
size
();
++
j
)
{
const
ir
::
For
*
tile_loop
=
tile_loops_
[
i
][
j
].
As
<
ir
::
For
>
();
CHECK
(
tile_loop
)
<<
"tiles store non For Expr"
;
int
idx
=
loop_var_name_to_idx
[
tile_loop
->
loop_var
->
name
];
tile_loops_
[
i
][
j
]
=
for_exprs
[
idx
];
int
extent
=
tile_loop
->
extent
.
as_int32
();
// maybe int64?
extent_prod
*=
extent
;
if
(
first_idx_less_than_max_factor
==
-
1
&&
extent
<=
max_factor_
)
{
first_idx_less_than_max_factor
=
idx
;
}
}
if
(
extent_prod
<=
max_factor_
)
{
Expr
fused
=
ir_schedule
->
Fuse
(
tile_loops_
[
i
]);
ir_schedule
->
Bind
(
fused
,
config_
.
bind_axis
[
i
]);
}
else
if
(
first_idx_less_than_max_factor
!=
-
1
)
{
ir_schedule
->
Bind
(
for_exprs
[
first_idx_less_than_max_factor
],
config_
.
bind_axis
[
i
]);
}
}
VLOG
(
5
)
<<
"Do Fuse and Binding on the non-main loop chains"
;
Expr
sche_block_top_loop
=
ir_schedule
->
GetLoops
(
sche_block
->
name
)[
0
];
if
(
reordered_expr
.
As
<
ir
::
Block
>
())
{
for
(
Expr
&
top_loop
:
reordered_expr
.
As
<
ir
::
Block
>
()
->
stmts
)
{
if
(
top_loop
!=
sche_block_top_loop
)
{
std
::
vector
<
Expr
>
scan_loop_blocks
=
ir_schedule
->
GetAllBlocks
();
Expr
other_loop_chain_schedule
;
for
(
Expr
&
block
:
scan_loop_blocks
)
{
std
::
vector
<
Expr
>
loop_chain
=
ir_schedule
->
GetLoops
(
block
);
if
(
loop_chain
[
0
]
==
top_loop
)
{
other_loop_chain_schedule
=
block
;
break
;
}
}
if
(
!
other_loop_chain_schedule
.
defined
())
{
LOG
(
WARNING
)
<<
"Has non-main loop chain, but not corresponding "
"ScheduleBlock in MultiLevelTiling"
;
continue
;
}
std
::
string
other_loop_schedule_name
=
other_loop_chain_schedule
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
VLOG
(
6
)
<<
"Found other_loop_schedule_name = "
<<
other_loop_schedule_name
;
int
fuse_index
=
0
;
for
(
int
i
=
0
;
i
<
num_binds
;
++
i
)
{
for_exprs
=
ir_schedule
->
GetLoops
(
other_loop_schedule_name
);
// Some loops extent may exceed the limited max factor (For example,
// exceed the limit number of CUDA threads), here we check whether
// the fused loop extent, which is the production of extends of loops
// to be fused, is less or equal to the max factor.
//
// If yes, we fuse those loops and bind the fused loop
// If no, we bind the first loop whose extent is less than the factor.
int
extent_prod
=
1
;
int
first_idx_less_than_max_factor
=
-
1
;
for
(
int
j
=
0
;
j
<
tile_loops_
[
i
].
size
();
++
j
)
{
int
extent
=
for_exprs
[
fuse_index
+
j
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
();
extent_prod
*=
extent
;
if
(
first_idx_less_than_max_factor
==
-
1
&&
extent
<=
max_factor_
)
{
first_idx_less_than_max_factor
=
fuse_index
+
j
;
}
}
if
(
extent_prod
<=
max_factor_
)
{
std
::
vector
<
Expr
>
loops_to_fuse
(
for_exprs
.
begin
()
+
fuse_index
,
for_exprs
.
begin
()
+
fuse_index
+
tile_loops_
[
i
].
size
());
Expr
fused
=
ir_schedule
->
Fuse
(
loops_to_fuse
);
ir_schedule
->
Bind
(
fused
,
config_
.
bind_axis
[
i
]);
fuse_index
+=
1
;
}
else
if
(
first_idx_less_than_max_factor
!=
-
1
)
{
ir_schedule
->
Bind
(
for_exprs
[
first_idx_less_than_max_factor
],
config_
.
bind_axis
[
i
]);
fuse_index
+=
tile_loops_
[
i
].
size
();
}
}
}
}
}
}
void
MultiLevelTiling
::
ApplyCacheRead
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
)
{
ir
::
ScheduleBlockRealize
*
sch_block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
ir
::
ScheduleBlock
*
sch_block
=
sch_block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
std
::
string
block_name
=
sch_block
->
name
;
// Analyze which buffers can be cached
std
::
vector
<
int
>
read_buffer_indexes
;
for
(
int
i
=
0
;
i
<
sch_block
->
read_buffers
.
size
();
++
i
)
{
bool
is_read_write
=
false
;
for
(
int
j
=
0
;
j
<
sch_block
->
write_buffers
.
size
();
++
j
)
{
if
(
sch_block
->
read_buffers
[
i
]
==
sch_block
->
write_buffers
[
j
])
{
is_read_write
=
true
;
break
;
}
}
if
(
!
is_read_write
)
{
read_buffer_indexes
.
push_back
(
i
);
}
}
// Schedule
for
(
int
read_buffer_index
:
read_buffer_indexes
)
{
for
(
int
level
:
config_
.
read_cache_levels
)
{
// 1.find target loop
const
auto
loops
=
tile_loops_
.
at
(
level
-
1
);
if
(
loops
.
size
()
==
0
)
{
continue
;
}
// 2.Do CacheRead and get the cache block
ir
::
Expr
cache_block
=
ir_schedule
->
CacheRead
(
block_expr
,
read_buffer_index
,
config_
.
read_cache_memory_type
);
std
::
string
cache_block_name
=
cache_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
std
::
string
target_for_loop_name
=
loops
.
back
().
As
<
ir
::
For
>
()
->
loop_var
->
name
;
// 3.Place the cache_block under target_for_loop
// The original block expr is invalid after the CacheRead schedule,
// so we reacquire the block expr after the schedule according to the
// block name
block_expr
=
ir_schedule
->
GetBlock
(
block_name
);
std
::
vector
<
Expr
>
for_exprs
=
ir_schedule
->
GetLoops
(
block_expr
);
for
(
const
Expr
&
for_expr
:
for_exprs
)
{
if
(
for_expr
.
As
<
ir
::
For
>
()
->
loop_var
->
name
.
find
(
target_for_loop_name
)
!=
std
::
string
::
npos
)
{
ir_schedule
->
ComputeAt
(
cache_block
,
for_expr
,
true
);
break
;
}
}
// 4.Threads under the same block cooperative fetch data from global
// memory.
Expr
new_cache_block
=
ir_schedule
->
GetBlock
(
cache_block_name
);
auto
cache_block_loops
=
ir_schedule
->
GetLoops
(
new_cache_block
);
std
::
vector
<
std
::
string
>
compute_at_extra_var
=
utils
::
Split
(
absl
::
get
<
std
::
string
>
(
new_cache_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
attrs
.
at
(
"compute_at_extra_var"
)),
","
);
std
::
vector
<
Expr
>
buffer_loops
;
// int nthreads = 1;
for
(
const
Expr
&
for_expr
:
cache_block_loops
)
{
if
(
std
::
find
(
compute_at_extra_var
.
begin
(),
compute_at_extra_var
.
end
(),
for_expr
.
As
<
ir
::
For
>
()
->
loop_var
->
name
)
!=
compute_at_extra_var
.
end
())
{
buffer_loops
.
push_back
(
for_expr
);
}
}
auto
fused_buffer_loop
=
ir_schedule
->
Fuse
(
buffer_loops
);
// TODO(BiynXu): Implement vectorize fetching data and pass in vector
// length
ir_schedule
->
Annotate
(
ir_schedule
->
GetBlock
(
cache_block_name
),
ir
::
attr
::
cooperative_process
,
0
);
}
}
}
void
MultiLevelTiling
::
ApplyCacheWrite
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
)
{
ir
::
Expr
cache_block
=
ir_schedule
->
CacheWrite
(
block_expr
,
0
,
config_
.
write_cache_memory_type
);
for
(
int
level
:
config_
.
write_cache_levels
)
{
const
auto
loops
=
tile_loops_
.
at
(
level
-
1
);
if
(
loops
.
size
()
==
0
)
{
continue
;
}
std
::
string
target_for_loop_name
=
loops
.
back
().
As
<
ir
::
For
>
()
->
loop_var
->
name
;
// Because the block name is changed in CacheWrite, we need to calculate the
// derived name according to the logic of CacheWrite and find the loop
// structure according to the derived name.
const
std
::
string
original_block_name
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
const
std
::
string
derivative_block_name
=
original_block_name
+
"_"
+
config_
.
write_cache_memory_type
+
"_temp_buffer"
;
std
::
vector
<
Expr
>
for_exprs
=
ir_schedule
->
GetLoops
(
derivative_block_name
);
for
(
const
Expr
&
for_expr
:
for_exprs
)
{
if
(
for_expr
.
As
<
ir
::
For
>
()
->
loop_var
->
name
.
find
(
target_for_loop_name
)
!=
std
::
string
::
npos
)
{
ir_schedule
->
ReverseComputeAt
(
ir_schedule
->
GetBlock
(
original_block_name
),
for_expr
,
true
);
}
}
const
std
::
string
reduce_init_block_name
=
original_block_name
+
"__reduce_init"
;
for_exprs
=
ir_schedule
->
GetLoops
(
derivative_block_name
);
for
(
const
Expr
&
for_expr
:
for_exprs
)
{
if
(
for_expr
.
As
<
ir
::
For
>
()
->
loop_var
->
name
.
find
(
target_for_loop_name
)
!=
std
::
string
::
npos
&&
ir_schedule
->
HasBlock
(
reduce_init_block_name
))
{
ir_schedule
->
SimpleComputeAt
(
ir_schedule
->
GetBlock
(
reduce_init_block_name
),
for_expr
);
}
}
}
}
const
std
::
unordered_map
<
common
::
Target
::
Arch
,
MultiLevelTiling
::
Config
>
MultiLevelTiling
::
kConfigs
{
{
common
::
Target
::
Arch
::
NVGPU
,
MultiLevelTiling
::
Config
{
/*bind_axis*/
std
::
vector
<
std
::
string
>
{
"blockIdx.x"
,
"threadIdx.x"
},
/*tile_struct*/
std
::
string
(
"SSSRRSRS"
),
/*read_cache_memory_type*/
std
::
string
(
"shared"
),
/*read_cache_levels*/
std
::
vector
<
int
>
{
4
},
/*write_cache_memory_type*/
std
::
string
(
"local"
),
/*write_cache_levels*/
std
::
vector
<
int
>
{
3
},
}},
{
common
::
Target
::
Arch
::
X86
,
MultiLevelTiling
::
Config
{
/*bind_axis*/
std
::
vector
<
std
::
string
>
{},
/*tile_struct*/
std
::
string
(
"SSRSRS"
),
/*read_cache_memory_type*/
std
::
string
(
"local"
),
/*read_cache_levels*/
std
::
vector
<
int
>
{
3
},
/*write_cache_memory_type*/
std
::
string
(
"local"
),
/*write_cache_levels*/
std
::
vector
<
int
>
{
2
},
}}};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cmath>
#include <cstdlib>
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
class
MultiLevelTiling
:
public
AutoGenRule
{
public:
struct
Config
{
// Which thread axis each tiled loop is bound to
std
::
vector
<
std
::
string
>
bind_axis
;
// Use char 'S' and 'R' to represent tile structure.
// S means space tiling level and R means reduce tiling level
//
// For example, if tile_struct_ = "SSRSRS" and we are doing matrix
// multiplication, i, j are the spatial indices and k is the reduce index,
// the tiling result will be i_0, j0, i1, j1, k0, i2, j2, k1, i3, j3
std
::
string
tile_struct
;
// The storage type of read cache
std
::
string
read_cache_memory_type
;
// Which tiled levels are read cache block inserted at
std
::
vector
<
int
>
read_cache_levels
;
// The storage type of write cache
std
::
string
write_cache_memory_type
;
// Which tiled levels are write cache block inserted at
std
::
vector
<
int
>
write_cache_levels
;
};
static
const
std
::
unordered_map
<
common
::
Target
::
Arch
,
Config
>
kConfigs
;
MultiLevelTiling
(
const
common
::
Target
&
target
,
const
Config
&
config
);
~
MultiLevelTiling
()
=
default
;
// initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true otherwise
RuleApplyType
Init
(
ir
::
IRSchedule
*
init_schedule
)
override
;
// Applies rule on the ir::ModuleExpr for a schedule block specified by index
// between 0 (inclusive) and NumberApplicable() (exclusive)
void
Apply
(
int
index
)
override
;
// Returns the name of the rule, used for debug.
std
::
string
GetRuleName
()
const
override
;
// Returns true if sche_block_realize is applicable by MultiLevelTiling
bool
MeetCondition
(
const
ir
::
ScheduleBlockRealize
&
sche_block_realize
)
const
;
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
override
;
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
;
// Sample pair of integer type (a, b) such as a * b = extent
template
<
typename
T
>
std
::
vector
<
T
>
SampleSplitTwo
(
T
extent
)
const
{
std
::
vector
<
std
::
vector
<
T
>>
candidates
;
for
(
T
div
=
1
;
div
<=
sqrt
(
extent
);
++
div
)
{
if
(
extent
%
div
==
0
)
{
candidates
.
push_back
({
T
(
div
),
extent
/
div
});
}
}
if
(
candidates
.
size
()
==
0
)
{
return
{
1
,
T
(
extent
)};
}
int
index
=
rand
()
%
candidates
.
size
();
// NOLINT
std
::
vector
<
T
>
pick
=
candidates
[
index
];
if
(
rand
()
%
2
!=
0
)
{
// NOLINT
T
tmp
=
pick
[
0
];
pick
[
0
]
=
pick
[
1
];
pick
[
1
]
=
tmp
;
}
return
pick
;
}
// Sample num_split integers whose product equals extent
template
<
typename
T
>
std
::
vector
<
T
>
SampleTileSplit
(
T
extent
,
int
num_split
)
const
{
CHECK_GT
(
num_split
,
0
)
<<
"num_split in SampleTileSplit must be greater than 0"
;
if
(
num_split
==
1
)
{
return
{
extent
};
}
std
::
vector
<
T
>
two_split
=
SampleSplitTwo
<
T
>
(
extent
);
if
(
num_split
==
2
)
{
return
two_split
;
}
int
half
=
num_split
>>
1
;
std
::
vector
<
T
>
result
=
SampleTileSplit
<
T
>
(
two_split
[
0
],
half
);
std
::
vector
<
T
>
remind
=
SampleTileSplit
<
T
>
(
two_split
[
1
],
num_split
-
half
);
result
.
insert
(
result
.
end
(),
remind
.
begin
(),
remind
.
end
());
return
result
;
}
private:
void
ApplyTiling
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
);
// NOLINT
void
ApplyCacheRead
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
);
// NOLINT
void
ApplyCacheWrite
(
ir
::
IRSchedule
*
ir_schedule
,
ir
::
Expr
&
block_expr
);
// NOLINT
private:
std
::
vector
<
ir
::
Expr
>
all_block_realizes_
;
std
::
vector
<
int
>
applicable_indices_
;
Config
config_
;
std
::
vector
<
int
>
s_indices_
;
std
::
vector
<
int
>
r_indices_
;
std
::
vector
<
std
::
vector
<
ir
::
Expr
>>
tile_loops_
;
// A factor to limit the split factor within max thread number per block
int
max_factor_
=
1024
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/multi_level_tiling.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
#include "paddle/cinn/utils/string.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
MultiLevelTile
,
SampleSplitTwo
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
MultiLevelTiling
multi_level_tiling
(
target
,
MultiLevelTiling
::
kConfigs
.
at
(
target
.
arch
));
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
size_t
number_to_split
=
rand
()
%
65535
+
2
;
// NOLINT, random number in [2, 2^16]
std
::
vector
<
size_t
>
split
=
multi_level_tiling
.
SampleSplitTwo
<
size_t
>
(
number_to_split
);
EXPECT_EQ
(
split
.
size
(),
2UL
);
EXPECT_EQ
(
split
[
0
]
*
split
[
1
],
number_to_split
);
}
}
TEST
(
MultiLevelTile
,
SampleTileSplit
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
MultiLevelTiling
multi_level_tiling
(
target
,
MultiLevelTiling
::
kConfigs
.
at
(
target
.
arch
));
for
(
int
i
=
0
;
i
<
100
;
++
i
)
{
int
number_to_split
=
rand
()
%
65535
+
2
;
// NOLINT, random number in [2, 2^16]
int
split_size
=
rand
()
%
5
+
1
;
// NOLINT, random in [1, 5]
std
::
vector
<
int
>
split
=
multi_level_tiling
.
SampleTileSplit
<
int
>
(
number_to_split
,
split_size
);
EXPECT_EQ
(
split
.
size
(),
static_cast
<
size_t
>
(
split_size
));
int
product
=
1
;
for
(
int
num
:
split
)
{
product
*=
num
;
}
EXPECT_EQ
(
product
,
number_to_split
);
}
}
TEST
(
MultiLevelTile
,
SimpleLoops
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
Expr
M
(
32
);
Expr
N
(
128
);
Placeholder
<
float
>
A
(
"A"
,
{
M
});
Placeholder
<
float
>
B
(
"B"
,
{
N
});
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestMultiLevelTile_SimpleLoops"
,
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before MultiLevelTiling: "
;
VLOG
(
6
)
<<
ast_expr
;
MultiLevelTiling
multi_level_tiling
(
target
,
MultiLevelTiling
::
kConfigs
.
at
(
target
.
arch
));
ir
::
IRSchedule
ir_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
SearchState
state
(
ir_schedule
,
0
,
{});
EXPECT_EQ
(
multi_level_tiling
.
Init
(
&
ir_schedule
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
EXPECT_EQ
(
multi_level_tiling
.
NumberApplicable
(),
1
);
multi_level_tiling
.
ApplyRandomly
();
// ApplyOnBlock
EXPECT_EQ
(
multi_level_tiling
.
AnalyseApplyType
(
state
,
"C"
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
multi_level_tiling
.
ApplyOnBlock
(
state
,
"C"
);
auto
test_func
=
[](
ir
::
IRSchedule
*
ir_sch
)
{
std
::
vector
<
ir
::
Expr
>
exprs
=
ir_sch
->
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
std
::
stringstream
ss
;
ss
<<
exprs
[
0
];
std
::
string
expr_str
=
ss
.
str
();
VLOG
(
6
)
<<
expr_str
;
};
test_func
(
&
ir_schedule
);
test_func
(
&
new_states
[
0
]
->
ir_schedule
);
}
// TODO(SunNy820828449): fix in future
/*
TEST(MulitLevelTile, MatrixMultiply) {
srand(0);
Context::Global().ResetNameId();
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
Expr M(32);
Expr N(32);
Expr K(32);
Placeholder<float> A("A", {M, K});
Placeholder<float> B("B", {K, N});
Var k(K.as_int32(), "reduce_axis_k");
ir::Tensor C = Compute(
{M, N}, [&](Var i, Var j) { return ReduceSum(A(i, k) * B(k, j), {k}); },
"C");
poly::StageMap stages = CreateStages({C});
std::vector<ir::LoweredFunc> funcs =
lang::LowerVec("TestMultiLevelTile_MatrixMultiply", stages, {C}, {}, {},
nullptr, target, true);
ir::Expr ast_expr = funcs[0]->body;
VLOG(6) << "Expr before MultiLevelTiling: ";
VLOG(6) << ast_expr;
MultiLevelTiling multi_level_tiling(target,
MultiLevelTiling::kConfigs.at(target.arch)); ir::IRSchedule
ir_schedule(ir::ModuleExpr({ast_expr})); SearchState state(ir_schedule, 0, {});
EXPECT_EQ(multi_level_tiling.Init(&ir_schedule),
RuleApplyType::kApplyAndPruneOtherRules);
EXPECT_EQ(multi_level_tiling.NumberApplicable(), 1);
multi_level_tiling.ApplyRandomly();
// ApplyOnBlock
EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state, "C"),
RuleApplyType::kApplyAndPruneOtherRules); auto new_states =
multi_level_tiling.ApplyOnBlock(state, "C");
auto test_func = [](ir::IRSchedule* ir_sch) {
std::vector<ir::Expr> exprs = ir_sch->GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
std::stringstream ss;
ss << exprs[0];
std::string expr_str = ss.str();
VLOG(6) << expr_str;
};
test_func(&ir_schedule);
test_func(&new_states[0]->ir_schedule);
}
*/
class
TestMultiLevelTiling
:
public
TestAutoGenRuleBase
{
public:
int
fixed_rand_seed
=
1
;
std
::
vector
<
std
::
string
>
default_input_names
;
std
::
vector
<
std
::
string
>
default_output_names
;
};
TEST_F
(
TestMultiLevelTiling
,
Matmul
)
{
default_input_names
=
{
"X"
,
"Y"
};
default_output_names
=
{
"temp_matmul_out"
};
std
::
vector
<
int32_t
>
X_shape
=
{
32
,
32
};
std
::
vector
<
int32_t
>
Y_shape
=
{
32
,
32
};
std
::
vector
<
int32_t
>
out_shape
=
{
32
,
32
};
Initialize
(
common
::
DefaultNVGPUTarget
());
frontend
::
Program
matmul_op
=
tests
::
OpBuilder
(
"matmul"
).
Build
({{
"X"
,
X_shape
},
{
"Y"
,
Y_shape
}});
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
matmul_op
,
fixed_rand_seed
);
SearchState
state
(
ir_schedule
);
VLOG
(
6
)
<<
"Original state:
\n
"
<<
state
->
DebugString
();
// Apply MultiLevelTiling
MultiLevelTiling
multi_level_tiling
(
target_
,
MultiLevelTiling
::
kConfigs
.
at
(
target_
.
arch
));
EXPECT_EQ
(
multi_level_tiling
.
AnalyseApplyType
(
state
,
default_output_names
[
0
]),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
multi_level_tiling
.
ApplyOnBlock
(
state
,
default_output_names
[
0
]);
VLOG
(
6
)
<<
"After MultiLevelTiling, state:
\n
"
<<
new_states
[
0
]
->
DebugString
();
std
::
string
ir
=
GetIR
(
new_states
[
0
]
->
ir_schedule
);
std
::
string
expected_ir
=
R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
thread_bind[blockIdx.x] for (i_j_fused, 0, 4)
{
thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 1)
{
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 1)
{
serial for (i_2, 0, 1)
{
serial for (j_2, 0, 1)
{
serial for (i_3, 0, 8)
{
serial for (j_3, 0, 32)
{
ScheduleBlock(temp_matmul_out__reduce_init)
{
i0, i1 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)))
{
temp_matmul_out__reduce_init[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = 0.00000000f
}
}
}
}
}
}
{
serial for (reduce_k_0, 0, 4)
{
serial for (ax0_0_ax1_0_fused, 0, 256)
{
ScheduleBlock(Y_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_0_ax1_0_fused / 32) + (8 * reduce_k_0)), ((ax0_0_ax1_0_fused % 32) + (32 * j_1)))
attrs(compute_at_extra_var:ax0_0,ax1_0, cooperative_process:0)
{
Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1]
}
}
}
serial for (ax0_ax1_fused, 0, 64)
{
ScheduleBlock(X_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_ax1_fused / 8) + ((8 * i_0_j_0_fused) + ((8 * i_1) + (8 * i_j_fused)))), ((ax0_ax1_fused % 8) + (8 * reduce_k_0)))
attrs(compute_at_extra_var:ax0,ax1, cooperative_process:0)
{
X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1]
}
}
}
serial for (reduce_k_1, 0, 1)
{
serial for (i_2, 0, 1)
{
serial for (j_2, 0, 1)
{
serial for (reduce_k_2, 0, 8)
{
serial for (i_3, 0, 8)
{
serial for (j_3, 0, 32)
{
ScheduleBlock(temp_matmul_out_local_temp_buffer)
{
i0_0, i1_0, i2 = axis.bind(((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3)), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)))
read_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)], _X[i(undefined:undefined), reduce_k(undefined:undefined)], _Y[reduce_k(undefined:undefined), j(undefined:undefined)])
write_buffers(_temp_matmul_out[i(undefined:undefined), j(undefined:undefined)])
{
temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] = (temp_matmul_out_local_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((32 * j_1) + ((32 * j_2) + j_3))] + (X_reshape_shared_temp_buffer[((8 * i_0_j_0_fused) + ((8 * i_1) + ((8 * i_2) + ((8 * i_j_fused) + i_3)))), ((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2))] * Y_reshape_shared_temp_buffer[((8 * reduce_k_0) + ((8 * reduce_k_1) + reduce_k_2)), ((32 * j_1) + ((32 * j_2) + j_3))]))
}
}
}
}
}
}
}
}
}
serial for (ax0_1, 0, 8)
{
serial for (ax1_1, 0, 32)
{
ScheduleBlock(temp_matmul_out)
{
v0, v1 = axis.bind((((8 * i_0_j_0_fused) + ((8 * i_1) + (8 * i_j_fused))) + ax0_1), ((32 * j_1) + ax1_1))
attrs(reverse_compute_at_extra_var:ax0_1,ax1_1)
{
temp_matmul_out[v0, v1] = temp_matmul_out_local_temp_buffer[v0, v1]
}
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
)ROC"
;
ASSERT_EQ
(
ir
,
expected_ir
);
// build ir::Module and debug source code
auto
ir_module
=
BuildIRModule
(
new_states
[
0
]
->
ir_schedule
);
auto
source_code
=
GenSourceCode
(
ir_module
);
VLOG
(
6
)
<<
"scheduled source code:
\n
"
<<
source_code
;
// execute and check precision
CheckResult
(
GenExecutableKernel
(
ir_module
),
GenExecutableKernel
(
BuildIRModule
(
MakeIRSchedule
(
matmul_op
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
default_input_names
,
default_output_names
,
{
X_shape
,
Y_shape
},
{
out_shape
},
target_
);
}
TEST_F
(
TestMultiLevelTiling
,
ReduceSum
)
{
default_input_names
=
{
"X"
};
default_output_names
=
{
"var_0_tmp"
};
std
::
vector
<
int32_t
>
X_shape
=
{
1
,
16
,
32
};
std
::
vector
<
int32_t
>
out_shape
=
{
1
,
16
,
1
};
std
::
vector
<
int32_t
>
reduce_dim
=
{
2
};
Initialize
(
common
::
DefaultNVGPUTarget
());
frontend
::
Program
reduce_sum_op
=
tests
::
OpBuilder
(
"reduce_sum"
)
.
Build
({{
"X"
,
X_shape
}},
{{
"dim"
,
reduce_dim
},
{
"keep_dim"
,
false
}});
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
reduce_sum_op
);
SearchState
state
(
ir_schedule
);
VLOG
(
6
)
<<
"Original state:
\n
"
<<
state
->
DebugString
();
// Apply MultiLevelTiling
MultiLevelTiling
multi_level_tiling
(
target_
,
MultiLevelTiling
::
kConfigs
.
at
(
target_
.
arch
));
// EXPECT_EQ(multi_level_tiling.AnalyseApplyType(state,
// default_output_names[0]), RuleApplyType::kCannotApply);
}
TEST_F
(
TestMultiLevelTiling
,
Pool2d
)
{
default_input_names
=
{
"input"
};
default_output_names
=
{
"var_0"
,
"pad_temp_0"
};
std
::
vector
<
std
::
vector
<
int32_t
>>
input_shapes
{{
2
,
8
,
16
,
16
}};
std
::
vector
<
std
::
vector
<
int32_t
>>
output_shapes
{{
2
,
8
,
8
,
8
},
{
2
,
8
,
18
,
18
}};
std
::
string
pooling_type
=
"max"
;
std
::
vector
<
int
>
ksize
{
3
,
3
};
std
::
vector
<
int
>
strides
{
2
,
2
};
std
::
vector
<
int
>
paddings
{
1
,
1
,
1
,
1
};
bool
ceil_mode
=
false
;
bool
exclusive
=
true
;
bool
global_pooling
=
false
;
std
::
string
data_format
=
"NCHW"
;
bool
adaptive
=
false
;
std
::
string
padding_algorithm
=
"EXPLICIT"
;
frontend
::
Program
pool2d_program
=
tests
::
OpBuilder
(
"pool2d"
).
Build
(
{{
"input"
,
input_shapes
[
0
]}},
{{
"pool_type"
,
pooling_type
},
{
"kernel_size"
,
ksize
},
{
"stride_size"
,
strides
},
{
"padding_size"
,
paddings
},
{
"ceil_mode"
,
ceil_mode
},
{
"exclusive"
,
exclusive
},
{
"global_pooling"
,
global_pooling
},
{
"data_format"
,
data_format
},
{
"adaptive"
,
adaptive
},
{
"padding_algorithm"
,
padding_algorithm
}});
Initialize
(
common
::
DefaultNVGPUTarget
());
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
pool2d_program
,
fixed_rand_seed
);
SearchState
state
(
ir_schedule
);
VLOG
(
6
)
<<
"Original state:
\n
"
<<
state
->
DebugString
();
// Apply MultiLevelTiling
MultiLevelTiling
::
Config
mlt_config
=
{
/*bind_axis*/
std
::
vector
<
std
::
string
>
{
"blockIdx.x"
,
"threadIdx.x"
},
/*tile_struct*/
std
::
string
(
"SSRS"
),
/*read_cache_memory_type*/
std
::
string
(
"shared"
),
/*read_cache_levels*/
std
::
vector
<
int
>
{
3
},
/*write_cache_memory_type*/
std
::
string
(
"local"
),
/*write_cache_levels*/
std
::
vector
<
int
>
{
2
},
};
MultiLevelTiling
multi_level_tiling
(
target_
,
mlt_config
);
EXPECT_EQ
(
multi_level_tiling
.
AnalyseApplyType
(
state
,
default_output_names
[
0
]),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
new_states
=
multi_level_tiling
.
ApplyOnBlock
(
state
,
default_output_names
[
0
]);
VLOG
(
6
)
<<
"After MultiLevelTiling, state:
\n
"
<<
new_states
[
0
]
->
DebugString
();
std
::
string
ir
=
GetIR
(
new_states
[
0
]
->
ir_schedule
);
std
::
string
expected_ir
=
R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
serial for (j, 0, 8)
{
serial for (k, 0, 18)
{
serial for (a, 0, 18)
{
ScheduleBlock(pad_temp_0)
{
i0, i1, i2, i3 = axis.bind(i, j, k, a)
{
pad_temp_0[i, j, k, a] = select(((a < 17) and ((a >= 1) and ((k < 17) and (k >= 1)))), input[i, j, (-1 + k), (-1 + a)], -3.40282347e+38f)
}
}
}
}
}
}
{
thread_bind[blockIdx.x] for (i_j_k_a_fused, 0, 16)
{
thread_bind[threadIdx.x] for (i_0_j_0_k_0_a_0_fused, 0, 4)
{
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{
serial for (k_1, 0, 1)
{
serial for (a_1, 0, 4)
{
ScheduleBlock(var_0__reduce_init)
{
i0_0, i1_0, i2_0, i3_0 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1))
{
var_0__reduce_init[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = -3.40282347e+38f
}
}
}
}
}
}
{
serial for (kernel_idx, 0, 3)
{
serial for (kernel_idx_0, 0, 3)
{
serial for (ax0_ax1_ax2_ax3_fused, 0, 28)
{
ScheduleBlock(pad_temp_0_shared_temp_buffer)
{
v0, v1, v2, v3 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + ((ax0_ax1_ax2_ax3_fused / 7) / 4))), (((ax0_ax1_ax2_ax3_fused / 7) % 4) + (4 * (((i_j_k_a_fused / 2) / 2) % 2))), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + kernel_idx)), ((ax0_ax1_ax2_ax3_fused % 7) + ((8 * (i_j_k_a_fused % 2)) + kernel_idx_0)))
attrs(compute_at_extra_var:ax0,ax1,ax2,ax3, cooperative_process:0)
{
pad_temp_0_shared_temp_buffer[v0, v1, v2, v3] = pad_temp_0[v0, v1, v2, v3]
}
}
}
serial for (i_1, 0, 1)
{
serial for (j_1, 0, 4)
{
serial for (k_1, 0, 1)
{
serial for (a_1, 0, 4)
{
ScheduleBlock(var_0_local_temp_buffer)
{
i0_1, i1_1, i2_1, i3_1, i4, i5 = axis.bind(((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1), kernel_idx, kernel_idx_0)
read_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)], _pad_temp_0[i(undefined:undefined), j(undefined:undefined)])
write_buffers(_var_0[i(undefined:undefined), j(undefined:undefined), k(undefined:undefined), a(undefined:undefined)])
{
var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((4 * ((i_j_k_a_fused / 2) % 2)) + ((i_0_j_0_k_0_a_0_fused % 4) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)] = cinn_max(var_0_local_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((i_0_j_0_k_0_a_0_fused % 4) + ((4 * ((i_j_k_a_fused / 2) % 2)) + k_1)), ((4 * (i_j_k_a_fused % 2)) + a_1)], pad_temp_0_shared_temp_buffer[((((i_j_k_a_fused / 2) / 2) / 2) + ((i_0_j_0_k_0_a_0_fused / 4) + i_1)), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + j_1), ((8 * ((i_j_k_a_fused / 2) % 2)) + ((2 * (i_0_j_0_k_0_a_0_fused % 4)) + ((2 * k_1) + kernel_idx))), ((8 * (i_j_k_a_fused % 2)) + ((2 * a_1) + kernel_idx_0))])
}
}
}
}
}
}
}
}
serial for (ax0_0, 0, 1)
{
serial for (ax1_0, 0, 4)
{
serial for (ax2_0, 0, 1)
{
serial for (ax3_0, 0, 4)
{
ScheduleBlock(var_0)
{
v0, v1, v2, v3 = axis.bind((((((i_j_k_a_fused / 2) / 2) / 2) + (i_0_j_0_k_0_a_0_fused / 4)) + ax0_0), ((4 * (((i_j_k_a_fused / 2) / 2) % 2)) + ax1_0), (((4 * ((i_j_k_a_fused / 2) % 2)) + (i_0_j_0_k_0_a_0_fused % 4)) + ax2_0), ((4 * (i_j_k_a_fused % 2)) + ax3_0))
attrs(reverse_compute_at_extra_var:ax0_0,ax1_0,ax2_0,ax3_0)
{
var_0[v0, v1, v2, v3] = var_0_local_temp_buffer[v0, v1, v2, v3]
}
}
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
)ROC"
;
ASSERT_EQ
(
ir
,
expected_ir
);
// build ir::Module and debug source code
auto
ir_module
=
BuildIRModule
(
new_states
[
0
]
->
ir_schedule
);
auto
source_code
=
GenSourceCode
(
ir_module
);
VLOG
(
6
)
<<
"scheduled source code:
\n
"
<<
source_code
;
// execute and check precision
CheckResult
(
GenExecutableKernel
(
ir_module
),
GenExecutableKernel
(
BuildIRModule
(
MakeIRSchedule
(
pool2d_program
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
default_input_names
,
default_output_names
,
input_shapes
,
output_shapes
,
target_
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
#include <string>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
namespace
cinn
{
namespace
auto_schedule
{
SkipRule
::
SkipRule
(
const
common
::
Target
&
target
)
:
AutoGenRule
(
target
)
{}
RuleApplyType
SkipRule
::
Init
(
ir
::
IRSchedule
*
ir_schedule
)
{
ir_schedule_
=
ir_schedule
;
num_applicable_
=
1
;
return
RuleApplyType
::
kApply
;
}
std
::
string
SkipRule
::
GetRuleName
()
const
{
return
"SkipRule"
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
class
SkipRule
:
public
AutoGenRule
{
public:
explicit
SkipRule
(
const
common
::
Target
&
target
);
~
SkipRule
()
=
default
;
RuleApplyType
Init
(
ir
::
IRSchedule
*
init_schedule
)
override
;
void
Apply
(
int
index
)
override
{}
std
::
string
GetRuleName
()
const
override
;
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
override
{
return
RuleApplyType
::
kApply
;
}
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
{
return
{
state
};
}
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/skip_rule.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cstdlib>
#include <iostream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/poly/stage.h"
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
SkipRule
,
Basic
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
Expr
M
(
32
);
Expr
N
(
128
);
Placeholder
<
float
>
A
(
"A"
,
{
M
});
Placeholder
<
float
>
B
(
"B"
,
{
N
});
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestSkipRule_Basic"
,
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before SkipRule: "
;
VLOG
(
6
)
<<
ast_expr
;
SkipRule
skip_rule
(
target
);
ir
::
IRSchedule
ir_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
SearchState
state
(
ir_schedule
,
0
,
{});
EXPECT_EQ
(
skip_rule
.
Init
(
&
ir_schedule
),
RuleApplyType
::
kApply
);
EXPECT_EQ
(
skip_rule
.
NumberApplicable
(),
1
);
skip_rule
.
ApplyRandomly
();
// ApplyOnBlock
EXPECT_EQ
(
skip_rule
.
AnalyseApplyType
(
state
,
"C"
),
RuleApplyType
::
kApply
);
std
::
vector
<
cinn
::
auto_schedule
::
SearchState
>
states
=
skip_rule
.
ApplyOnBlock
(
state
,
"C"
);
auto
test_func
=
[
&
ast_expr
](
ir
::
IRSchedule
*
ir_sch
)
{
std
::
vector
<
ir
::
Expr
>
exprs
=
ir_sch
->
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
EXPECT_EQ
(
ast_expr
,
exprs
[
0
]);
};
test_func
(
&
ir_schedule
);
test_func
(
&
states
[
0
]
->
ir_schedule
);
}
TEST
(
SkipRule
,
ApplyOnSpecificBlock
)
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
Expr
M
(
32
);
Expr
N
(
128
);
Placeholder
<
float
>
A
(
"A"
,
{
M
});
Placeholder
<
float
>
B
(
"B"
,
{
N
});
ir
::
Tensor
C
=
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
)
+
B
(
j
);
},
"C"
);
poly
::
StageMap
stages
=
CreateStages
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"TestSkipRule_Basic"
,
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr before SkipRule: "
;
VLOG
(
6
)
<<
ast_expr
;
SkipRule
skip_rule
(
target
);
ir
::
IRSchedule
ir_schedule
(
ir
::
ModuleExpr
({
ast_expr
}));
SearchState
state
(
ir_schedule
,
0
,
{});
EXPECT_EQ
(
skip_rule
.
AnalyseApplyType
(
state
,
"C"
),
RuleApplyType
::
kApply
);
std
::
vector
<
cinn
::
auto_schedule
::
SearchState
>
states
=
skip_rule
.
ApplyOnBlock
(
state
,
"C"
);
std
::
vector
<
ir
::
Expr
>
exprs
=
states
[
0
]
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
EXPECT_EQ
(
ast_expr
,
exprs
[
0
]);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <memory.h>
#include <stdlib.h>
#include "paddle/cinn/auto_schedule/analysis/analyze_ir.h"
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_lowering.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/hlir/framework/tensor.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#ifdef CINN_WITH_CUDA
#include <cuda_runtime.h>
#endif
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
hlir
::
framework
::
Instruction
;
using
::
cinn
::
hlir
::
framework
::
Scope
;
using
::
cinn
::
hlir
::
framework
::
Shape
;
using
::
cinn
::
hlir
::
framework
::
Tensor
;
void
TestAutoGenRuleBase
::
Initialize
(
const
common
::
Target
&
target
)
{
target_
=
target
;
backend_compier_
=
backends
::
Compiler
::
Create
(
target
);
}
ir
::
IRSchedule
TestAutoGenRuleBase
::
MakeIRSchedule
(
const
frontend
::
Program
&
test_program
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
bool
apply_manual_schedule
)
{
Context
::
Global
().
ResetNameId
();
auto
graph
=
std
::
make_shared
<
hlir
::
framework
::
Graph
>
(
test_program
,
target_
);
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
LOG_IF
(
WARNING
,
graph
->
fusion_groups
.
size
()
>
1
)
<<
"Test Graph has more than 1 group"
;
auto
&
dtype_dict
=
graph
->
GetMutableAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
auto
&
shape_dict
=
graph
->
GetMutableAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
hlir
::
framework
::
OpLowerer
op_lowerer
(
dtype_dict
,
shape_dict
,
target_
);
lowered_funcs_
=
op_lowerer
.
Lower
(
graph
->
fusion_groups
.
front
(),
/*apply_op_schedule = */
apply_manual_schedule
,
/*apply_group_schedule = */
apply_manual_schedule
);
CHECK
(
!
lowered_funcs_
.
empty
())
<<
"lowered_funcs_ is empty"
;
std
::
vector
<
Expr
>
bodys
;
for
(
auto
&&
func
:
lowered_funcs_
)
{
bodys
.
emplace_back
(
func
->
body
);
}
return
ir
::
IRSchedule
(
ir
::
ModuleExpr
({
std
::
move
(
bodys
)}),
rand_seed
);
}
std
::
string
TestAutoGenRuleBase
::
GetIR
(
const
ir
::
IRSchedule
&
schedule
)
{
const
auto
&
exprs
=
schedule
.
GetModule
().
GetExprs
();
std
::
stringstream
module_stream
;
for
(
auto
i
=
0
;
i
<
exprs
.
size
();
++
i
)
{
module_stream
<<
"Expr "
<<
i
<<
" {
\n
"
<<
exprs
.
at
(
i
)
<<
"
\n
} // end Expr "
<<
i
<<
"
\n
"
;
}
return
module_stream
.
str
();
}
ir
::
Module
TestAutoGenRuleBase
::
BuildIRModule
(
const
ir
::
IRSchedule
&
schedule
)
{
auto
&&
updated_bodys
=
schedule
.
GetModule
().
GetExprs
();
CHECK_EQ
(
lowered_funcs_
.
size
(),
updated_bodys
.
size
())
<<
"associated exprs size not equal"
;
ir
::
Module
::
Builder
builder
(
"test_bulder"
,
this
->
target_
);
for
(
int
i
=
0
;
i
<
lowered_funcs_
.
size
();
++
i
)
{
ir
::
Expr
func_body
=
updated_bodys
.
at
(
i
);
const
ir
::
LoweredFunc
&
ori_func
=
lowered_funcs_
.
at
(
i
);
auto
&&
new_func
=
UpdateFuncWithNewBody
(
target_
,
ori_func
,
func_body
);
builder
.
AddFunction
(
new_func
);
}
return
builder
.
Build
();
}
std
::
string
TestAutoGenRuleBase
::
GenSourceCode
(
const
ir
::
Module
&
ir_module
)
{
std
::
unique_ptr
<
backends
::
CodeGenC
>
codegen
;
#ifdef CINN_WITH_CUDA
if
(
target_
==
common
::
DefaultNVGPUTarget
())
{
codegen
=
std
::
make_unique
<
backends
::
CodeGenCUDA_Dev
>
(
this
->
target_
);
}
else
{
codegen
=
std
::
make_unique
<
backends
::
CodeGenCX86
>
(
this
->
target_
,
CodeGenCX86
::
Feature
::
AVX512
);
}
#else
codegen
=
std
::
make_unique
<
backends
::
CodeGenCX86
>
(
this
->
target_
,
CodeGenCX86
::
Feature
::
AVX512
);
#endif
codegen
->
SetInlineBuiltinCodes
(
false
);
return
codegen
->
Compile
(
ir_module
,
CodeGenC
::
OutputKind
::
CImpl
);
}
raw_func_type
TestAutoGenRuleBase
::
GenExecutableKernel
(
const
ir
::
Module
&
ir_module
)
{
auto
&&
func_name
=
lowered_funcs_
.
front
()
->
name
;
// Compile to machine code
backend_compier_
->
Build
(
ir_module
);
auto
test_func_ptr
=
reinterpret_cast
<
void
(
*
)(
void
**
,
int32_t
)
>
(
backend_compier_
->
Lookup
(
func_name
));
return
test_func_ptr
;
}
void
MemoryCopy
(
const
float
*
src
,
float
*
dst
,
int
numel
,
std
::
string
type
)
{
#ifdef CINN_WITH_CUDA
if
(
type
==
"DeviceToHost"
)
{
cudaMemcpy
(
dst
,
src
,
numel
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
);
return
;
}
else
if
(
type
==
"HostToDevice"
)
{
cudaMemcpy
(
dst
,
src
,
numel
*
sizeof
(
float
),
cudaMemcpyHostToDevice
);
return
;
}
#endif
if
(
type
==
"HostToHost"
)
{
for
(
size_t
i
=
0
;
i
<
numel
;
++
i
)
{
dst
[
i
]
=
src
[
i
];
}
}
else
{
LOG
(
FATAL
)
<<
"Unknown memory copy type"
;
}
}
void
AddDataToScope
(
Scope
*
scope
,
const
common
::
Target
&
target
,
float
*
data_ptr
,
std
::
string
name
,
const
std
::
vector
<
int
>&
shape
)
{
auto
*
var
=
scope
->
Var
<
Tensor
>
(
name
);
auto
&
tensor
=
absl
::
get
<
Tensor
>
(
*
var
);
CHECK
(
shape
.
size
())
<<
"The size of shape can not be 0."
;
Shape
cinn_shape
(
shape
);
tensor
->
Resize
(
cinn_shape
);
auto
*
tgt_data_ptr
=
tensor
->
mutable_data
<
float
>
(
target
);
std
::
string
mem_cpy_type
=
target
==
common
::
DefaultNVGPUTarget
()
?
"DeviceToHost"
:
"HostToHost"
;
MemoryCopy
(
data_ptr
,
tgt_data_ptr
,
cinn_shape
.
numel
(),
mem_cpy_type
);
}
void
CheckResult
(
raw_func_type
test_func
,
raw_func_type
expected_func
,
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
string
>&
output_names
,
const
std
::
vector
<
std
::
vector
<
int
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int
>>&
output_shapes
,
const
common
::
Target
&
target
)
{
CHECK
(
input_names
.
size
())
<<
"The number of inputs must be greater than 0."
;
CHECK
(
output_names
.
size
())
<<
"The number of outputs must be greater than 0."
;
CHECK_EQ
(
input_names
.
size
(),
input_shapes
.
size
())
<<
"The quantity of input_names and input_shapes must be equal."
;
CHECK_EQ
(
output_names
.
size
(),
output_shapes
.
size
())
<<
"The quantity of output_names and output_shapes must be equal."
;
// Initialize data
std
::
vector
<
float
*>
input_data_ptrs
(
input_names
.
size
());
for
(
int
i
=
0
;
i
<
input_shapes
.
size
();
++
i
)
{
int
input_data_numel
=
std
::
accumulate
(
input_shapes
[
i
].
begin
(),
input_shapes
[
i
].
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
input_data_ptrs
[
i
]
=
reinterpret_cast
<
float
*>
(
malloc
(
input_data_numel
*
sizeof
(
float
)));
for
(
int
j
=
0
;
j
<
input_data_numel
;
++
j
)
{
input_data_ptrs
[
i
][
j
]
=
(
rand
()
*
1.
f
)
/
RAND_MAX
;
// NOLINT
}
}
std
::
vector
<
float
*>
test_output_data_ptrs
(
output_names
.
size
());
std
::
vector
<
float
*>
expected_output_data_ptrs
(
output_names
.
size
());
std
::
vector
<
int
>
output_data_numels
(
output_shapes
.
size
());
for
(
int
i
=
0
;
i
<
output_shapes
.
size
();
++
i
)
{
output_data_numels
[
i
]
=
std
::
accumulate
(
output_shapes
[
i
].
begin
(),
output_shapes
[
i
].
end
(),
1
,
[](
int
a
,
int
b
)
{
return
a
*
b
;
});
test_output_data_ptrs
[
i
]
=
reinterpret_cast
<
float
*>
(
malloc
(
output_data_numels
[
i
]
*
sizeof
(
float
)));
memset
(
test_output_data_ptrs
[
i
],
0
,
output_data_numels
[
i
]
*
sizeof
(
float
));
expected_output_data_ptrs
[
i
]
=
reinterpret_cast
<
float
*>
(
malloc
(
output_data_numels
[
i
]
*
sizeof
(
float
)));
memset
(
expected_output_data_ptrs
[
i
],
0
,
output_data_numels
[
i
]
*
sizeof
(
float
));
}
auto
launch_kernel_fn
=
[
&
](
raw_func_type
&
raw_func
,
std
::
vector
<
float
*>&
output_data_ptrs
)
{
// Initialize scope
Scope
scope
;
// Initialize input data in scope.
for
(
int
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
AddDataToScope
(
&
scope
,
target
,
input_data_ptrs
[
i
],
input_names
[
i
],
input_shapes
[
i
]);
}
// Initialize output data in scope.
for
(
int
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
AddDataToScope
(
&
scope
,
target
,
output_data_ptrs
[
i
],
output_names
[
i
],
output_shapes
[
i
]);
}
// Create Instruction and run
Instruction
instr
(
target
,
&
scope
,
input_names
,
output_names
);
CHECK
(
raw_func
)
<<
"The raw_func can not be nullptr."
;
instr
.
SetLoweredFunc
(
reinterpret_cast
<
void
*>
(
raw_func
));
// should call Finalize explicitly before Run
instr
.
Finalize
();
instr
.
Run
();
// data
for
(
int
i
=
0
;
i
<
output_names
.
size
();
++
i
)
{
const
float
*
result_ptr
=
scope
.
GetTensor
(
output_names
[
i
])
->
data
<
float
>
();
std
::
string
mem_cpy_type
=
target
==
common
::
DefaultNVGPUTarget
()
?
"DeviceToHost"
:
"HostToHost"
;
MemoryCopy
(
result_ptr
,
output_data_ptrs
[
i
],
output_data_numels
[
i
],
mem_cpy_type
);
}
};
// launch and execute test and expected kernel separately
launch_kernel_fn
(
test_func
,
test_output_data_ptrs
);
launch_kernel_fn
(
expected_func
,
expected_output_data_ptrs
);
// Check result
for
(
int
i
=
0
;
i
<
output_shapes
.
size
();
++
i
)
{
for
(
int
j
=
0
;
j
<
output_data_numels
[
i
];
++
j
)
{
ASSERT_NEAR
(
test_output_data_ptrs
[
i
][
j
],
expected_output_data_ptrs
[
i
][
j
],
1e-4
);
}
}
// Free memory
for
(
auto
ptr
:
input_data_ptrs
)
{
free
(
ptr
);
}
for
(
auto
ptr
:
test_output_data_ptrs
)
{
free
(
ptr
);
}
for
(
auto
ptr
:
expected_output_data_ptrs
)
{
free
(
ptr
);
}
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/utils/random_engine.h"
namespace
cinn
{
namespace
auto_schedule
{
/* @brief: Function pointer of executable code compiled by CINN.
* @params-1: Pointers to all arguments, including input and output.
* @params-2: The number of Arguments.
* @return: void
*/
using
raw_func_type
=
void
(
*
)(
void
**
,
int32_t
);
// A base utility class for testing AutoGenRule
class
TestAutoGenRuleBase
:
public
::
testing
::
Test
{
public:
void
SetUp
()
override
{
srand
(
0
);
Context
::
Global
().
ResetNameId
();
}
// Initialize context for specified target
void
Initialize
(
const
common
::
Target
&
target
);
// construct an ir::IRSchedule by lowering the specified for following
// AutoGenRule test
ir
::
IRSchedule
MakeIRSchedule
(
const
frontend
::
Program
&
test_program
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
-
1
,
bool
apply_manual_schedule
=
false
);
// Get the IR of bodies in IRSchedule
std
::
string
GetIR
(
const
ir
::
IRSchedule
&
schedule
);
// build ir::Module from the original lowered funcs with their bodies updated
// by the schedule
ir
::
Module
BuildIRModule
(
const
ir
::
IRSchedule
&
schedule
);
// generate source code with the built ir module
std
::
string
GenSourceCode
(
const
ir
::
Module
&
ir_module
);
// generate executable kernel function with the built ir module
raw_func_type
GenExecutableKernel
(
const
ir
::
Module
&
ir_module
);
protected:
common
::
Target
target_
;
std
::
vector
<
ir
::
LoweredFunc
>
lowered_funcs_
;
std
::
unique_ptr
<
backends
::
Compiler
>
backend_compier_
;
};
/* @brief: Interface for checking function correctness.
* @params-1: Function pointer of the function to be tested.
* @params-2: Expected function pointer for comparison.
* @params-3: Names of input data.
* @params-4: Names of output data.
* @params-5: Shapes of the input data, each input corresponds to a
* std::vector<int>.
* @params-6: Shapes of the output data, each output corresponds to a
* std::vector<int>.
* @params-7: The Target expressing computing platform and architecture of the
* function to be tested.
* @return: void
*/
void
CheckResult
(
raw_func_type
test_func
,
raw_func_type
expected_func
,
const
std
::
vector
<
std
::
string
>&
input_names
,
const
std
::
vector
<
std
::
string
>&
output_names
,
const
std
::
vector
<
std
::
vector
<
int
>>&
input_shapes
,
const
std
::
vector
<
std
::
vector
<
int
>>&
output_shapes
,
const
common
::
Target
&
target
);
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/block_sampler.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/block_sampler.h"
#include <algorithm>
#include "paddle/cinn/ir/ir.h"
namespace
cinn
{
namespace
auto_schedule
{
std
::
unique_ptr
<
BlockSampler
>
BlockSampler
::
Make
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
,
const
std
::
string
&
strategy
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
const
std
::
vector
<
int
>&
weights
)
{
CHECK_GT
(
all_blocks
.
size
(),
0
)
<<
"Empty block list"
;
if
(
strategy
==
"traversal"
)
{
VLOG
(
6
)
<<
"Init TraversalBlockSampler with block num = "
<<
all_blocks
.
size
();
return
std
::
make_unique
<
TraversalBlockSampler
>
(
all_blocks
,
default_remove_policy
);
}
else
if
(
strategy
==
"probabilistic"
)
{
VLOG
(
6
)
<<
"Init ProbabilisticBlockSampler with block num = "
<<
all_blocks
.
size
();
return
std
::
make_unique
<
ProbabilisticBlockSampler
>
(
all_blocks
,
default_remove_policy
,
rand_seed
,
weights
);
}
LOG
(
FATAL
)
<<
"Unimplemented strategy:"
<<
strategy
;
return
nullptr
;
}
BlockSampler
::
BlockSampler
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
)
{
default_remove_policy_
=
default_remove_policy
;
std
::
transform
(
all_blocks
.
begin
(),
all_blocks
.
end
(),
std
::
back_inserter
(
all_blocks_
),
[](
const
ir
::
Expr
&
block_expr
)
{
const
ir
::
ScheduleBlockRealize
*
block_realize
=
block_expr
.
As
<
ir
::
ScheduleBlockRealize
>
();
const
ir
::
ScheduleBlock
*
block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
return
block
->
name
;
});
}
std
::
string
TraversalBlockSampler
::
NextBlock
(
bool
remove
)
{
if
(
cur_idx_
<
all_blocks_
.
size
())
{
VLOG
(
6
)
<<
"[TraversalBlockSampler] next block: "
<<
all_blocks_
.
at
(
cur_idx_
);
std
::
string
block_name
=
all_blocks_
.
at
(
cur_idx_
);
if
(
remove
)
{
++
cur_idx_
;
}
return
block_name
;
}
VLOG
(
6
)
<<
"[TraversalBlockSampler] next block: empty"
;
return
""
;
}
ProbabilisticBlockSampler
::
ProbabilisticBlockSampler
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
const
std
::
vector
<
int
>&
weights
)
:
BlockSampler
(
all_blocks
,
default_remove_policy
),
weights_
(
weights
),
rand_seed_
(
rand_seed
)
{
if
(
weights
.
empty
())
{
weights_
.
resize
(
all_blocks
.
size
(),
1
);
}
else
{
CHECK_EQ
(
all_blocks
.
size
(),
weights_
.
size
());
}
remains_
=
all_blocks
.
size
();
}
std
::
string
ProbabilisticBlockSampler
::
NextBlock
(
bool
remove
)
{
if
(
remains_
==
0
)
{
return
""
;
}
int
block_idx
=
utils
::
SampleDiscreteFromDistribution
<
int
>
(
weights_
,
&
rand_seed_
);
if
(
remove
)
{
weights_
[
block_idx
]
=
0
;
--
remains_
;
}
VLOG
(
6
)
<<
"[ProbabilisticBlockSampler] next block: "
<<
all_blocks_
.
at
(
block_idx
);
return
all_blocks_
.
at
(
block_idx
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/block_sampler.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <vector>
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/utils/random_engine.h"
namespace
cinn
{
namespace
auto_schedule
{
class
SearchState
;
// Select the next block to be operated for SearchState during the search
// process
class
BlockSampler
{
public:
/**
* @brief Create a BlockSampler with the specific strategy name and necessary
* construct parameters.
* @param all_blocks All possible blocks to be sampled.
* @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The block sampling strategy.
* Currently, the available strategies are "traversal" and
* "probabilistic", where "traversal" means to select blocks one by one until
* all blocks are traversed, and "probabilistic" means randomly picking blocks
* according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/
static
std
::
unique_ptr
<
BlockSampler
>
Make
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
=
true
,
const
std
::
string
&
strategy
=
"traversal"
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
0
,
const
std
::
vector
<
int
>&
weights
=
{});
// Return the name of sample strategy
virtual
const
char
*
Name
()
const
=
0
;
// Reset associated states to sample at the beginning
virtual
void
Reset
()
=
0
;
// Select a block with default remove policy.
std
::
string
NextBlock
()
{
return
NextBlock
(
default_remove_policy_
);
}
protected:
// A BlockSampler object should be created with the static function Make()
BlockSampler
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
);
// Select a block to apply rule
// The param remove is used to determine whether to delete the next block
// after selecting it, If remove == true, it will not be sampled in the
// future.
virtual
std
::
string
NextBlock
(
bool
remove
)
=
0
;
// The names of all blocks
// Because the Block Expr will be changed in the search process, the name is
// saved for indexing
std
::
vector
<
std
::
string
>
all_blocks_
;
// The default policy to determine whether to delete the next block after
// selecting it.
bool
default_remove_policy_
;
};
// Sample blocks with traversal strategy,
// witch means to select blocks one by one until all blocks are traversed.
class
TraversalBlockSampler
:
public
BlockSampler
{
public:
TraversalBlockSampler
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
)
:
BlockSampler
(
all_blocks
,
default_remove_policy
),
cur_idx_
(
0
)
{}
const
char
*
Name
()
const
override
{
return
"traversal"
;
}
void
Reset
()
override
{
cur_idx_
=
0
;
}
private:
std
::
string
NextBlock
(
bool
remove
)
override
;
private:
int
cur_idx_
;
};
// Sample blocks with probabilistic strategy,
// witch means randomly picking blocks according to the given distribution.
class
ProbabilisticBlockSampler
:
public
BlockSampler
{
public:
ProbabilisticBlockSampler
(
const
std
::
vector
<
ir
::
Expr
>&
all_blocks
,
bool
default_remove_policy
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
0
,
const
std
::
vector
<
int
>&
weights
=
{});
const
char
*
Name
()
const
override
{
return
"probabilistic"
;
}
void
Reset
()
override
{}
private:
std
::
string
NextBlock
(
bool
remove
)
override
;
private:
std
::
vector
<
int
>
weights_
;
utils
::
LinearRandomEngine
::
StateType
rand_seed_
;
int
remains_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/block_sampler_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/block_sampler.h"
#include <gtest/gtest.h>
#include "paddle/cinn/ir/ir.h"
namespace
cinn
{
namespace
auto_schedule
{
std
::
vector
<
ir
::
Expr
>
CreateTestBlocks
()
{
std
::
vector
<
ir
::
Expr
>
blocks
;
for
(
int
i
=
0
;
i
<
3
;
++
i
)
{
ir
::
Expr
block
=
ir
::
ScheduleBlock
::
Make
(
{},
{},
{},
"block_"
+
std
::
to_string
(
i
),
ir
::
Expr
());
blocks
.
push_back
(
ir
::
ScheduleBlockRealize
::
Make
({},
block
));
}
return
blocks
;
}
TEST
(
BlockSampler
,
Make
)
{
std
::
vector
<
ir
::
Expr
>
mock_blocks
=
CreateTestBlocks
();
auto
traversal_block_sampler
=
BlockSampler
::
Make
(
mock_blocks
,
true
,
"traversal"
);
ASSERT_STREQ
(
traversal_block_sampler
->
Name
(),
"traversal"
);
auto
probabilistic_block_sampler
=
BlockSampler
::
Make
(
mock_blocks
,
true
,
"probabilistic"
);
ASSERT_STREQ
(
probabilistic_block_sampler
->
Name
(),
"probabilistic"
);
}
TEST
(
TraversalBlockSampler
,
NextBlock
)
{
std
::
vector
<
ir
::
Expr
>
blocks
=
CreateTestBlocks
();
auto
traversal_block_sampler
=
BlockSampler
::
Make
(
blocks
,
true
,
"traversal"
);
ASSERT_EQ
(
"block_0"
,
traversal_block_sampler
->
NextBlock
());
ASSERT_EQ
(
"block_1"
,
traversal_block_sampler
->
NextBlock
());
ASSERT_EQ
(
"block_2"
,
traversal_block_sampler
->
NextBlock
());
ASSERT_EQ
(
""
,
traversal_block_sampler
->
NextBlock
());
traversal_block_sampler
->
Reset
();
ASSERT_EQ
(
"block_0"
,
traversal_block_sampler
->
NextBlock
());
traversal_block_sampler
=
BlockSampler
::
Make
(
blocks
,
false
,
"traversal"
);
ASSERT_EQ
(
"block_0"
,
traversal_block_sampler
->
NextBlock
());
ASSERT_EQ
(
"block_0"
,
traversal_block_sampler
->
NextBlock
());
}
TEST
(
ProbabilisticBlockSampler
,
NextBlock
)
{
std
::
vector
<
ir
::
Expr
>
blocks
=
CreateTestBlocks
();
auto
probabilistic_block_sampler
=
BlockSampler
::
Make
(
blocks
,
false
,
"probabilistic"
,
0
,
{
4
,
2
,
1
});
std
::
string
block_name
;
for
(
int
i
=
0
;
i
<
20
;
++
i
)
{
block_name
=
probabilistic_block_sampler
->
NextBlock
();
VLOG
(
6
)
<<
"next block name: "
<<
block_name
;
}
probabilistic_block_sampler
=
BlockSampler
::
Make
(
blocks
,
true
,
"probabilistic"
,
0
,
{
4
,
2
,
1
});
probabilistic_block_sampler
->
NextBlock
();
probabilistic_block_sampler
->
NextBlock
();
probabilistic_block_sampler
->
NextBlock
();
ASSERT_EQ
(
""
,
probabilistic_block_sampler
->
NextBlock
());
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/rule_sampler.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/rule_sampler.h"
#include <algorithm>
#include <random>
namespace
cinn
{
namespace
auto_schedule
{
std
::
unique_ptr
<
RuleSampler
>
RuleSampler
::
Make
(
const
std
::
vector
<
AutoGenRule
*>&
potential_rules
,
bool
default_remove_policy
,
const
std
::
string
&
strategy
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
const
std
::
vector
<
int
>&
weights
)
{
CHECK_GT
(
potential_rules
.
size
(),
0
)
<<
"Empty rule list"
;
if
(
strategy
==
"traversal"
)
{
return
std
::
make_unique
<
TraversalRuleSampler
>
(
potential_rules
,
default_remove_policy
);
}
else
if
(
strategy
==
"probabilistic"
)
{
return
std
::
make_unique
<
ProbabilisticRuleSampler
>
(
potential_rules
,
default_remove_policy
,
rand_seed
,
weights
);
}
LOG
(
FATAL
)
<<
"Unimplemented strategy:"
<<
strategy
;
return
nullptr
;
}
AutoGenRule
*
TraversalRuleSampler
::
NextRule
(
bool
remove
)
{
if
(
cur_idx_
<
potential_rules_
->
size
())
{
AutoGenRule
*
rule
=
potential_rules_
->
at
(
cur_idx_
);
if
(
remove
)
{
++
cur_idx_
;
}
return
rule
;
}
return
nullptr
;
}
ProbabilisticRuleSampler
::
ProbabilisticRuleSampler
(
const
std
::
vector
<
AutoGenRule
*>&
potential_rules
,
bool
default_remove_policy
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
,
const
std
::
vector
<
int
>&
weights
)
:
RuleSampler
(
potential_rules
,
default_remove_policy
),
weights_
(
weights
),
rand_seed_
(
utils
::
LinearRandomEngine
::
NormalizeState
(
rand_seed
))
{
if
(
weights
.
empty
())
{
weights_
.
resize
(
potential_rules
.
size
(),
1
);
}
else
{
CHECK_EQ
(
potential_rules
.
size
(),
weights_
.
size
());
}
remains_
=
potential_rules
.
size
();
}
AutoGenRule
*
ProbabilisticRuleSampler
::
NextRule
(
bool
remove
)
{
if
(
remains_
==
0
)
{
return
nullptr
;
}
int
rule_idx
=
utils
::
SampleDiscreteFromDistribution
<
int
>
(
weights_
,
&
rand_seed_
);
if
(
remove
)
{
weights_
[
rule_idx
]
=
0
;
--
remains_
;
}
return
potential_rules_
->
at
(
rule_idx
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/rule_sampler.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <random>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/utils/random_engine.h"
namespace
cinn
{
namespace
auto_schedule
{
class
SearchState
;
// Select the next potential rule for the SearchState during the search process.
class
RuleSampler
{
public:
/**
* @brief Create a RuleSampler with the specific strategy name and necessary
* construct parameters.
* @param potential_rules All possible rules to be sampled.
* @param default_remove_policy The default option to determine whether to
* delete the next block after selecting it.
* @param strategy The rule sampling strategy.
* Currently, the available strategies are "traversal" and
* "probabilistic", where "traversal" means to select rules one by one until
* all rules are traversed, and "probabilistic" means randomly picking rules
* according to the given distribution.
* @param weights Used for the probabilistic policy, giving each candidate a
* weight.
*/
static
std
::
unique_ptr
<
RuleSampler
>
Make
(
const
std
::
vector
<
AutoGenRule
*>&
potential_rules
,
bool
default_remove_policy
=
true
,
const
std
::
string
&
strategy
=
"traversal"
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
0
,
const
std
::
vector
<
int
>&
weights
=
{});
// Return the name of sample strategy
virtual
const
char
*
Name
()
const
=
0
;
// Reset associated states to sample at the beginning
virtual
void
Reset
()
=
0
;
// Select a rule with default remove policy.
AutoGenRule
*
NextRule
()
{
return
NextRule
(
default_remove_policy_
);
}
protected:
// A RuleSampler object should be created with the static function Make()
RuleSampler
(
const
std
::
vector
<
AutoGenRule
*>&
potential_rules
,
bool
default_remove_policy
)
:
potential_rules_
(
&
potential_rules
),
default_remove_policy_
(
default_remove_policy
)
{}
// Select a rule to apply.
// The param remove is used to determine whether to delete the next rule after
// selecting it, If remove == true, it will not be sampled in the future.
virtual
AutoGenRule
*
NextRule
(
bool
remove
)
=
0
;
// The pointer refers to all potential rules
const
std
::
vector
<
AutoGenRule
*>*
potential_rules_
;
// The default policy to determine whether to delete the next rule after
// selecting it.
bool
default_remove_policy_
;
};
// Sample rules with traversal strategy,
// witch means to select rules one by one until all rules are traversed.
class
TraversalRuleSampler
:
public
RuleSampler
{
public:
TraversalRuleSampler
(
const
std
::
vector
<
AutoGenRule
*>&
potential_rules
,
bool
default_remove_policy
)
:
RuleSampler
(
potential_rules
,
default_remove_policy
),
cur_idx_
(
0
)
{}
const
char
*
Name
()
const
override
{
return
"traversal"
;
}
void
Reset
()
override
{
cur_idx_
=
0
;
}
private:
AutoGenRule
*
NextRule
(
bool
remove
)
override
;
private:
int
cur_idx_
;
};
// Sample rules with probabilistic strategy,
// which means randomly picking rules according to the given distribution.
class
ProbabilisticRuleSampler
:
public
RuleSampler
{
public:
ProbabilisticRuleSampler
(
const
std
::
vector
<
AutoGenRule
*>&
potential_rules
,
bool
default_remove_policy
,
utils
::
LinearRandomEngine
::
StateType
rand_seed
=
0
,
const
std
::
vector
<
int
>&
weights
=
{});
const
char
*
Name
()
const
override
{
return
"probabilistic"
;
}
void
Reset
()
override
{}
private:
AutoGenRule
*
NextRule
(
bool
remove
)
override
;
private:
std
::
vector
<
int
>
weights_
;
utils
::
LinearRandomEngine
::
StateType
rand_seed_
;
int
remains_
;
};
}
// namespace auto_schedule
}
// namespace cinn
Prev
1
…
5
6
7
8
9
10
11
12
13
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment