Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
171 additions
and
430 deletions
+171
-430
paddle/cinn/optim/cast_simplify_test.cc
paddle/cinn/optim/cast_simplify_test.cc
+6
-8
paddle/cinn/optim/collect_undefined_vars.cc
paddle/cinn/optim/collect_undefined_vars.cc
+0
-111
paddle/cinn/optim/collect_undefined_vars.h
paddle/cinn/optim/collect_undefined_vars.h
+0
-36
paddle/cinn/optim/compute_inline_expand.cc
paddle/cinn/optim/compute_inline_expand.cc
+9
-8
paddle/cinn/optim/eliminate_broadcast_in_forloop.cc
paddle/cinn/optim/eliminate_broadcast_in_forloop.cc
+9
-9
paddle/cinn/optim/extern_call_process.cc
paddle/cinn/optim/extern_call_process.cc
+1
-1
paddle/cinn/optim/fold_cinn_call_arguments.cc
paddle/cinn/optim/fold_cinn_call_arguments.cc
+2
-2
paddle/cinn/optim/insert_debug_log_callee.cc
paddle/cinn/optim/insert_debug_log_callee.cc
+2
-2
paddle/cinn/optim/ir_simplify.cc
paddle/cinn/optim/ir_simplify.cc
+125
-25
paddle/cinn/optim/ir_simplify.h
paddle/cinn/optim/ir_simplify.h
+2
-0
paddle/cinn/optim/lower_function_call_bind_vars.cc
paddle/cinn/optim/lower_function_call_bind_vars.cc
+1
-1
paddle/cinn/optim/lower_intrin.cc
paddle/cinn/optim/lower_intrin.cc
+1
-1
paddle/cinn/optim/map_extern_call.cc
paddle/cinn/optim/map_extern_call.cc
+1
-1
paddle/cinn/optim/optimize.cc
paddle/cinn/optim/optimize.cc
+8
-7
paddle/cinn/optim/optimize_test.cc
paddle/cinn/optim/optimize_test.cc
+1
-1
paddle/cinn/optim/remove_nested_block.cc
paddle/cinn/optim/remove_nested_block.cc
+0
-123
paddle/cinn/optim/remove_nested_block.h
paddle/cinn/optim/remove_nested_block.h
+0
-33
paddle/cinn/optim/remove_nested_block_test.cc
paddle/cinn/optim/remove_nested_block_test.cc
+0
-58
paddle/cinn/optim/remove_schedule_block.cc
paddle/cinn/optim/remove_schedule_block.cc
+2
-2
paddle/cinn/optim/remove_schedule_block_test.cc
paddle/cinn/optim/remove_schedule_block_test.cc
+1
-1
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/optim/cast_simplify_test.cc
View file @
01a10755
...
...
@@ -12,13 +12,11 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/cast_simplify.h"
#include <gtest/gtest.h>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/ir_simplify.h"
namespace
cinn
::
optim
{
TEST
(
CastSimplify
,
same_type
)
{
...
...
@@ -26,7 +24,7 @@ TEST(CastSimplify, same_type) {
Expr
a
=
ir
::
Cast
::
Make
(
Int
(
32
),
n
);
LOG
(
INFO
)
<<
n
->
type
();
LOG
(
INFO
)
<<
a
;
Cast
Simplify
(
&
a
);
Simplify
Cast
(
&
a
);
ASSERT_EQ
(
utils
::
GetStreamCnt
(
a
),
"n"
);
}
...
...
@@ -34,7 +32,7 @@ TEST(CastSimplify, Imm_int) {
Expr
a
=
ir
::
Cast
::
Make
(
Int
(
64
),
Expr
(
1
));
Expr
c
=
ir
::
Cast
::
Make
(
Int
(
32
),
a
);
LOG
(
INFO
)
<<
c
;
Cast
Simplify
(
&
c
);
Simplify
Cast
(
&
c
);
LOG
(
INFO
)
<<
c
;
ASSERT_EQ
(
utils
::
GetStreamCnt
(
c
),
"1"
);
ASSERT_EQ
(
c
.
type
(),
Int
(
32
));
...
...
@@ -44,7 +42,7 @@ TEST(CastSimplify, Imm_double) {
Expr
a
=
ir
::
Cast
::
Make
(
Float
(
64
),
Expr
(
2.33
));
Expr
c
=
ir
::
Cast
::
Make
(
Int
(
32
),
a
);
LOG
(
INFO
)
<<
c
;
Cast
Simplify
(
&
c
);
Simplify
Cast
(
&
c
);
LOG
(
INFO
)
<<
c
;
ASSERT_EQ
(
utils
::
GetStreamCnt
(
c
),
"2"
);
ASSERT_EQ
(
c
.
type
(),
Int
(
32
));
...
...
@@ -54,7 +52,7 @@ TEST(CastSimplify, Imm_uint) {
Expr
a
=
ir
::
Cast
::
Make
(
UInt
(
64
),
Expr
(
1
));
Expr
c
=
ir
::
Cast
::
Make
(
UInt
(
32
),
a
);
LOG
(
INFO
)
<<
c
;
Cast
Simplify
(
&
c
);
Simplify
Cast
(
&
c
);
LOG
(
INFO
)
<<
c
;
ASSERT_EQ
(
utils
::
GetStreamCnt
(
c
),
"1"
);
ASSERT_EQ
(
c
.
type
(),
UInt
(
32
));
...
...
paddle/cinn/optim/collect_undefined_vars.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/collect_undefined_vars.h"
#include <set>
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace
cinn
::
optim
{
namespace
{
struct
Mutator
:
public
ir
::
IRMutator
<>
{
using
ir
::
IRMutator
<>::
Visit
;
std
::
vector
<
std
::
string
>
undefined_vars
;
std
::
set
<
std
::
string
>
defined_vars
;
std
::
set
<
std
::
string
>
used_vars
;
void
CollectVarDef
(
const
std
::
string
&
var
)
{
CHECK
(
!
defined_vars
.
count
(
var
))
<<
"var "
<<
var
<<
" has been defined, please check"
;
CHECK
(
!
used_vars
.
count
(
var
))
<<
"var "
<<
var
<<
" is wrongly used before definition"
;
defined_vars
.
insert
(
var
);
}
void
ClearVar
(
const
std
::
string
&
var
)
{
defined_vars
.
erase
(
var
);
used_vars
.
erase
(
var
);
}
void
CollectVarUse
(
const
std
::
string
&
var
)
{
used_vars
.
insert
(
var
);
if
(
defined_vars
.
count
(
var
)
==
0
)
{
undefined_vars
.
push_back
(
var
);
}
}
void
Visit
(
const
ir
::
Let
*
op
,
Expr
*
expr
)
final
{
Expr
symbol
=
op
->
symbol
;
auto
var
=
symbol
.
as_var_ref
();
CHECK
(
var
.
defined
());
CollectVarDef
(
var
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
Let
>
();
Visit
(
&
node
->
body
,
&
node
->
body
);
}
void
Visit
(
const
ir
::
For
*
op
,
Expr
*
expr
)
final
{
CollectVarDef
(
op
->
loop_var
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
For
>
();
Visit
(
&
node
->
min
,
&
node
->
min
);
Visit
(
&
node
->
extent
,
&
node
->
extent
);
Visit
(
&
node
->
body
,
&
node
->
body
);
ClearVar
(
op
->
loop_var
->
name
);
}
void
Visit
(
const
ir
::
Load
*
op
,
Expr
*
expr
)
final
{
auto
tensor
=
op
->
tensor
.
as_tensor_ref
();
CollectVarUse
(
tensor
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
Load
>
();
for
(
auto
&
idx
:
node
->
indices
)
Visit
(
&
idx
,
&
idx
);
}
void
Visit
(
const
ir
::
Store
*
op
,
Expr
*
expr
)
final
{
auto
tensor
=
op
->
tensor
.
as_tensor_ref
();
CollectVarUse
(
tensor
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
Store
>
();
for
(
auto
&
idx
:
node
->
indices
)
Visit
(
&
idx
,
&
idx
);
Visit
(
&
node
->
value
,
&
node
->
value
);
}
void
Visit
(
const
ir
::
_Var_
*
op
,
Expr
*
expr
)
final
{
CollectVarUse
(
op
->
name
);
auto
*
node
=
expr
->
As
<
ir
::
_Var_
>
();
if
(
node
->
lower_bound
.
defined
())
{
Visit
(
&
node
->
lower_bound
,
&
node
->
lower_bound
);
}
if
(
node
->
upper_bound
.
defined
())
{
Visit
(
&
node
->
upper_bound
,
&
node
->
upper_bound
);
}
}
void
Visit
(
const
ir
::
Reduce
*
op
,
Expr
*
expr
)
final
{
for
(
auto
&
axis
:
op
->
reduce_axis
)
{
CollectVarDef
(
axis
->
name
);
}
auto
*
node
=
expr
->
As
<
ir
::
Reduce
>
();
if
(
node
->
init
.
defined
())
Visit
(
&
node
->
init
,
&
node
->
init
);
Visit
(
&
node
->
body
,
&
node
->
body
);
}
};
}
// namespace
std
::
vector
<
std
::
string
>
CollectUndefinedVars
(
Expr
*
e
)
{
Mutator
mutator
;
mutator
.
Visit
(
e
,
e
);
return
mutator
.
undefined_vars
;
}
}
// namespace cinn::optim
paddle/cinn/optim/collect_undefined_vars.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/ir/ir.h"
namespace
cinn
::
optim
{
/**
* Collect undefined vars in the scope.
*
* e.g.
*
* The expression:
* for i
* for j
* a[i, j] = b[i, j]
*
* here a, b are vars without definition
*/
std
::
vector
<
std
::
string
>
CollectUndefinedVars
(
Expr
*
e
);
}
// namespace cinn::optim
paddle/cinn/optim/compute_inline_expand.cc
View file @
01a10755
...
...
@@ -18,8 +18,8 @@
#include <string>
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace
cinn
{
...
...
@@ -150,7 +150,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {
}
ir
::
IRMutator
<>::
Visit
(
&
node
->
tensor
,
&
node
->
tensor
);
for
(
int
i
=
0
;
i
<
node
->
indices
.
size
();
i
++
)
{
auto
temp
=
optim
::
IRCopy
(
node
->
indices
[
i
]);
auto
temp
=
ir
::
ir_utils
::
IRCopy
(
node
->
indices
[
i
]);
ir
::
IRMutator
<>::
Visit
(
&
temp
,
&
temp
);
node
->
indices
[
i
]
=
temp
;
}
...
...
@@ -159,7 +159,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {
}
else
{
ir
::
IRMutator
<>::
Visit
(
&
node
->
tensor
,
&
node
->
tensor
);
for
(
int
i
=
0
;
i
<
node
->
indices
.
size
();
i
++
)
{
auto
temp
=
optim
::
IRCopy
(
node
->
indices
[
i
]);
auto
temp
=
ir
::
ir_utils
::
IRCopy
(
node
->
indices
[
i
]);
ir
::
IRMutator
<>::
Visit
(
&
temp
,
&
temp
);
node
->
indices
[
i
]
=
temp
;
}
...
...
@@ -167,7 +167,7 @@ struct TensorInlineExpandMutator : public ir::IRMutator<> {
}
else
{
ir
::
IRMutator
<>::
Visit
(
&
node
->
tensor
,
&
node
->
tensor
);
for
(
int
i
=
0
;
i
<
node
->
indices
.
size
();
i
++
)
{
auto
temp
=
optim
::
IRCopy
(
node
->
indices
[
i
]);
auto
temp
=
ir
::
ir_utils
::
IRCopy
(
node
->
indices
[
i
]);
ir
::
IRMutator
<>::
Visit
(
&
temp
,
&
temp
);
node
->
indices
[
i
]
=
temp
;
}
...
...
@@ -225,7 +225,7 @@ void ComputeInlineExpand(Expr *expr,
poly
::
StageMap
stages
,
std
::
map
<
std
::
string
,
ir
::
Tensor
>
*
all_tensor_map
)
{
// the inline tensors contained in the expression.
auto
inline_tensors
=
ir
::
CollectIRNodes
(
*
expr
,
[
&
](
const
Expr
*
x
)
{
auto
inline_tensors
=
ir
::
ir_utils
::
CollectIRNodes
(
*
expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
stages
[
x
->
as_tensor
()]
->
inlined
();
});
...
...
@@ -240,9 +240,10 @@ void ComputeInlineExpand(Expr *expr,
TensorInlineExpandMutator
(
tensor
->
name
,
all_tensor_map
,
stages
)(
expr
);
}
inline_tensors
=
ir
::
CollectLoadTensors
(
*
expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
stages
[
x
->
as_tensor
()]
->
inlined
();
});
inline_tensors
=
ir
::
ir_utils
::
CollectLoadTensors
(
*
expr
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
stages
[
x
->
as_tensor
()]
->
inlined
();
});
}
}
...
...
paddle/cinn/optim/eliminate_broadcast_in_forloop.cc
View file @
01a10755
...
...
@@ -17,10 +17,10 @@
#include <tuple>
#include <vector>
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/
utils/
ir_visitor.h"
#include "paddle/cinn/
optim
/ir_replace.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/
ir/utils
/ir_replace.h"
namespace
cinn
{
namespace
optim
{
...
...
@@ -36,9 +36,9 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
auto
*
node
=
expr
->
As
<
ir
::
Store
>
();
auto
broadcasts
=
ir
::
CollectIRNodes
(
node
->
value
,
[
&
](
const
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
Broadcast
>
();
});
auto
broadcasts
=
ir
::
ir_utils
::
CollectIRNodes
(
node
->
value
,
[
&
](
const
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
Broadcast
>
();
});
std
::
vector
<
Expr
>
let_exprs
;
Var
tmp
;
...
...
@@ -54,7 +54,7 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
std
::
tie
(
let_expr
,
tmp
)
=
CreateTmpLet
(
broadcast
);
let_exprs
.
push_back
(
let_expr
);
optim
::
IrReplace
(
expr
,
broadcast
,
tmp
);
cinn
::
ir
::
ir_utils
::
IrReplace
(
expr
,
broadcast
,
tmp
);
}
// insert the let expressions to the outer forloop.
...
...
@@ -79,7 +79,7 @@ struct EliminateBroadcastInForloop : public ir::IRMutator<Expr*> {
}
bool
ContainsLoopVar
(
Expr
expr
,
Var
loop_var
)
{
return
!
ir
::
CollectIRNodes
(
expr
,
[
&
](
const
Expr
*
e
)
->
bool
{
return
!
ir
::
ir_utils
::
CollectIRNodes
(
expr
,
[
&
](
const
Expr
*
e
)
->
bool
{
return
e
->
As
<
ir
::
_Var_
>
()
&&
e
->
As
<
ir
::
_Var_
>
()
->
name
==
loop_var
->
name
;
}).
empty
();
...
...
paddle/cinn/optim/extern_call_process.cc
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace
cinn
{
namespace
optim
{
...
...
paddle/cinn/optim/fold_cinn_call_arguments.cc
View file @
01a10755
...
...
@@ -17,8 +17,8 @@
#include <unordered_set>
#include <vector>
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
paddle/cinn/optim/insert_debug_log_callee.cc
View file @
01a10755
...
...
@@ -19,8 +19,8 @@
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
#include "paddle/cinn/utils/string.h"
...
...
paddle/cinn/optim/ir_simplify.cc
View file @
01a10755
...
...
@@ -24,18 +24,19 @@
#include "paddle/cinn/common/arithmatic.h"
#include "paddle/cinn/common/cas.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/ir_visitor.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
optim
{
using
namespace
ir
;
// NOLINT
using
common
::
bfloat16
;
using
common
::
ExprToGinacConverter
;
using
common
::
float16
;
using
utils
::
GetStreamCnt
;
using
utils
::
Replace
;
...
...
@@ -53,9 +54,9 @@ void PartialSimplify(
}
//! Simplify the expression but Load.
struct
Simplify
ButStoreLoad
Mutator
:
public
ir
::
IRMutator
<
ir
::
Expr
*>
{
struct
Simplify
NoPureMath
Mutator
:
public
ir
::
IRMutator
<
ir
::
Expr
*>
{
common
::
cas_intervals_t
&
var_intervals
;
explicit
Simplify
ButStoreLoad
Mutator
(
explicit
Simplify
NoPureMath
Mutator
(
common
::
cas_intervals_t
&
var_intervals
)
// NOLINT
:
var_intervals
(
var_intervals
)
{}
...
...
@@ -76,19 +77,6 @@ struct SimplifyButStoreLoadMutator : public ir::IRMutator<ir::Expr*> {
__
(
Max
)
#undef __
void
Visit
(
const
Ramp
*
op
,
Expr
*
expr
)
override
{
auto
*
node
=
expr
->
As
<
Ramp
>
();
CHECK
(
common
::
IsPureMath
(
node
->
base
));
CHECK
(
common
::
IsPureMath
(
node
->
stride
));
PartialSimplify
(
&
node
->
base
,
var_intervals
);
PartialSimplify
(
&
node
->
stride
,
var_intervals
);
}
void
Visit
(
const
Cast
*
op
,
Expr
*
expr
)
override
{
auto
*
node
=
expr
->
As
<
Cast
>
();
Visit
(
&
node
->
v
(),
&
node
->
v
());
}
void
Visit
(
const
PolyFor
*
op
,
Expr
*
expr
)
override
{
auto
*
node
=
expr
->
As
<
ir
::
PolyFor
>
();
node
->
condition
=
common
::
SolveInequality
(
op
->
condition
,
op
->
iterator
);
...
...
@@ -138,7 +126,7 @@ struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
if
(
common
::
IsPureMath
(
idx
))
{
PartialSimplify
(
&
idx
,
var_intervals_
);
}
else
{
Simplify
ButStoreLoad
Mutator
mutator
(
var_intervals_
);
Simplify
NoPureMath
Mutator
mutator
(
var_intervals_
);
mutator
(
&
idx
);
}
}
...
...
@@ -176,7 +164,7 @@ struct SimplifyStoreMutator : public ir::IRMutator<ir::Expr*> {
if
(
common
::
IsPureMath
(
idx
))
{
PartialSimplify
(
&
idx
,
var_intervals_
);
}
else
{
Simplify
ButStoreLoad
Mutator
mutator
(
var_intervals_
);
Simplify
NoPureMath
Mutator
mutator
(
var_intervals_
);
mutator
(
&
idx
);
}
}
...
...
@@ -215,8 +203,8 @@ struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
CHECK
(
common
::
IsPureMath
(
node
->
stride
))
<<
node
->
stride
<<
"is not a pure math!"
;
Simplify
(
&
node
->
base
);
Simplify
(
&
node
->
stride
);
Partial
Simplify
(
&
node
->
base
);
Partial
Simplify
(
&
node
->
stride
);
}
// ramp + ramp
void
Visit
(
const
Add
*
op
,
Expr
*
expr
)
override
{
...
...
@@ -317,6 +305,33 @@ struct SimplifyBlocksMutator : public ir::IRMutator<> {
expr
->
As
<
ir
::
Block
>
()
->
stmts
=
stmts
;
}
}
void
Visit
(
const
ScheduleBlock
*
op
,
Expr
*
expr
)
override
{
auto
*
node
=
expr
->
As
<
ScheduleBlock
>
();
CHECK
(
node
);
for
(
auto
&
var
:
node
->
iter_vars
)
{
if
(
var
->
lower_bound
.
defined
())
{
Visit
(
&
var
->
lower_bound
,
&
var
->
lower_bound
);
}
if
(
var
->
upper_bound
.
defined
())
{
Visit
(
&
var
->
upper_bound
,
&
var
->
upper_bound
);
}
}
for
(
auto
&
buffer_region
:
node
->
read_buffers
)
{
Visit
(
&
buffer_region
,
&
buffer_region
);
}
for
(
auto
&
buffer_region
:
node
->
write_buffers
)
{
Visit
(
&
buffer_region
,
&
buffer_region
);
}
if
(
node
->
body
.
As
<
Block
>
())
{
if
(
node
->
body
.
As
<
Block
>
()
->
stmts
.
size
()
==
1
)
{
node
->
body
=
node
->
body
.
As
<
Block
>
()
->
stmts
[
0
];
}
}
Visit
(
&
(
node
->
body
),
&
(
node
->
body
));
}
};
struct
SimplifyForLoopsMutator
:
public
ir
::
IRMutator
<>
{
...
...
@@ -359,23 +374,108 @@ struct SimplifyForLoopsMutator : public ir::IRMutator<> {
}
};
template
<
typename
CastType
,
typename
T
>
CastType
NormCastValue
(
T
value
)
{
if
(
type_of
<
CastType
>
().
is_uint
()
||
type_of
<
T
>
().
is_uint
())
{
// not support uint
return
static_cast
<
CastType
>
(
value
);
}
if
(
std
::
isinf
(
value
))
{
return
std
::
numeric_limits
<
CastType
>::
infinity
();
}
else
if
(
std
::
isnan
(
value
))
{
return
std
::
numeric_limits
<
CastType
>::
signaling_NaN
();
}
else
if
(
value
>=
static_cast
<
T
>
(
std
::
numeric_limits
<
CastType
>::
max
()))
{
return
std
::
numeric_limits
<
CastType
>::
max
();
}
else
if
(
value
<=
static_cast
<
T
>
(
std
::
numeric_limits
<
CastType
>::
lowest
()))
{
return
std
::
numeric_limits
<
CastType
>::
lowest
();
}
return
static_cast
<
CastType
>
(
value
);
}
struct
SimplifyCastMutator
:
public
ir
::
IRMutator
<>
{
void
operator
()(
Expr
*
expr
)
{
ir
::
IRMutator
<
ir
::
Expr
*>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
Cast
*
op
,
Expr
*
expr
)
{
auto
*
node
=
expr
->
As
<
ir
::
Cast
>
();
ir
::
IRMutator
<
ir
::
Expr
*>::
Visit
(
&
node
->
v
(),
&
node
->
v
());
if
(
op
->
type
()
==
op
->
v
().
type
())
{
*
expr
=
op
->
v
();
return
;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if
(
op
->
v
().
is_constant
())
{
if
(
op
->
type
()
==
type_of
<
int8_t
>
())
{
__CAST_TO_TYPE
(
int8_t
)
}
else
if
(
op
->
type
()
==
type_of
<
int16_t
>
())
{
__CAST_TO_TYPE
(
int16_t
)
}
else
if
(
op
->
type
()
==
type_of
<
int32_t
>
())
{
__CAST_TO_TYPE
(
int32_t
)
}
else
if
(
op
->
type
()
==
type_of
<
int64_t
>
())
{
__CAST_TO_TYPE
(
int64_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint8_t
>
())
{
__CAST_TO_TYPE
(
uint8_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint16_t
>
())
{
__CAST_TO_TYPE
(
uint16_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint32_t
>
())
{
__CAST_TO_TYPE
(
uint32_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint64_t
>
())
{
__CAST_TO_TYPE
(
uint64_t
)
}
else
if
(
op
->
type
()
==
type_of
<
float
>
())
{
__CAST_TO_TYPE
(
float
)
}
else
if
(
op
->
type
()
==
type_of
<
double
>
())
{
__CAST_TO_TYPE
(
double
)
}
else
if
(
op
->
type
()
==
type_of
<
bool
>
())
{
__CAST_TO_TYPE
(
bool
)
}
else
if
(
op
->
type
()
==
type_of
<
uint32_t
>
())
{
__CAST_TO_TYPE
(
uint32_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint64_t
>
())
{
__CAST_TO_TYPE
(
uint64_t
)
}
else
if
(
op
->
type
()
==
type_of
<
bfloat16
>
())
{
// Cannot simplify!!! pass
__CAST_TO_TYPE
(
bfloat16
)
}
else
if
(
op
->
type
()
==
type_of
<
float16
>
())
{
// Cannot simplify!!! pass
__CAST_TO_TYPE
(
float16
)
}
else
{
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
}
// namespace
void
Simplify
(
Expr
*
expr
)
{
VLOG
(
3
)
<<
"Begin Simplify "
<<
*
expr
;
optim
::
CastSimplify
(
expr
);
SimplifyCastMutator
()
(
expr
);
SimplifyRampMutator
()(
expr
);
SimplifyLoadMutator
()(
expr
);
SimplifyStoreMutator
()(
expr
);
SimplifyIfThenElseMutator
()(
expr
);
common
::
cas_intervals_t
var_intervals
;
Simplify
ButStoreLoad
Mutator
mutator
(
var_intervals
);
Simplify
NoPureMath
Mutator
mutator
(
var_intervals
);
mutator
(
expr
);
ReplaceFracWithDivMutator
()(
expr
);
}
void
SimplifyCast
(
Expr
*
expr
)
{
SimplifyCastMutator
()(
expr
);
}
void
SimplifyForLoops
(
Expr
*
expr
)
{
SimplifyForLoopsMutator
()(
expr
);
}
void
SimplifyBlocks
(
Expr
*
expr
)
{
SimplifyBlocksMutator
()(
expr
);
}
...
...
paddle/cinn/optim/ir_simplify.h
View file @
01a10755
...
...
@@ -30,6 +30,8 @@ namespace optim {
*/
void
Simplify
(
Expr
*
expr
);
void
SimplifyCast
(
Expr
*
expr
);
void
SimplifyForLoops
(
Expr
*
expr
);
void
SimplifyBlocks
(
Expr
*
expr
);
...
...
paddle/cinn/optim/lower_function_call_bind_vars.cc
View file @
01a10755
...
...
@@ -17,7 +17,7 @@
#include <string>
#include <vector>
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace
cinn
{
namespace
optim
{
...
...
paddle/cinn/optim/lower_intrin.cc
View file @
01a10755
...
...
@@ -19,8 +19,8 @@
#include "paddle/cinn/backends/llvm/llvm_intrin_rule.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/intrinsic_ops.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/registry.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace
cinn
{
namespace
optim
{
...
...
paddle/cinn/optim/map_extern_call.cc
View file @
01a10755
...
...
@@ -16,7 +16,7 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/hlir/op/op_util.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/runtime/cpu/host_intrinsics.h"
namespace
cinn
{
...
...
paddle/cinn/optim/optimize.cc
View file @
01a10755
...
...
@@ -14,12 +14,11 @@
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/schedule/ir_schedule_util.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/call_arg_list_to_pod_value.h"
#include "paddle/cinn/optim/cast_bool_to_int8.h"
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/optim/eliminate_broadcast_in_forloop.h"
#include "paddle/cinn/optim/extern_call_process.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
...
...
@@ -28,9 +27,9 @@
#include "paddle/cinn/optim/lower_function_call_bind_vars.h"
#include "paddle/cinn/optim/lower_intrin.h"
#include "paddle/cinn/optim/map_extern_call.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/remove_schedule_block.h"
#include "paddle/cinn/optim/replace_const_param_to_integer.h"
#include "paddle/cinn/optim/replace_cross_thread_reduction.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/optim/unroll_loops.h"
...
...
@@ -44,13 +43,14 @@ Expr Optimize(Expr e,
bool
runtime_debug_info
,
bool
remove_gpu_for_loops
)
{
CHECK
(
e
.
defined
());
auto
copied
=
IRCopy
(
e
);
auto
copied
=
ir
::
ir_utils
::
IRCopy
(
e
);
FoldCINNCallArguments
(
&
copied
);
TransformPolyForToFor
(
&
copied
);
ReplaceConstParamToInteger
(
&
copied
);
// Simplify already contains CastSimplify
Simplify
(
&
copied
);
ReplaceCrossThreadReduction
(
&
copied
);
UnrollLoop
(
&
copied
);
VLOG
(
4
)
<<
"After Optimize UnrollLoop:"
<<
copied
;
...
...
@@ -66,8 +66,8 @@ Expr Optimize(Expr e,
CudaSyncThreadsDropIfThenElse
(
&
copied
);
#endif
RemoveNested
Block
(
&
copied
);
VLOG
(
4
)
<<
"After
Optimize RemoveNested
Block:"
<<
copied
;
Simplify
Block
s
(
&
copied
);
VLOG
(
4
)
<<
"After
Simplify
Block
s
:"
<<
copied
;
MapExternCall
(
&
copied
,
target
);
VLOG
(
10
)
<<
"After Optimize MapExternCall:"
<<
copied
;
...
...
@@ -86,7 +86,8 @@ Expr Optimize(Expr e,
}
ir
::
Module
Optimize
(
const
ir
::
Module
&
module
,
const
Target
&
target
)
{
auto
copied
=
IRCopy
(
Expr
(
module
));
auto
copied
=
ir
::
ir_utils
::
IRCopy
(
Expr
(
module
));
ReplaceCrossThreadReduction
(
&
copied
);
UnrollLoop
(
&
copied
);
VectorizeLoops
(
&
copied
,
Target
());
VLOG
(
10
)
<<
"After VectorizeLoops:"
<<
copied
.
as_module_ref
();
...
...
paddle/cinn/optim/optimize_test.cc
View file @
01a10755
...
...
@@ -17,7 +17,7 @@
#include <gtest/gtest.h>
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
paddle/cinn/optim/remove_nested_block.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
optim
{
Expr
GetExprInsideBlock
(
Expr
op
)
{
Expr
node
=
op
;
while
(
node
.
As
<
ir
::
Block
>
())
{
auto
&
stmts
=
node
.
As
<
ir
::
Block
>
()
->
stmts
;
if
(
stmts
.
size
()
==
1
)
{
node
=
stmts
.
front
();
}
else
{
break
;
}
}
return
node
;
}
// This will remove the nested blocks, but it will also remove the block outside
// the forloop's body.
struct
NestedBlockSimplifer
:
public
ir
::
IRMutator
<
Expr
*>
{
void
operator
()(
ir
::
Expr
*
expr
)
{
Visit
(
expr
);
}
private:
void
Visit
(
ir
::
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
Block
*
expr
,
Expr
*
op
)
override
{
auto
*
node
=
op
->
As
<
ir
::
Block
>
();
if
(
node
->
stmts
.
size
()
==
1
)
{
*
op
=
GetExprInsideBlock
(
*
op
);
IRMutator
::
Visit
(
op
,
op
);
}
else
{
IRMutator
::
Visit
(
expr
,
op
);
}
}
};
struct
NestedBlockRemover
:
public
ir
::
IRMutator
<
Expr
*>
{
void
operator
()(
ir
::
Expr
*
expr
)
{
Visit
(
expr
);
}
private:
void
Visit
(
ir
::
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
Block
*
expr
,
Expr
*
op
)
override
{
auto
*
node
=
op
->
As
<
ir
::
Block
>
();
std
::
vector
<
ir
::
Expr
>
new_exprs
;
bool
detect_nested
=
false
;
for
(
auto
it
=
node
->
stmts
.
begin
();
it
!=
node
->
stmts
.
end
();
it
++
)
{
auto
*
block
=
it
->
As
<
ir
::
Block
>
();
if
(
block
)
{
detect_nested
=
true
;
new_exprs
.
insert
(
std
::
end
(
new_exprs
),
block
->
stmts
.
begin
(),
block
->
stmts
.
end
());
}
else
{
new_exprs
.
push_back
(
*
it
);
}
}
node
->
stmts
=
new_exprs
;
IRMutator
::
Visit
(
expr
,
op
);
}
};
// add block outside forloop's body.
struct
AddBlockToForloop
:
public
ir
::
IRMutator
<>
{
void
operator
()(
ir
::
Expr
*
expr
)
{
ir
::
IRMutator
<>::
Visit
(
expr
,
expr
);
}
void
Visit
(
const
ir
::
For
*
expr
,
Expr
*
op
)
override
{
auto
*
node
=
op
->
As
<
ir
::
For
>
();
if
(
!
node
->
body
.
As
<
ir
::
Block
>
())
{
node
->
body
=
ir
::
Block
::
Make
({
node
->
body
});
}
ir
::
IRMutator
<>::
Visit
(
expr
,
op
);
}
void
Visit
(
const
ir
::
PolyFor
*
expr
,
Expr
*
op
)
override
{
auto
*
node
=
op
->
As
<
ir
::
PolyFor
>
();
if
(
!
node
->
body
.
As
<
ir
::
Block
>
())
{
node
->
body
=
ir
::
Block
::
Make
({
node
->
body
});
}
ir
::
IRMutator
<>::
Visit
(
expr
,
op
);
}
void
Visit
(
const
ir
::
_LoweredFunc_
*
expr
,
Expr
*
op
)
override
{
auto
*
node
=
op
->
As
<
ir
::
_LoweredFunc_
>
();
if
(
!
node
->
body
.
As
<
ir
::
Block
>
())
{
node
->
body
=
ir
::
Block
::
Make
({
node
->
body
});
}
ir
::
IRMutator
<>::
Visit
(
expr
,
op
);
}
};
void
RemoveNestedBlock
(
Expr
*
e
)
{
NestedBlockRemover
()(
e
);
NestedBlockSimplifer
()(
e
);
AddBlockToForloop
()(
e
);
}
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/remove_nested_block.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/**
* This file implements the strategy to remove the unnecessary nested block.
*/
#pragma once
#include <vector>
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/ir.h"
namespace
cinn
{
namespace
optim
{
/**
* Remove the unecessary nested block.
*/
void
RemoveNestedBlock
(
Expr
*
e
);
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/remove_nested_block_test.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/remove_nested_block.h"
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
optim
{
TEST
(
RemoveNestedBlock
,
basic
)
{
auto
block0
=
ir
::
Block
::
Make
({
Expr
(
1.
f
),
Expr
(
1.
f
)});
auto
block1
=
ir
::
Block
::
Make
({
block0
});
auto
e
=
Expr
(
block1
);
std
::
string
origin
=
utils
::
GetStreamCnt
(
e
);
EXPECT_EQ
(
origin
,
utils
::
Trim
(
R"ROC(
{
{
1.00000000f
1.00000000f
}
}
)ROC"
));
std
::
cout
<<
"origin:
\n
"
<<
e
<<
std
::
endl
;
RemoveNestedBlock
(
&
e
);
std
::
cout
<<
"e:
\n
"
<<
e
<<
std
::
endl
;
EXPECT_EQ
(
utils
::
GetStreamCnt
(
e
),
utils
::
Trim
(
R"ROC(
{
1.00000000f
1.00000000f
}
)ROC"
));
}
}
// namespace optim
}
// namespace cinn
paddle/cinn/optim/remove_schedule_block.cc
View file @
01a10755
...
...
@@ -14,8 +14,8 @@
#include "paddle/cinn/optim/remove_schedule_block.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
namespace
cinn
{
...
...
paddle/cinn/optim/remove_schedule_block_test.cc
View file @
01a10755
...
...
@@ -21,8 +21,8 @@
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
Prev
1
…
21
22
23
24
25
26
27
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment