Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
565
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
640 additions
and
229 deletions
+640
-229
paddle/cinn/ir/utils/ir_replace.h
paddle/cinn/ir/utils/ir_replace.h
+4
-3
paddle/cinn/ir/utils/ir_verify.cc
paddle/cinn/ir/utils/ir_verify.cc
+10
-6
paddle/cinn/ir/utils/ir_verify.h
paddle/cinn/ir/utils/ir_verify.h
+6
-3
paddle/cinn/lang/CMakeLists.txt
paddle/cinn/lang/CMakeLists.txt
+2
-0
paddle/cinn/lang/lower.cc
paddle/cinn/lang/lower.cc
+125
-27
paddle/cinn/lang/lower.h
paddle/cinn/lang/lower.h
+17
-0
paddle/cinn/lang/lower_impl.cc
paddle/cinn/lang/lower_impl.cc
+16
-17
paddle/cinn/lang/lower_impl.h
paddle/cinn/lang/lower_impl.h
+2
-3
paddle/cinn/lang/lower_tensor_group.cc
paddle/cinn/lang/lower_tensor_group.cc
+237
-0
paddle/cinn/lang/lower_tensor_group.h
paddle/cinn/lang/lower_tensor_group.h
+72
-0
paddle/cinn/lang/lower_test.cc
paddle/cinn/lang/lower_test.cc
+136
-5
paddle/cinn/lang/packed_func_test.cc
paddle/cinn/lang/packed_func_test.cc
+1
-1
paddle/cinn/lang/placeholder.h
paddle/cinn/lang/placeholder.h
+1
-1
paddle/cinn/lang/placeholder_test.cc
paddle/cinn/lang/placeholder_test.cc
+1
-1
paddle/cinn/optim/CMakeLists.txt
paddle/cinn/optim/CMakeLists.txt
+4
-8
paddle/cinn/optim/buffer_assign.cc
paddle/cinn/optim/buffer_assign.cc
+4
-4
paddle/cinn/optim/call_arg_list_to_pod_value.cc
paddle/cinn/optim/call_arg_list_to_pod_value.cc
+1
-1
paddle/cinn/optim/cast_bool_to_int8.cc
paddle/cinn/optim/cast_bool_to_int8.cc
+1
-1
paddle/cinn/optim/cast_simplify.cc
paddle/cinn/optim/cast_simplify.cc
+0
-117
paddle/cinn/optim/cast_simplify.h
paddle/cinn/optim/cast_simplify.h
+0
-31
No files found.
Too many changes to show.
To preserve performance only
565 of 565+
files are displayed.
Plain diff
Email patch
paddle/cinn/
optim
/ir_replace.h
→
paddle/cinn/
ir/utils
/ir_replace.h
View file @
01a10755
...
...
@@ -18,10 +18,11 @@
#include "paddle/cinn/ir/ir.h"
namespace
cinn
{
namespace
optim
{
namespace
ir
{
namespace
ir_utils
{
//! Replace the variable \p v to expression \p e in expression \p expr.
void
IrReplace
(
ir
::
Expr
*
expr
,
ir
::
Expr
from
,
ir
::
Expr
to
);
}
// namespace
optim
}
// namespace ir_utils
}
// namespace
ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_verify.cc
View file @
01a10755
...
...
@@ -14,10 +14,13 @@
#include "paddle/cinn/ir/utils/ir_verify.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
::
ir
{
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
namespace
{
struct
IrVerifyVisitor
:
public
ir
::
IRMutator
<>
{
using
ir
::
IRMutator
<>::
Visit
;
...
...
@@ -30,10 +33,11 @@ struct IrVerifyVisitor : public ir::IRMutator<> {
NODETY_FORALL
(
__
)
#undef __
};
}
// namespace
void
IrVerify
(
Expr
e
)
{
IrVerifyVisitor
visitor
;
visitor
.
Visit
(
&
e
,
&
e
);
}
}
// namespace cinn::ir
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/ir/utils/ir_verify.h
View file @
01a10755
...
...
@@ -15,8 +15,11 @@
#pragma once
#include "paddle/cinn/ir/ir.h"
namespace
cinn
::
ir
{
namespace
cinn
{
namespace
ir
{
namespace
ir_utils
{
void
IrVerify
(
Expr
e
);
}
// namespace cinn::ir
}
// namespace ir_utils
}
// namespace ir
}
// namespace cinn
paddle/cinn/lang/CMakeLists.txt
View file @
01a10755
...
...
@@ -7,6 +7,8 @@ gather_srcs(
compute.cc
placeholder.cc
lower.cc
lower_impl.cc
lower_tensor_group.cc
builtin.cc
lower_impl.cc
packed_func.cc
)
...
...
paddle/cinn/lang/lower.cc
View file @
01a10755
...
...
@@ -22,14 +22,16 @@
#include <utility>
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/lang/lower_impl.h"
#include "paddle/cinn/lang/lower_tensor_group.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
namespace
lang
{
using
ast_gen_ius
::
TensorGroup
;
using
ir
::
Tensor
;
using
poly
::
Stage
;
...
...
@@ -38,7 +40,7 @@ std::vector<ir::Argument> GetArgs(
std
::
vector
<
ir
::
Argument
>
res
;
std
::
map
<
std
::
string
,
std
::
set
<
const
ir
::
Load
*>>
name2loads
;
std
::
map
<
std
::
string
,
std
::
set
<
const
ir
::
Store
*>>
name2stores
;
auto
load_or_store_nodes
=
ir
::
CollectIRNodesWithoutTensor
(
auto
load_or_store_nodes
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
func_body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
()
||
x
->
As
<
ir
::
Load
>
();
});
...
...
@@ -84,6 +86,49 @@ std::vector<ir::Argument> GetArgs(
return
res
;
}
//! Collect the temporary tensors from a computational graph.
std
::
vector
<
ir
::
Buffer
>
GetTempBuffers
(
const
std
::
vector
<
Tensor
>&
tensor_args
,
const
TensorGroup
&
tensor_group
,
Expr
body
)
{
std
::
unordered_set
<
std
::
string
>
tensor_arg_names
;
std
::
unordered_set
<
std
::
string
>
buffer_arg_names
;
for
(
auto
&
tensor
:
tensor_args
)
{
tensor_arg_names
.
insert
(
tensor
->
name
);
if
(
tensor
->
buffer
.
defined
())
{
buffer_arg_names
.
insert
(
tensor
->
buffer
->
name
);
}
}
std
::
map
<
std
::
string
,
ir
::
Buffer
>
name_to_buffer
;
// used to avoid duplication.
auto
all_temp_tensors
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
()
&&
(
!
tensor_group
.
Contain
(
x
->
as_tensor
()
->
name
)
||
((
!
buffer_arg_names
.
count
(
x
->
as_tensor
()
->
buffer
->
name
)
&&
!
tensor_arg_names
.
count
(
x
->
as_tensor
()
->
name
))
||
utils
::
Endswith
(
x
->
as_tensor
()
->
buffer
->
name
,
"temp_buffer"
)));
});
for
(
auto
&
e
:
all_temp_tensors
)
{
auto
buffer_name
=
e
.
as_tensor
()
->
buffer
->
name
;
if
(
!
name_to_buffer
.
count
(
buffer_name
))
{
name_to_buffer
[
buffer_name
]
=
e
.
as_tensor
()
->
buffer
;
}
else
{
// Just copy from old code, but why?
if
(
e
.
as_tensor
()
->
buffer
->
numel
()
<
name_to_buffer
[
buffer_name
]
->
numel
())
{
name_to_buffer
[
buffer_name
]
=
e
.
as_tensor
()
->
buffer
;
}
}
}
std
::
vector
<
ir
::
Buffer
>
temp_buffers
;
for
(
auto
&
i
:
name_to_buffer
)
{
temp_buffers
.
push_back
(
i
.
second
);
}
return
temp_buffers
;
}
//! Collect the temporary tensors from a computational graph.
std
::
vector
<
ir
::
Buffer
>
GetTempBuffers
(
const
std
::
vector
<
Tensor
>&
tensor_args
,
const
poly
::
StageMap
&
stage_map
,
...
...
@@ -100,7 +145,7 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
name_to_buffer
;
// used to avoid duplication.
auto
all_temp_tensors
=
ir
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
()
&&
(
!
stage_map
->
Lookup
(
x
->
as_tensor
()
->
name
)
||
!
stage_map
[
x
->
as_tensor
()]
->
inlined
())
&&
...
...
@@ -120,7 +165,8 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<Tensor>& tensor_args,
}
}
// visit the ir body and update the map of name_to_buffer
auto
update_map
=
ir
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
auto
update_map
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
())
{
auto
buffer_name
=
x
->
as_tensor
()
->
buffer
->
name
;
if
(
name_to_buffer
.
count
(
buffer_name
)
&&
...
...
@@ -150,7 +196,7 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<ir::Argument>& args,
name_to_buffer
;
// used to avoid duplication.
auto
all_temp_tensors
=
ir
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
()
&&
(
!
buffer_arg_names
.
count
(
x
->
as_tensor
()
->
buffer
->
name
)
||
utils
::
Endswith
(
x
->
as_tensor
()
->
buffer
->
name
,
"temp_buffer"
));
...
...
@@ -167,7 +213,8 @@ std::vector<ir::Buffer> GetTempBuffers(const std::vector<ir::Argument>& args,
}
}
// visit the ir body and update the map of name_to_buffer
auto
update_map
=
ir
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
auto
update_map
=
ir
::
ir_utils
::
CollectIRNodesWithoutTensor
(
body
,
[
&
](
const
Expr
*
x
)
{
if
(
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
buffer
.
defined
())
{
auto
buffer_name
=
x
->
as_tensor
()
->
buffer
->
name
;
if
(
name_to_buffer
.
count
(
buffer_name
)
&&
...
...
@@ -205,7 +252,7 @@ void InitReduceTensor(StageMap stages,
tensor
->
InitReduction
(
stages
,
target
);
}
auto
uninited_reduce_tensors
=
ir
::
CollectIRNodes
(
tensor
->
body
(),
[
&
](
const
Expr
*
x
)
{
ir
::
ir_utils
::
CollectIRNodes
(
tensor
->
body
(),
[
&
](
const
Expr
*
x
)
{
return
x
&&
x
->
defined
()
&&
x
->
as_tensor
()
&&
x
->
as_tensor
()
->
is_reduce_tensor
()
&&
!
x
->
as_tensor
()
->
IsReduceInited
(
stages
);
...
...
@@ -216,6 +263,57 @@ void InitReduceTensor(StageMap stages,
}
}
std
::
set
<
ir
::
Tensor
>
CollectTempTensorsFromCtrlDepends
(
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
std
::
vector
<
Tensor
>&
tensor_args
)
{
std
::
set
<
ir
::
Tensor
>
res
;
for
(
const
ir
::
Tensor
&
a
:
tensor_group
->
GetAllTensors
())
{
for
(
const
ir
::
Tensor
&
t
:
tensor_group
->
GetCrtlDepTensors
(
a
->
name
))
{
res
.
emplace
(
t
);
}
}
for
(
const
ir
::
Tensor
&
t
:
tensor_args
)
{
if
(
res
.
count
(
t
))
{
res
.
erase
(
t
);
}
}
return
res
;
}
ir
::
LoweredFunc
LowerToAst
(
const
std
::
string
&
name
,
const
std
::
vector
<
Tensor
>&
tensor_args
,
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
Target
&
target
)
{
std
::
vector
<
ir
::
LoweredFunc
>
result
=
LowerToAstVec
(
name
,
tensor_args
,
tensor_group
,
target
);
CHECK_EQ
(
result
.
size
(),
1UL
)
<<
"LowerToAst contains not only 1 LoweredFunc, "
"use LowerToAstVec instead."
;
return
result
[
0
];
}
std
::
vector
<
ir
::
LoweredFunc
>
LowerToAstVec
(
const
std
::
string
&
name
,
const
std
::
vector
<
Tensor
>&
tensor_args
,
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
Target
&
target
)
{
std
::
set
<
ir
::
Tensor
>
ctrl_deps
=
CollectTempTensorsFromCtrlDepends
(
tensor_group
,
tensor_args
);
auto
lower_instance
=
detail
::
LowerTensorGroup
(
name
,
tensor_args
,
{},
tensor_group
,
std
::
vector
<
Tensor
>
(
ctrl_deps
.
begin
(),
ctrl_deps
.
end
()),
target
);
std
::
vector
<
ir
::
LoweredFunc
>
result
=
lower_instance
();
for
(
auto
&
res
:
result
)
{
if
(
target
==
common
::
DefaultNVGPUTarget
())
{
res
->
device_api
=
ir
::
DeviceAPI
::
GPU
;
}
}
return
result
;
}
ir
::
LoweredFunc
Lower
(
const
std
::
string
&
name
,
StageMap
stages
,
const
std
::
vector
<
Tensor
>&
tensor_args
,
...
...
paddle/cinn/lang/lower.h
View file @
01a10755
...
...
@@ -20,6 +20,7 @@
#include <string>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/lowered_func.h"
#include "paddle/cinn/ir/module.h"
...
...
@@ -73,6 +74,22 @@ std::vector<ir::LoweredFunc> LowerVec(
const
Target
&
target
=
common
::
DefaultHostTarget
(),
bool
support_ir_schedule
=
false
);
ir
::
LoweredFunc
LowerToAst
(
const
std
::
string
&
name
,
const
std
::
vector
<
Tensor
>
&
tensor_args
,
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
Target
&
target
=
common
::
DefaultHostTarget
());
std
::
vector
<
ir
::
LoweredFunc
>
LowerToAstVec
(
const
std
::
string
&
name
,
const
std
::
vector
<
Tensor
>
&
tensor_args
,
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
Target
&
target
=
common
::
DefaultHostTarget
());
std
::
vector
<
ir
::
Buffer
>
GetTempBuffers
(
const
std
::
vector
<
Tensor
>
&
tensor_args
,
const
ast_gen_ius
::
TensorGroup
&
tensor_group
,
Expr
body
);
std
::
vector
<
ir
::
Argument
>
GetArgs
(
const
Expr
&
func_body
,
const
std
::
vector
<
std
::
string
>
&
input_output_nodes
);
...
...
paddle/cinn/lang/lower_impl.cc
View file @
01a10755
...
...
@@ -23,9 +23,9 @@
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/stage.h"
...
...
@@ -35,7 +35,7 @@ namespace lang {
namespace
detail
{
void
CheckNoIslCallRemains
(
Expr
*
expr
)
{
auto
isl_calls
=
ir
::
CollectIRNodes
(
*
expr
,
[](
const
Expr
*
expr
)
{
auto
isl_calls
=
ir
::
ir_utils
::
CollectIRNodes
(
*
expr
,
[](
const
Expr
*
expr
)
{
return
expr
->
As
<
ir
::
Call
>
()
&&
expr
->
As
<
ir
::
Call
>
()
->
is_isl_call
();
});
#ifdef CINN_DEBUG
...
...
@@ -223,7 +223,7 @@ void CreateCompGraphWithInlineTensors(common::Graph* graph,
// collect dependency tensors of t
// here we just collect the tensors in Load nodes
// NOTE there may be some other cases.
auto
deps
=
ir
::
CollectLoadTensors
(
auto
deps
=
ir
::
ir_utils
::
CollectLoadTensors
(
t
->
body
(),
[](
const
Expr
*
x
)
{
return
x
->
as_tensor
();
});
for
(
const
auto
&
dep
:
deps
)
{
auto
e_tensor
=
dep
.
as_tensor_ref
();
...
...
@@ -342,8 +342,7 @@ std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(
CheckArgsUnique
();
std
::
vector
<
ir
::
Argument
>
args
;
optim
::
TensorWriteTeller
teller
;
teller
.
Collect
(
&
fn_body
);
auto
teller
=
ir
::
ir_utils
::
CollectTensorNeedsWrite
(
&
fn_body
);
std
::
set
<
std
::
string
>
arg_names
;
...
...
@@ -358,7 +357,7 @@ std::vector<ir::Argument> LowerImpl::GenerateFunctionArgumentList(
for
(
auto
&
tensor
:
tensor_args_
)
{
auto
*
tensor_node
=
tensor
.
As
<
ir
::
_Tensor_
>
();
bool
is_output
=
teller
.
IsWrite
(
tensor
->
name
);
bool
is_output
=
teller
.
count
(
tensor
->
name
);
VLOG
(
1
)
<<
"tensor argument "
<<
tensor
->
name
<<
" buffer "
<<
tensor
->
buffer
->
name
;
...
...
@@ -396,8 +395,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
std
::
vector
<
ir
::
Argument
>
in_args
;
std
::
vector
<
ir
::
Argument
>
out_args
;
optim
::
TensorWriteTeller
teller
;
teller
.
Collect
(
&
func_iterator
);
auto
teller
=
ir
::
ir_utils
::
CollectTensorNeedsWrite
(
&
func_iterator
);
std
::
set
<
std
::
string
>
arg_names
;
std
::
set
<
std
::
string
>
all_tensor_names
;
...
...
@@ -410,11 +408,12 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
in_args
.
emplace_back
(
scalar
,
ir
::
Argument
::
IO
::
kInput
);
}
auto
all_tensors
=
ir
::
CollectIRNodes
(
func_iterator
,
[
&
](
const
Expr
*
x
)
{
auto
all_tensors
=
ir
::
ir_utils
::
CollectIRNodes
(
func_iterator
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_tensor
()
&&
!
stages_
[
x
->
as_tensor
()]
->
inlined
();
});
auto
all_vars
=
ir
::
CollectIRNodes
(
auto
all_vars
=
ir
::
ir_utils
::
CollectIRNodes
(
func_iterator
,
[
&
](
const
Expr
*
x
)
{
return
x
->
as_var
();
});
for
(
auto
&
i
:
all_tensors
)
{
...
...
@@ -448,7 +447,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
VLOG
(
3
)
<<
"In tensor_args_, it has : "
<<
tensor
->
name
;
if
(
temp_tensor_names
.
count
(
tensor
->
name
)
>
0
)
continue
;
if
(
all_tensor_names
.
count
(
tensor
->
name
)
==
0
)
continue
;
bool
is_output
=
teller
.
IsWrite
(
tensor
->
name
);
bool
is_output
=
teller
.
count
(
tensor
->
name
);
VLOG
(
3
)
<<
"tensor argument "
<<
tensor
->
name
<<
" buffer "
<<
tensor
->
buffer
->
name
;
...
...
@@ -485,7 +484,7 @@ std::vector<ir::Argument> LowerImpl::GenFuncArgForSplitKernel(
VLOG
(
3
)
<<
"Tensor "
<<
tensor
->
name
;
if
(
tensor
->
buffer
.
defined
()
&&
!
arg_names
.
count
(
tensor
->
buffer
->
name
))
{
bool
is_output
=
teller
.
IsWrite
(
tensor
->
name
)
&&
teller
.
IsWrite
(
tensor
->
name
);
teller
.
count
(
tensor
->
name
)
&&
teller
.
count
(
tensor
->
name
);
if
(
is_output
)
out_args
.
emplace_back
(
tensor
->
buffer
,
ir
::
Argument
::
IO
::
kOutput
);
}
...
...
@@ -590,7 +589,7 @@ std::vector<ir::LoweredFunc> LowerImpl::operator()() {
Reference
(
&
arg
)
->
buffer
=
tensor_map
.
at
(
arg
->
name
)
->
buffer
;
}
}
auto
store_exprs
=
ir
::
CollectIRNodes
(
auto
store_exprs
=
ir
::
ir_utils
::
CollectIRNodes
(
func_iterator
,
[](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
std
::
vector
<
ir
::
Tensor
>
new_temp_tensors
;
for
(
auto
&
expr
:
store_exprs
)
{
...
...
@@ -655,7 +654,7 @@ std::vector<ir::LoweredFunc> LowerImpl::operator()() {
if
(
support_ir_schedule_
)
{
optim
::
TransformPolyForToFor
(
&
func
->
body
);
optim
::
RemoveNested
Block
(
&
func
->
body
);
optim
::
Simplify
Block
s
(
&
func
->
body
);
func
->
body
=
ir
::
Block
::
Make
({
func
->
body
});
result
.
push_back
(
ir
::
LoweredFunc
(
func
.
get
()));
num_func
++
;
...
...
paddle/cinn/lang/lower_impl.h
View file @
01a10755
...
...
@@ -27,14 +27,13 @@
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/optim/compute_inline_expand.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/remove_nested_block.h"
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/optim/tensor_write_tell.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/ast_gen.h"
...
...
paddle/cinn/lang/lower_tensor_group.cc
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/lang/lower_tensor_group.h"
#include <algorithm>
#include <queue>
#include <string>
#include <unordered_set>
#include "paddle/cinn/ast_gen_ius/ast_gen.h"
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/optim/replace_var_with_expr.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/stage.h"
namespace
cinn
{
namespace
lang
{
namespace
detail
{
LowerTensorGroup
::
LowerTensorGroup
(
const
std
::
string
&
fn_name
,
const
std
::
vector
<
ir
::
Tensor
>&
tensor_args
,
const
std
::
vector
<
ir
::
Var
>&
scalar_args
,
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
std
::
vector
<
ir
::
Tensor
>&
temp_tensor_args
,
const
Target
&
target
)
:
fn_name_
(
fn_name
),
tensor_args_
(
tensor_args
),
scalar_args_
(
scalar_args
),
tensor_group_
(
tensor_group
),
temp_tensor_args_
(
temp_tensor_args
),
target_
(
target
)
{}
std
::
vector
<
ir
::
LoweredFunc
>
LowerTensorGroup
::
operator
()()
{
std
::
vector
<
ir
::
LoweredFunc
>
result
;
int
num_func
=
0
;
// 1. Generate function body
std
::
vector
<
ir
::
Expr
>
func_bodies
=
GenerateFunctionBody
(
tensor_group_
);
for
(
ir
::
Expr
&
func_body
:
func_bodies
)
{
func_body
=
ir
::
ScheduleBlockRealize
::
Make
(
{},
ir
::
ScheduleBlock
::
Make
(
{},
{},
{},
common
::
UniqName
(
"root"
),
func_body
));
// 2. Assign buffer to tensors
auto
tensor_map
=
tensor_group_
->
AllocateBuffers
();
// copy the tensor(with buffer assigned) back to func's args.
for
(
auto
&
arg
:
tensor_args_
)
{
if
(
arg
->
is_placeholder_node
()
||
arg
->
buffer
.
defined
())
{
continue
;
}
if
(
arg
->
body
().
As
<
ir
::
Call
>
()
&&
arg
->
body
().
type
().
is_void
())
{
continue
;
// extern call
}
if
(
tensor_map
.
find
(
arg
->
name
)
==
tensor_map
.
end
())
{
LOG
(
INFO
)
<<
"Didn't find arg tensor "
<<
arg
->
name
<<
"in tensor_map.
\n
"
<<
"The function is "
<<
fn_name_
<<
"
\n
And all the arg tensors are:
\n
"
;
for
(
auto
&
i
:
tensor_args_
)
{
LOG
(
INFO
)
<<
i
->
name
;
}
LOG
(
FATAL
)
<<
"Fatal Error!"
;
}
Reference
(
&
arg
)
->
buffer
=
tensor_map
.
at
(
arg
->
name
)
->
buffer
;
}
// 3. Collect temp tensor buffers
std
::
set
<
std
::
string
>
temp_tensor_names
;
for
(
auto
&
t
:
temp_tensor_args_
)
{
temp_tensor_names
.
insert
(
t
->
name
);
}
// Some store tensors are also temp tensors;
auto
store_exprs
=
ir
::
ir_utils
::
CollectIRNodes
(
func_body
,
[](
const
Expr
*
x
)
{
return
x
->
As
<
ir
::
Store
>
();
});
for
(
auto
&
expr
:
store_exprs
)
{
auto
*
store_node
=
expr
.
As
<
ir
::
Store
>
();
CHECK
(
store_node
);
auto
*
tensor
=
store_node
->
tensor
.
As
<
ir
::
_Tensor_
>
();
CHECK
(
tensor
);
VLOG
(
3
)
<<
"In store_exprs, its name is : "
<<
tensor
->
name
;
CHECK
(
tensor
->
buffer
.
defined
());
if
(
tensor
->
buffer
->
memory_type
!=
ir
::
MemoryType
::
Heap
)
{
temp_tensor_names
.
insert
(
store_node
->
tensor
.
as_tensor_ref
()
->
name
);
}
}
std
::
vector
<
ir
::
Buffer
>
temp_buffers
;
std
::
unordered_set
<
std
::
string
>
buffer_name_set
;
for
(
const
std
::
string
&
name
:
temp_tensor_names
)
{
if
(
!
tensor_map
.
count
(
name
))
{
continue
;
}
ir
::
Tensor
&
t
=
tensor_map
[
name
];
if
(
t
->
buffer
.
defined
()
&&
!
buffer_name_set
.
count
(
t
->
buffer
->
name
))
{
temp_buffers
.
push_back
(
t
->
buffer
);
buffer_name_set
.
insert
(
t
->
buffer
->
name
);
}
}
// 4. Handle function args
std
::
vector
<
ir
::
Argument
>
func_args
=
GenerateFunctionArgumentList
(
func_body
);
// 5. Actual function make
std
::
string
actual_fn_name
=
fn_name_
;
if
(
num_func
>
0
)
{
actual_fn_name
+=
"_"
+
std
::
to_string
(
num_func
);
VLOG
(
3
)
<<
"Making func :"
<<
actual_fn_name
;
}
for
(
auto
&
i
:
func_args
)
{
VLOG
(
3
)
<<
"func_args is : "
<<
i
.
name
();
}
for
(
auto
&
i
:
temp_buffers
)
{
VLOG
(
3
)
<<
"temp_buffers is : "
<<
i
->
name
;
}
ir
::
LoweredFunc
func
=
ir
::
_LoweredFunc_
::
Make
(
actual_fn_name
,
func_args
,
func_body
,
temp_buffers
);
// 6. Final clean up
optim
::
SimplifyBlocks
(
&
func
->
body
);
func
->
body
=
ir
::
Block
::
Make
({
func
->
body
});
result
.
push_back
(
ir
::
LoweredFunc
(
func
.
get
()));
num_func
++
;
}
return
result
;
}
std
::
vector
<
ir
::
Argument
>
LowerTensorGroup
::
GenerateFunctionArgumentList
(
Expr
fn_body
)
{
std
::
vector
<
ir
::
Argument
>
args
;
auto
teller
=
ir
::
ir_utils
::
CollectTensorNeedsWrite
(
&
fn_body
);
std
::
set
<
std
::
string
>
arg_names
;
for
(
auto
&
scalar
:
scalar_args_
)
{
CHECK
(
!
arg_names
.
count
(
scalar
->
name
));
auto
*
scalar_node
=
scalar
.
As
<
ir
::
_Var_
>
();
CHECK
(
scalar_node
->
type
().
valid
());
arg_names
.
insert
(
scalar
->
name
);
args
.
emplace_back
(
scalar
,
ir
::
Argument
::
IO
::
kInput
);
}
for
(
auto
&
tensor
:
tensor_args_
)
{
auto
*
tensor_node
=
tensor
.
As
<
ir
::
_Tensor_
>
();
bool
is_output
=
teller
.
count
(
tensor
->
name
);
VLOG
(
6
)
<<
"tensor argument "
<<
tensor
->
name
<<
", buffer "
<<
tensor
->
buffer
->
name
<<
", is output: "
<<
is_output
;
// avoid duplicate
if
(
!
tensor_node
->
buffer
.
defined
())
{
continue
;
}
// if a argument is already marked as kInput, mark it as kOutput and move
// it to the back.
if
(
arg_names
.
count
(
tensor_node
->
buffer
->
name
))
{
auto
it
=
std
::
find_if
(
args
.
begin
(),
args
.
end
(),
[
&
](
const
ir
::
Argument
&
x
)
{
return
x
.
name
()
==
tensor_node
->
buffer
->
name
;
});
CHECK
(
it
!=
args
.
end
());
if
(
it
->
is_input
())
{
args
.
erase
(
it
);
}
else
if
(
it
->
is_output
())
{
continue
;
}
}
arg_names
.
insert
(
tensor_node
->
buffer
->
name
);
auto
io
=
is_output
?
ir
::
Argument
::
IO
::
kOutput
:
ir
::
Argument
::
IO
::
kInput
;
VLOG
(
6
)
<<
"Collect "
<<
(
is_output
?
"W"
:
"R"
)
<<
" argument "
<<
tensor
->
buffer
->
name
;
args
.
emplace_back
(
tensor_node
->
buffer
,
io
);
}
return
args
;
}
std
::
vector
<
ir
::
Expr
>
LowerTensorGroup
::
GenerateFunctionBody
(
ast_gen_ius
::
TensorGroup
*
tensor_group
)
{
// TODO(zhhsplendid): GetGenFuncTopoOrder() may remove args
std
::
vector
<
ir
::
Tensor
>
ordered_tensors
=
tensor_group
->
GetGenFuncTopoOrder
();
std
::
vector
<
ir
::
Expr
>
result
;
std
::
vector
<
ir
::
Expr
>
bodies
;
for
(
const
ir
::
Tensor
&
tensor
:
ordered_tensors
)
{
VLOG
(
6
)
<<
"tensor_name = "
<<
tensor
->
name
;
if
(
!
tensor
->
is_placeholder_node
()
&&
tensor
->
has_expression
())
{
VLOG
(
6
)
<<
"ast_gen_ius::AstGen::Build for Tensor "
<<
tensor
;
bodies
.
emplace_back
(
ast_gen_ius
::
AstGen
::
Build
(
tensor
,
tensor_group
));
bool
gpu_local
=
tensor
->
buffer
.
defined
()
&&
(
tensor
->
buffer
->
memory_type
==
ir
::
MemoryType
::
GPUShared
||
tensor
->
buffer
->
memory_type
==
ir
::
MemoryType
::
GPULocal
);
if
(
target_
==
common
::
DefaultNVGPUTarget
()
&&
!
gpu_local
)
{
result
.
push_back
(
bodies
.
size
()
==
1
?
bodies
[
0
]
:
ir
::
Block
::
Make
(
bodies
));
bodies
.
clear
();
}
}
}
if
(
!
bodies
.
empty
())
{
result
.
push_back
(
bodies
.
size
()
==
1
?
bodies
[
0
]
:
ir
::
Block
::
Make
(
bodies
));
bodies
.
clear
();
}
return
result
;
}
}
// namespace detail
}
// namespace lang
}
// namespace cinn
paddle/cinn/lang/lower_tensor_group.h
0 → 100644
View file @
01a10755
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <absl/container/flat_hash_map.h>
#include <iostream>
#include <map>
#include <memory>
#include <set>
#include <stack>
#include <string>
#include <unordered_set>
#include <utility>
#include <vector>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/common/graph_utils.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/optim/compute_inline_expand.h"
#include "paddle/cinn/optim/fold_cinn_call_arguments.h"
#include "paddle/cinn/optim/optimize.h"
#include "paddle/cinn/optim/replace_call_with_expr.h"
#include "paddle/cinn/optim/transform_gpu_forloop.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
#include "paddle/cinn/poly/ast_gen.h"
namespace
cinn
{
namespace
lang
{
namespace
detail
{
class
LowerTensorGroup
{
public:
LowerTensorGroup
(
const
std
::
string
&
fn_name
,
const
std
::
vector
<
ir
::
Tensor
>&
tensor_args
,
const
std
::
vector
<
ir
::
Var
>&
scalar_args
,
ast_gen_ius
::
TensorGroup
*
tensor_group
,
const
std
::
vector
<
ir
::
Tensor
>&
temp_tensor_args
=
{},
const
Target
&
target
=
common
::
DefaultHostTarget
());
std
::
vector
<
ir
::
LoweredFunc
>
operator
()();
std
::
vector
<
ir
::
Expr
>
GenerateFunctionBody
(
ast_gen_ius
::
TensorGroup
*
tensor_group
);
std
::
vector
<
ir
::
Argument
>
GenerateFunctionArgumentList
(
ir
::
Expr
fn_body
);
private:
const
std
::
string
&
fn_name_
;
const
std
::
vector
<
ir
::
Tensor
>&
tensor_args_
;
const
std
::
vector
<
Var
>&
scalar_args_
;
std
::
vector
<
ir
::
Tensor
>
temp_tensor_args_
;
ast_gen_ius
::
TensorGroup
*
tensor_group_
;
Target
target_
;
};
}
// namespace detail
}
// namespace lang
}
// namespace cinn
paddle/cinn/lang/lower_test.cc
View file @
01a10755
...
...
@@ -18,6 +18,7 @@
#include <set>
#include "paddle/cinn/ast_gen_ius/tensor_group.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/lang/buffer.h"
#include "paddle/cinn/lang/compute.h"
...
...
@@ -27,6 +28,10 @@
namespace
cinn
{
namespace
lang
{
#define TEST_SOUTPUT(x, out) \
LOG(INFO) << "\n" << x << std::endl; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
TEST
(
lower
,
basic
)
{
auto
M
=
Expr
(
100
);
auto
N
=
Expr
(
15
);
...
...
@@ -42,10 +47,6 @@ TEST(lower, basic) {
LOG
(
INFO
)
<<
"lower_size "
<<
lower_funcs
;
#define TEST_SOUTPUT(x, out) \
std::cout << "\n" << x << std::endl; \
EXPECT_EQ(utils::GetStreamCnt(x), utils::Trim(out));
auto
out
=
R"ROC(
{
serial for (i, 0, 100)
...
...
@@ -77,7 +78,7 @@ TEST(lower, more_complex) {
auto
lower_funcs
=
Lower
(
"cal_C"
,
stages
,
{
A
,
B
,
C
});
std
::
cout
<<
"func:
\n
"
<<
Expr
(
lower_funcs
->
self
())
<<
std
::
endl
;
LOG
(
INFO
)
<<
"func:
\n
"
<<
Expr
(
lower_funcs
->
self
())
<<
std
::
endl
;
}
//! To support training, the dynamic shape support is vital. We test the
...
...
@@ -157,5 +158,135 @@ TEST(lower, temp_buffer_collects) {
}
}
TEST
(
lower_to_ast
,
basic
)
{
Context
::
Global
().
ResetNameId
();
auto
M
=
Expr
(
100
);
auto
N
=
Expr
(
15
);
Placeholder
<
float
>
A
(
"A"
,
{
Expr
(
M
),
Expr
(
N
)});
ir
::
Tensor
B
=
Compute
(
{
M
,
N
},
[
=
](
Var
i
,
Var
j
)
->
Expr
{
return
A
(
i
,
j
)
+
1.
f
;
},
"B"
);
ast_gen_ius
::
TensorGroup
tensor_group
({
B
});
ir
::
LoweredFunc
lower_func
=
LowerToAst
(
"cal_B"
,
{
A
,
B
},
&
tensor_group
);
LOG
(
INFO
)
<<
"lower_func "
<<
lower_func
;
auto
out
=
R"ROC(
function cal_B (_A, _B)
{
ScheduleBlock(root)
{
serial for (i, 0, 100)
{
serial for (j, 0, 15)
{
ScheduleBlock(B)
{
i0, i1 = axis.bind(i, j)
B[i0, i1] = (A[i0, i1] + 1.00000000f)
}
}
}
}
}
)ROC"
;
TEST_SOUTPUT
(
lower_func
,
out
);
}
TEST
(
lower_to_ast
,
three_dim
)
{
Context
::
Global
().
ResetNameId
();
Expr
M
(
100
);
Expr
N
(
15
);
Expr
K
(
200
);
Placeholder
<
float
>
A
(
"A"
,
{
Expr
(
M
),
Expr
(
N
)});
Placeholder
<
float
>
B
(
"B"
,
{
Expr
(
N
),
Expr
(
K
)});
auto
C
=
Compute
(
{
M
,
N
,
K
},
[
=
](
Var
i
,
Var
j
,
Var
k
)
->
Expr
{
return
A
(
i
,
j
)
*
B
(
j
,
k
);
},
"C"
);
ast_gen_ius
::
TensorGroup
tensor_group
({
C
});
ir
::
LoweredFunc
lower_func
=
LowerToAst
(
"cal_C"
,
{
A
,
B
,
C
},
&
tensor_group
);
LOG
(
INFO
)
<<
"func:
\n
"
<<
lower_func
<<
std
::
endl
;
auto
out
=
R"ROC(
function cal_C (_A, _B, _C)
{
ScheduleBlock(root)
{
serial for (i, 0, 100)
{
serial for (j, 0, 15)
{
serial for (k, 0, 200)
{
ScheduleBlock(C)
{
i0, i1, i2 = axis.bind(i, j, k)
C[i0, i1, i2] = (A[i0, i1] * B[i1, i2])
}
}
}
}
}
}
)ROC"
;
TEST_SOUTPUT
(
lower_func
,
out
);
}
TEST
(
lower_to_ast
,
matmul_with_reduce_sum
)
{
Context
::
Global
().
ResetNameId
();
Placeholder
<
float
>
A
(
"A"
,
{
Expr
(
100
),
Expr
(
20
)});
Placeholder
<
float
>
B
(
"B"
,
{
Expr
(
20
),
Expr
(
50
)});
Target
target
{};
// C = A * B
Var
k
(
20
,
"k0"
);
Tensor
C
=
Compute
(
{
Expr
(
100
),
Expr
(
50
)},
[
&
](
Var
i
,
Var
j
)
{
return
lang
::
ReduceSum
(
A
(
i
,
k
)
*
B
(
k
,
j
),
{
k
});
},
"C"
);
ast_gen_ius
::
TensorGroup
tensor_group
({
C
});
ir
::
LoweredFunc
lower_func
=
LowerToAst
(
"matmul"
,
{
A
,
B
,
C
},
&
tensor_group
);
LOG
(
INFO
)
<<
"func:
\n
"
<<
lower_func
<<
std
::
endl
;
auto
out
=
R"ROC(
function matmul (_A, _B, _C)
{
ScheduleBlock(root)
{
serial for (i, 0, 100)
{
serial for (j, 0, 50)
{
ScheduleBlock(C__reduce_init)
{
i0, i1 = axis.bind(i, j)
C__reduce_init[i0, i1] = 0.00000000f
}
serial for (k0, 0, 20)
{
ScheduleBlock(C)
{
i0_0, i1_0, i2 = axis.bind(i, j, k0)
C[i0_0, i1_0] = (C[i0_0, i1_0] + (A[i0_0, i2] * B[i2, i1_0]))
}
}
}
}
}
}
)ROC"
;
TEST_SOUTPUT
(
lower_func
,
out
);
}
}
// namespace lang
}
// namespace cinn
paddle/cinn/lang/packed_func_test.cc
View file @
01a10755
...
...
@@ -16,8 +16,8 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
paddle/cinn/lang/placeholder.h
View file @
01a10755
...
...
@@ -19,9 +19,9 @@
#include "paddle/cinn/common/common.h"
#include "paddle/cinn/ir/buffer.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/operation.h"
#include "paddle/cinn/ir/tensor.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace
cinn
{
...
...
paddle/cinn/lang/placeholder_test.cc
View file @
01a10755
...
...
@@ -16,7 +16,7 @@
#include <gtest/gtest.h>
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
namespace
cinn
{
namespace
lang
{
...
...
paddle/cinn/optim/CMakeLists.txt
View file @
01a10755
...
...
@@ -3,11 +3,8 @@ core_gather_headers()
gather_srcs
(
cinnapi_src
SRCS
remove_nested_block.cc
replace_call_with_expr.cc
ir_replace.cc
replace_var_with_expr.cc
tensor_write_tell.cc
ir_simplify.cc
optimize.cc
vectorize_loops.cc
...
...
@@ -23,19 +20,16 @@ gather_srcs(
compute_inline_expand.cc
buffer_assign.cc
replace_const_param_to_integer.cc
cast_simplify.cc
lower_intrin.cc
cast_bool_to_int8.cc
collect_undefined_vars.cc
var_mod_simplify.cc
remove_schedule_block.cc
)
remove_schedule_block.cc
replace_cross_thread_reduction.cc
)
if
(
WITH_CUDA
)
gather_srcs
(
cinnapi_src SRCS transform_gpu_forloop.cc
)
endif
()
cinn_cc_test
(
test_remove_nested_block SRCS remove_nested_block_test.cc DEPS
cinncore
)
cinn_cc_test
(
test_ir_simplify SRCS ir_simplify_test.cc DEPS cinncore
)
cinn_cc_test
(
test_replace_call_with_expr SRCS replace_call_with_expr_test.cc
DEPS cinncore
)
...
...
@@ -62,3 +56,5 @@ cinn_cc_test(test_cast_simplify SRCS cast_simplify_test.cc DEPS cinncore)
cinn_cc_test
(
test_remove_schedule_block SRCS remove_schedule_block_test.cc DEPS
cinncore
)
cinn_cc_test
(
test_unroll_loops SRCS unroll_loops_test.cc DEPS cinncore
)
cinn_cc_test
(
test_replace_cross_thread_reduction SRCS
replace_cross_thread_reduction_test.cc DEPS cinncore
)
paddle/cinn/optim/buffer_assign.cc
View file @
01a10755
...
...
@@ -15,10 +15,10 @@
#include "paddle/cinn/optim/buffer_assign.h"
#include "paddle/cinn/common/union_find.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/ir/utils/ir_replace.h"
#include "paddle/cinn/lang/lower_impl.h"
#include "paddle/cinn/optim/ir_replace.h"
namespace
cinn
{
namespace
optim
{
...
...
@@ -73,7 +73,7 @@ std::map<std::string, ir::Tensor> InitialAssignBuffer(
// unify all the tensor occurance with a global one, e.g. there are multiple
// tensor B exists in the expression, replace them with a shared one.
ir
::
CollectIRNodes
(
*
expr
,
[
&
](
const
Expr
*
x
)
->
bool
{
ir
::
ir_utils
::
CollectIRNodes
(
*
expr
,
[
&
](
const
Expr
*
x
)
->
bool
{
auto
*
t
=
x
->
as_tensor
();
if
(
t
&&
!
stages
[
t
]
->
inlined
())
{
Reference
(
x
)
=
Expr
(
all_tensor_map
.
at
(
t
->
name
));
...
...
paddle/cinn/optim/call_arg_list_to_pod_value.cc
View file @
01a10755
...
...
@@ -19,7 +19,7 @@
#include <vector>
#include "paddle/cinn/common/ir_util.h"
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
#include "paddle/cinn/runtime/intrinsic.h"
namespace
cinn
{
...
...
paddle/cinn/optim/cast_bool_to_int8.cc
View file @
01a10755
...
...
@@ -16,7 +16,7 @@
#include <glog/logging.h>
#include "paddle/cinn/ir/
utils/
ir_mutator.h"
#include "paddle/cinn/ir/ir_mutator.h"
namespace
cinn
::
optim
{
...
...
paddle/cinn/optim/cast_simplify.cc
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/optim/cast_simplify.h"
#include "paddle/cinn/ir/utils/ir_mutator.h"
namespace
cinn
::
optim
{
using
cinn
::
common
::
bfloat16
;
using
cinn
::
common
::
float16
;
namespace
{
template
<
typename
CastType
,
typename
T
>
CastType
NormCastValue
(
T
value
)
{
if
(
type_of
<
CastType
>
().
is_uint
()
||
type_of
<
T
>
().
is_uint
())
{
// not support uint
return
static_cast
<
CastType
>
(
value
);
}
if
(
std
::
isinf
(
value
))
{
return
std
::
numeric_limits
<
CastType
>::
infinity
();
}
else
if
(
std
::
isnan
(
value
))
{
return
std
::
numeric_limits
<
CastType
>::
signaling_NaN
();
}
else
if
(
value
>=
static_cast
<
T
>
(
std
::
numeric_limits
<
CastType
>::
max
()))
{
return
std
::
numeric_limits
<
CastType
>::
max
();
}
else
if
(
value
<=
static_cast
<
T
>
(
std
::
numeric_limits
<
CastType
>::
lowest
()))
{
return
std
::
numeric_limits
<
CastType
>::
lowest
();
}
return
static_cast
<
CastType
>
(
value
);
}
struct
Mutator
:
ir
::
IRMutator
<>
{
using
ir
::
IRMutator
<>::
Visit
;
void
Visit
(
const
ir
::
Cast
*
op
,
Expr
*
expr
)
{
auto
*
node
=
expr
->
As
<
ir
::
Cast
>
();
Visit
(
&
node
->
v
(),
&
node
->
v
());
if
(
op
->
type
()
==
op
->
v
().
type
())
{
*
expr
=
op
->
v
();
return
;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if
(
op
->
v
().
is_constant
())
{
if
(
op
->
type
()
==
type_of
<
int8_t
>
())
{
__CAST_TO_TYPE
(
int8_t
)
}
else
if
(
op
->
type
()
==
type_of
<
int16_t
>
())
{
__CAST_TO_TYPE
(
int16_t
)
}
else
if
(
op
->
type
()
==
type_of
<
int32_t
>
())
{
__CAST_TO_TYPE
(
int32_t
)
}
else
if
(
op
->
type
()
==
type_of
<
int64_t
>
())
{
__CAST_TO_TYPE
(
int64_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint8_t
>
())
{
__CAST_TO_TYPE
(
uint8_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint16_t
>
())
{
__CAST_TO_TYPE
(
uint16_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint32_t
>
())
{
__CAST_TO_TYPE
(
uint32_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint64_t
>
())
{
__CAST_TO_TYPE
(
uint64_t
)
}
else
if
(
op
->
type
()
==
type_of
<
float
>
())
{
__CAST_TO_TYPE
(
float
)
}
else
if
(
op
->
type
()
==
type_of
<
double
>
())
{
__CAST_TO_TYPE
(
double
)
}
else
if
(
op
->
type
()
==
type_of
<
bool
>
())
{
__CAST_TO_TYPE
(
bool
)
}
else
if
(
op
->
type
()
==
type_of
<
uint32_t
>
())
{
__CAST_TO_TYPE
(
uint32_t
)
}
else
if
(
op
->
type
()
==
type_of
<
uint64_t
>
())
{
__CAST_TO_TYPE
(
uint64_t
)
}
else
if
(
op
->
type
()
==
type_of
<
bfloat16
>
())
{
// Cannot simplify!!! pass
__CAST_TO_TYPE
(
bfloat16
)
}
else
if
(
op
->
type
()
==
type_of
<
float16
>
())
{
// Cannot simplify!!! pass
__CAST_TO_TYPE
(
float16
)
}
else
{
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
}
// namespace
void
CastSimplify
(
Expr
*
e
)
{
Mutator
mutator
;
mutator
.
Visit
(
e
,
e
);
}
}
// namespace cinn::optim
paddle/cinn/optim/cast_simplify.h
deleted
100644 → 0
View file @
63eb0da5
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/ir.h"
namespace
cinn
::
optim
{
/**
* Simplify the Cast nodes.
*
* There are several patterns:
* 1. the source and target type are the same, drop the Cast node
* 2. for intermediate numbers, just replace the Cast node with a Node of the
* target type
*/
void
CastSimplify
(
Expr
*
e
);
}
// namespace cinn::optim
Prev
1
…
20
21
22
23
24
25
26
27
28
29
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment