Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
01a10755
Commit
01a10755
authored
Mar 04, 2024
by
yuguo-Jack
Browse files
2.5.2-dtk24.04
parent
63eb0da5
Changes
558
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
227 additions
and
131 deletions
+227
-131
paddle/cinn/hlir/op/custom_call.cc
paddle/cinn/hlir/op/custom_call.cc
+1
-1
paddle/cinn/hlir/op/elementwise.cc
paddle/cinn/hlir/op/elementwise.cc
+44
-17
paddle/cinn/hlir/op/nn.cc
paddle/cinn/hlir/op/nn.cc
+9
-0
paddle/cinn/hlir/op/op_util.cc
paddle/cinn/hlir/op/op_util.cc
+13
-37
paddle/cinn/hlir/op/reduction.cc
paddle/cinn/hlir/op/reduction.cc
+68
-27
paddle/cinn/hlir/op/reduction_test.cc
paddle/cinn/hlir/op/reduction_test.cc
+14
-4
paddle/cinn/hlir/op/transform.cc
paddle/cinn/hlir/op/transform.cc
+2
-2
paddle/cinn/hlir/pass/alterlayout_test.cc
paddle/cinn/hlir/pass/alterlayout_test.cc
+17
-9
paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc
paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc
+2
-1
paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc
...e/cinn/hlir/pass/common_subexpression_elimination_test.cc
+7
-4
paddle/cinn/hlir/pass/const_propagate_test.cc
paddle/cinn/hlir/pass/const_propagate_test.cc
+5
-3
paddle/cinn/hlir/pass/constant_folding_pass_test.cc
paddle/cinn/hlir/pass/constant_folding_pass_test.cc
+2
-1
paddle/cinn/hlir/pass/custom_call_pass.cc
paddle/cinn/hlir/pass/custom_call_pass.cc
+6
-3
paddle/cinn/hlir/pass/dense_merge_pass_test.cc
paddle/cinn/hlir/pass/dense_merge_pass_test.cc
+4
-2
paddle/cinn/hlir/pass/dot_merger_test.cc
paddle/cinn/hlir/pass/dot_merger_test.cc
+4
-2
paddle/cinn/hlir/pass/fusion_merge_pass.cc
paddle/cinn/hlir/pass/fusion_merge_pass.cc
+1
-1
paddle/cinn/hlir/pass/general_fusion_merge_pass.cc
paddle/cinn/hlir/pass/general_fusion_merge_pass.cc
+1
-1
paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h
...ir/pass/general_fusion_merge_pass/fusion_pass_registrar.h
+4
-4
paddle/cinn/hlir/pass/opfusion_test.cc
paddle/cinn/hlir/pass/opfusion_test.cc
+21
-11
paddle/cinn/hlir/pass/reduce_split_pass_test.cc
paddle/cinn/hlir/pass/reduce_split_pass_test.cc
+2
-1
No files found.
Too many changes to show.
To preserve performance only
558 of 558+
files are displayed.
Plain diff
Email patch
paddle/cinn/hlir/op/custom_call.cc
View file @
01a10755
...
...
@@ -23,7 +23,7 @@
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/hlir/pe/transform.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
#ifdef CINN_WITH_CUDNN
...
...
paddle/cinn/hlir/op/elementwise.cc
View file @
01a10755
...
...
@@ -17,6 +17,7 @@
#include <iostream>
#include "absl/types/optional.h"
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
...
...
@@ -107,6 +108,13 @@ std::vector<Type> InferDtypeForElementwise(
return
res
;
}
void
GenerateEquationsForElementwise
(
cinn
::
adt
::
config
::
OpEquationContext
*
ctx
)
{
CHECK
(
ctx
->
GetInTensorsRanks
().
size
()
!=
0
)
<<
"The inputs is empty! Please check again."
;
ctx
->
Equal
(
ctx
->
GetInIteratorTuple
(
0
),
ctx
->
GetOutIteratorTuple
(
0
));
}
std
::
vector
<
Type
>
InferDtypeForElementwiseBool
(
const
std
::
vector
<
Type
>
&
inputs_type
,
const
framework
::
AttrMapType
&
attrs
)
{
CHECK
(
!
inputs_type
.
empty
())
...
...
@@ -157,23 +165,31 @@ std::shared_ptr<OpStrategy> StrategyForScale(
CHECK
(
pack_args
[
1
].
is_string
());
std
::
string
tensor_name
=
pack_args
[
1
].
operator
std
::
string
();
if
(
bias_after_scale
)
{
out
=
Compute
(
A
->
shape
,
[
=
](
const
std
::
vector
<
Expr
>
&
indice
)
{
return
ir
::
Cast
::
Make
(
A
->
type
(),
Expr
(
scale
))
*
A
(
indice
)
+
ir
::
Cast
::
Make
(
A
->
type
(),
Expr
(
bias
));
},
tensor_name
);
}
else
{
out
=
Compute
(
A
->
shape
,
[
=
](
const
std
::
vector
<
Expr
>
&
indice
)
{
return
ir
::
Cast
::
Make
(
A
->
type
(),
Expr
(
scale
))
*
(
A
(
indice
)
+
ir
::
Cast
::
Make
(
A
->
type
(),
Expr
(
bias
)));
},
tensor_name
);
}
// Paddle upscale float16 or bfloat16 compute to float32,
// we made CINN consistent with this behavior of Paddle
bool
should_upscale_fp32
=
A
->
type
()
==
common
::
F16
()
||
A
->
type
()
==
common
::
BF16
();
out
=
Compute
(
A
->
shape
,
[
=
](
const
std
::
vector
<
Expr
>
&
indice
)
{
Expr
cast_scale
=
should_upscale_fp32
?
Expr
(
scale
)
:
ir
::
Cast
::
Make
(
A
->
type
(),
Expr
(
scale
));
Expr
cast_bias
=
should_upscale_fp32
?
Expr
(
bias
)
:
ir
::
Cast
::
Make
(
A
->
type
(),
Expr
(
bias
));
Expr
cast_A_indice
=
should_upscale_fp32
?
ir
::
Cast
::
Make
(
common
::
F32
(),
A
(
indice
))
:
A
(
indice
);
Expr
add_result
=
bias_after_scale
?
cast_scale
*
cast_A_indice
+
cast_bias
:
cast_scale
*
(
cast_A_indice
+
cast_bias
);
return
should_upscale_fp32
?
ir
::
Cast
::
Make
(
A
->
type
(),
add_result
)
:
add_result
;
},
tensor_name
);
auto
stages
=
CreateStages
({
out
});
*
ret
=
CINNValuePack
{{
CINNValue
(
Expr
(
out
.
get
())),
CINNValue
(
stages
)}};
});
...
...
@@ -413,6 +429,11 @@ std::vector<Type> InferDtypeForFillConstant(
return
{
out_type
};
}
void
GenerateEquationsForFillConstant
(
cinn
::
adt
::
config
::
OpEquationContext
*
ctx
)
{
// Do nothing
}
std
::
vector
<
std
::
vector
<
std
::
string
>>
InferLayoutForFillConstant
(
const
std
::
vector
<
framework
::
shape_t
>
&
input_shapes
,
const
std
::
vector
<
std
::
string
>
&
input_layouts
,
...
...
@@ -987,6 +1008,9 @@ CINN_REGISTER_HELPER(elementwise_ops) {
MakeOpFunction(cinn::hlir::op::InferShapeForElementwise)) \
.set_attr("inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForElementwise)) \
.set_attr( \
"generate_equations", \
MakeOpFunction(cinn::hlir::op::GenerateEquationsForElementwise)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForElementwise)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
...
...
@@ -1108,6 +1132,9 @@ CINN_REGISTER_HELPER(elementwise_ops) {
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferShapeForFillConstant
))
.
set_attr
(
"inferdtype"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferDtypeForFillConstant
))
.
set_attr
(
"generate_equations"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
GenerateEquationsForFillConstant
))
#ifndef CINN_WITH_CUDA
.
set_attr
(
"inferlayout"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferLayoutForFillConstant
))
...
...
paddle/cinn/hlir/op/nn.cc
View file @
01a10755
...
...
@@ -16,6 +16,7 @@
#include <functional>
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
...
...
@@ -78,6 +79,12 @@ std::vector<framework::shape_t> InferShapeForRelu(
return
res
;
}
void
GenerateEquationsForRelu
(
cinn
::
adt
::
config
::
OpEquationContext
*
ctx
)
{
CHECK
(
ctx
->
GetInTensorsRanks
().
size
()
!=
0
)
<<
"The inputs is empty! Please check again."
;
ctx
->
Equal
(
ctx
->
GetInIteratorTuple
(
0
),
ctx
->
GetOutIteratorTuple
(
0
));
}
std
::
vector
<
Type
>
InferDtypeForRelu
(
const
std
::
vector
<
Type
>
&
inputs_type
,
const
framework
::
AttrMapType
&
attrs
)
{
CHECK
(
!
inputs_type
.
empty
())
...
...
@@ -2328,6 +2335,8 @@ CINN_REGISTER_HELPER(nn_ops) {
"CINNStrategy"
,
cinn
::
hlir
::
op
::
StrategyForRelu
)
.
set_attr
(
"infershape"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferShapeForRelu
))
.
set_attr
(
"inferdtype"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferDtypeForRelu
))
.
set_attr
(
"generate_equations"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
GenerateEquationsForRelu
))
#ifndef CINN_WITH_CUDA
.
set_attr
(
"inferlayout"
,
MakeOpFunction
(
cinn
::
hlir
::
op
::
InferLayoutForUnary
))
...
...
paddle/cinn/hlir/op/op_util.cc
View file @
01a10755
...
...
@@ -34,45 +34,21 @@ CINNSchedule GetElementwiseScheduleFunc(
common
::
CINNValuePack
arg_pack
=
args
[
0
];
CHECK_GT
(
arg_pack
.
size
(),
0U
)
<<
"arg_pack.size() must contains at least one element."
;
// TODO(Aurelius84): For NewIrCompiler, the outputs of Compute are
// tensor_ref and not Expr.
bool
is_tensor_stages
=
arg_pack
.
size
()
==
2U
&&
arg_pack
[
0
].
is_tensor
()
&&
arg_pack
[
1
].
is_stagemap
();
if
(
!
is_tensor_stages
)
{
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRElementwiseSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
}
else
{
CHECK
(
!
args
.
empty
())
<<
"The input argument of ElementwiseSchedule is "
"empty! Please check.
\n
"
;
common
::
CINNValuePack
arg_pack
=
args
[
0
];
Expr
out
=
arg_pack
[
0
];
poly
::
StageMap
stages
=
arg_pack
[
1
];
CHECK
(
out
.
as_tensor
());
CHECK_EQ
(
arg_pack
.
size
(),
2UL
);
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
pe
::
CudaScheduleInjective
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
);
}
else
if
(
target
.
arch
==
Target
::
Arch
::
X86
)
{
pe
::
ScheduleInjectiveCPU
(
stages
[
out
.
as_tensor_ref
()],
output_shapes
.
front
(),
target
,
vectorizable
);
std
::
vector
<
Expr
>
vec_ast
;
for
(
int
i
=
0
;
i
<
arg_pack
.
size
();
i
++
)
{
if
(
arg_pack
[
i
].
is_expr
())
{
Expr
temp
=
arg_pack
[
i
];
vec_ast
.
emplace_back
(
temp
);
}
*
ret
=
arg_pack
;
}
CHECK
(
!
vec_ast
.
empty
());
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
pe
::
IRElementwiseSchedule
(
ir_sch
,
output_shapes
.
front
(),
target
);
std
::
vector
<
common
::
CINNValue
>
res
{
common
::
CINNValue
(
ir_sch
.
GetModule
().
GetExprs
().
at
(
0
))};
*
ret
=
common
::
CINNValuePack
{
res
};
});
}
...
...
paddle/cinn/hlir/op/reduction.cc
View file @
01a10755
...
...
@@ -18,6 +18,7 @@
#include <iostream>
#include <vector>
#include "paddle/cinn/adt/op_equation_context.h"
#include "paddle/cinn/hlir/framework/node.h"
#include "paddle/cinn/hlir/framework/op.h"
#include "paddle/cinn/hlir/framework/op_strategy.h"
...
...
@@ -28,6 +29,11 @@
#include "paddle/cinn/ir/op/ir_operators.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/optim/ir_simplify.h"
#include "paddle/cinn/runtime/flags.h"
PD_DECLARE_bool
(
cinn_enable_map_expr
);
PD_DECLARE_bool
(
cinn_new_group_scheduler
);
namespace
cinn
{
namespace
hlir
{
...
...
@@ -58,7 +64,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
const
std
::
string
&
op_name
,
BlockReduceFunc
gpu_reduce_with_last_axis_func
,
BlockReduceFunc
gpu_reduce_without_last_axis_func
,
ReduceFunc
c
pu
_reduce_func
)
{
ReduceFunc
c
ommon
_reduce_func
)
{
std
::
vector
<
int
>
reduce_axes
;
auto
ndim
=
inputs
[
0
]
->
shape
.
size
();
if
(
attrs
.
attr_store
.
count
(
"dim"
))
{
...
...
@@ -127,7 +133,16 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
<<
"The type of input argument "
<<
x
->
name
<<
" of "
<<
op_name
<<
" should be bool, but get "
<<
x
->
type
()
<<
"! Please check."
;
if
(
target
==
common
::
DefaultNVGPUTarget
())
{
const
auto
&
NaiveCompute
=
[
&
]()
{
VLOG
(
3
)
<<
"Do Reduce Compute!"
;
auto
out
=
common_reduce_func
(
x
,
reduce_axes
,
keep_dim
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
std
::
vector
<
CINNValue
>
cinn_values
{
CINNValue
(
out
),
CINNValue
(
stages
)};
*
ret
=
CINNValuePack
{
cinn_values
};
};
if
(
!
FLAGS_cinn_enable_map_expr
&&
!
FLAGS_cinn_new_group_scheduler
&&
target
==
common
::
DefaultNVGPUTarget
())
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
VLOG
(
3
)
<<
"Do Two Step Block Reduce Compute!"
;
auto
res
=
gpu_reduce_with_last_axis_func
(
...
...
@@ -154,12 +169,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
*
ret
=
CINNValuePack
{
cinn_values
};
}
}
else
{
VLOG
(
3
)
<<
"Do Reduce Compute!"
;
auto
out
=
cpu_reduce_func
(
x
,
reduce_axes
,
keep_dim
,
tensor_name
);
auto
stages
=
CreateStages
({
out
});
std
::
vector
<
CINNValue
>
cinn_values
{
CINNValue
(
out
),
CINNValue
(
stages
)};
*
ret
=
CINNValuePack
{
cinn_values
};
NaiveCompute
();
}
});
...
...
@@ -193,7 +203,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
ir_sch
.
MergeExprs
();
if
(
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
FLAGS_cinn_new_group_scheduler
&&
target
.
arch
==
Target
::
Arch
::
NVGPU
)
{
if
(
!
WithoutLastDimInReduce
(
inputs
[
0
]
->
shape
,
reduce_axes
))
{
if
(
arg_pack
.
size
()
==
4
)
{
CHECK_EQ
(
vec_tensor
.
size
(),
2
);
...
...
@@ -313,7 +323,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
reduce_op_, \
gpu_reduce_with_last_axis_func, \
gpu_reduce_without_last_axis_func, \
c
pu
_reduce_func)
\
c
ommon
_reduce_func) \
std::shared_ptr<OpStrategy> StrategyFor##reduce_op_( \
const framework::NodeAttr &attrs, \
const std::vector<ir::Tensor> &inputs, \
...
...
@@ -328,7 +338,7 @@ std::shared_ptr<OpStrategy> StrategyForReduce(
#op_name_, \
gpu_reduce_with_last_axis_func, \
gpu_reduce_without_last_axis_func, \
c
pu
_reduce_func);
\
c
ommon
_reduce_func); \
}
STRATEGY_FOR_REDUCE
(
reduce_sum
,
...
...
@@ -414,6 +424,35 @@ std::vector<shape_t> InferShapeForReduction(
return
{
out_shapes
};
}
void
GenerateEquationsForReduction
(
cinn
::
adt
::
config
::
OpEquationContext
*
ctx
)
{
CHECK
(
ctx
->
GetInTensorsRanks
().
size
()
!=
0
)
<<
"The inputs is empty! Please check again."
;
const
bool
keep_dim
=
ctx
->
Attr
<
bool
>
(
"keep_dim"
);
const
auto
&
dim
=
ctx
->
Attr
<
std
::
vector
<
int
>>
(
"dim"
);
const
auto
&
IsReduceAxis
=
[
&
](
const
int
in_axis
)
{
return
std
::
find
(
dim
.
begin
(),
dim
.
end
(),
in_axis
)
!=
dim
.
end
();
};
const
auto
&
VisitEachAxisPair
=
[
&
](
const
auto
&
DoEach
)
{
std
::
size_t
out_axis
=
0
;
for
(
std
::
size_t
in_axis
=
0
;
in_axis
<
ctx
->
GetInTensorsRanks
().
at
(
0
);
++
in_axis
)
{
if
(
IsReduceAxis
(
in_axis
))
{
out_axis
+=
keep_dim
;
}
else
{
DoEach
(
in_axis
,
out_axis
);
out_axis
+=
1
;
}
}
};
VisitEachAxisPair
([
&
](
const
int
input_axis
,
const
int
output_axis
)
{
ctx
->
Equal
(
ctx
->
GetInIteratorTuple
(
0
)
->
at
(
input_axis
),
ctx
->
GetOutIteratorTuple
(
0
)
->
at
(
output_axis
));
});
}
std
::
vector
<
Type
>
InferDtypeForReduction
(
const
std
::
vector
<
Type
>
&
inputs_type
,
const
framework
::
AttrMapType
&
attrs
)
{
CHECK
(
!
inputs_type
.
empty
())
...
...
@@ -477,22 +516,24 @@ std::vector<std::vector<std::string>> InferLayoutForBnOptimize(
}
// namespace cinn
CINN_REGISTER_HELPER
(
reduce_ops
)
{
#define CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, dtype__) \
CINN_REGISTER_OP(op__) \
.describe(#op__ " function") \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr<cinn::hlir::framework::StrategyFunction>( \
"CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \
.set_attr("infershape", \
MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \
.set_attr( \
"inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForReduction##dtype__)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
"OpPattern", cinn::hlir::framework::OpPatternKind::kReduction) \
#define CINN_REGISTER_REDUCTION_WITH_DTYPE(op__, op_stragegy__, dtype__) \
CINN_REGISTER_OP(op__) \
.describe(#op__ " function") \
.set_num_inputs(1) \
.set_num_outputs(1) \
.set_attr<cinn::hlir::framework::StrategyFunction>( \
"CINNStrategy", cinn::hlir::op::StrategyFor##op_stragegy__) \
.set_attr("infershape", \
MakeOpFunction(cinn::hlir::op::InferShapeForReduction)) \
.set_attr( \
"inferdtype", \
MakeOpFunction(cinn::hlir::op::InferDtypeForReduction##dtype__)) \
.set_attr("generate_equations", \
MakeOpFunction(cinn::hlir::op::GenerateEquationsForReduction)) \
.set_attr("inferlayout", \
MakeOpFunction(cinn::hlir::op::InferLayoutForReduction)) \
.set_attr<cinn::hlir::framework::OpPatternKind>( \
"OpPattern", cinn::hlir::framework::OpPatternKind::kReduction) \
.set_support_level(4);
#define CINN_REGISTER_REDUCTION(op__, op_stragegy__) \
...
...
paddle/cinn/hlir/op/reduction_test.cc
View file @
01a10755
...
...
@@ -39,6 +39,9 @@
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
PD_DECLARE_bool
(
cinn_new_group_scheduler
);
namespace
cinn
{
namespace
hlir
{
namespace
framework
{
...
...
@@ -362,6 +365,9 @@ void TestCaseForReduce(const float init_val,
dim3
block
;
grid
=
{
c
,
1
,
1
};
int
block_dim_x
=
n
*
w
*
h
>
1024
?
1024
:
n
*
w
*
h
;
if
(
FLAGS_cinn_new_group_scheduler
)
{
block_dim_x
=
1
;
}
block
=
{
block_dim_x
,
1
,
1
};
void
*
args
[]
=
{
&
dev_x
,
&
dev_z
};
...
...
@@ -531,7 +537,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce) {
std
::
vector
<
int
>
dim
=
{
1
};
auto
res
=
GenReduceCode
(
shape
,
dim
,
"Operator_Reduction_Case_Warp_Reduce"
);
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
!=
std
::
string
::
npos
);
if
(
!
FLAGS_cinn_new_group_scheduler
)
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
!=
std
::
string
::
npos
);
}
TEST
(
Operator
,
Operator_Reduction_Case_Block_Reduce
)
{
...
...
@@ -544,7 +551,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce) {
std
::
vector
<
int
>
dim
=
{
1
};
auto
res
=
GenReduceCode
(
shape
,
dim
,
"Operator_Reduction_Case_Block_Reduce"
);
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
==
std
::
string
::
npos
);
if
(
!
FLAGS_cinn_new_group_scheduler
)
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
==
std
::
string
::
npos
);
}
TEST
(
Operator
,
Operator_Reduction_Case_Warp_Reduce_Case_1
)
{
...
...
@@ -558,7 +566,8 @@ TEST(Operator, Operator_Reduction_Case_Warp_Reduce_Case_1) {
auto
res
=
GenReduceCode
(
shape
,
dim
,
"Operator_Reduction_Case_Warp_Reduce_Case_1"
);
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
!=
std
::
string
::
npos
);
if
(
!
FLAGS_cinn_new_group_scheduler
)
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
!=
std
::
string
::
npos
);
}
TEST
(
Operator
,
Operator_Reduction_Case_Block_Reduce_Case_1
)
{
...
...
@@ -572,7 +581,8 @@ TEST(Operator, Operator_Reduction_Case_Block_Reduce_Case_1) {
auto
res
=
GenReduceCode
(
shape
,
dim
,
"Operator_Reduction_Case_Block_Reduce_Case_2"
);
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
==
std
::
string
::
npos
);
if
(
!
FLAGS_cinn_new_group_scheduler
)
CHECK
(
res
.
second
.
find
(
"threadIdx.x < 32"
)
==
std
::
string
::
npos
);
}
}
// namespace framework
}
// namespace hlir
...
...
paddle/cinn/hlir/op/transform.cc
View file @
01a10755
...
...
@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pe/ir_schedule_pe.h"
#include "paddle/cinn/hlir/pe/nn.h"
#include "paddle/cinn/hlir/pe/schedule.h"
#include "paddle/cinn/ir/
utils/
ir_printer.h"
#include "paddle/cinn/ir/ir_printer.h"
#include "paddle/cinn/utils/string.h"
namespace
cinn
{
...
...
@@ -2044,7 +2044,7 @@ CINN_REGISTER_HELPER(transform_ops) {
// pointers, the code generated by operator fusion will have out-of-bounds
// access. It should not fuse with any other injective operators, though
// scatter_add is injective. turn KNonFusible to kInjective will fail
// /Paddle/python/paddle/
fluid
/tests/unittests/test_index_select_op.py
// /Paddle/python/paddle/
base
/tests/unittests/test_index_select_op.py
.
set_attr
<
cinn
::
hlir
::
framework
::
OpPatternKind
>
(
"OpPattern"
,
cinn
::
hlir
::
framework
::
OpPatternKind
::
kNonFusible
)
.
set_support_level
(
4
);
...
...
paddle/cinn/hlir/pass/alterlayout_test.cc
View file @
01a10755
...
...
@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string
(
model_dir
,
""
,
""
);
PD_
DEFINE_string
(
model_dir
,
""
,
""
);
namespace
cinn
{
namespace
frontend
{
...
...
@@ -76,7 +76,8 @@ TEST(conv, conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -122,7 +123,8 @@ TEST(conv_relu_conv, conv_relu_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -171,7 +173,8 @@ TEST(conv_add_conv, conv_add_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -227,7 +230,8 @@ TEST(conv_bn_conv, conv_bn_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -283,7 +287,8 @@ TEST(conv_pool2d_conv, conv_pool2d_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -334,7 +339,8 @@ TEST(conv_softmax_conv, conv_softmax_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -382,7 +388,8 @@ TEST(conv_sigmoid_conv, conv_sigmoid_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -434,7 +441,8 @@ TEST(conv_mul_conv, conv_mul_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
paddle/cinn/hlir/pass/check_fusion_accuracy_pass_test.cc
View file @
01a10755
...
...
@@ -46,7 +46,8 @@ void RunTest(const Target& target,
const
std
::
shared_ptr
<
Graph
>&
graph
,
const
std
::
vector
<
std
::
string
>&
input_names
)
{
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
for
(
size_t
i
=
0
;
i
<
input_names
.
size
();
++
i
)
{
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
input_names
[
i
]);
...
...
paddle/cinn/hlir/pass/common_subexpression_elimination_test.cc
View file @
01a10755
...
...
@@ -37,7 +37,7 @@
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string
(
model_dir
,
""
,
""
);
PD_
DEFINE_string
(
model_dir
,
""
,
""
);
namespace
cinn
{
namespace
frontend
{
...
...
@@ -71,7 +71,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case1) {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"BuildNonFusedGroupsPass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
&
prerun_instrs
=
runtime_program
->
GetPreRunInstructions
();
auto
&
run_instrs
=
runtime_program
->
GetRunInstructions
();
...
...
@@ -115,7 +116,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case2) {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"BuildNonFusedGroupsPass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
&
prerun_instrs
=
runtime_program
->
GetPreRunInstructions
();
auto
&
run_instrs
=
runtime_program
->
GetRunInstructions
();
...
...
@@ -180,7 +182,8 @@ TEST(common_subexpression_elimination, common_subexpression_elimination_case3) {
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
&
prerun_instrs
=
runtime_program
->
GetPreRunInstructions
();
auto
&
run_instrs
=
runtime_program
->
GetRunInstructions
();
...
...
paddle/cinn/hlir/pass/const_propagate_test.cc
View file @
01a10755
...
...
@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string
(
model_dir
,
""
,
""
);
PD_
DEFINE_string
(
model_dir
,
""
,
""
);
namespace
cinn
{
namespace
frontend
{
...
...
@@ -57,7 +57,8 @@ TEST(const_conv, const_conv) {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusionPass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
&
prerun_instrs
=
runtime_program
->
GetPreRunInstructions
();
auto
&
run_instrs
=
runtime_program
->
GetRunInstructions
();
...
...
@@ -101,7 +102,8 @@ TEST(const_bn, const_bn) {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
auto
&
prerun_instrs
=
runtime_program
->
GetPreRunInstructions
();
auto
&
run_instrs
=
runtime_program
->
GetRunInstructions
();
...
...
paddle/cinn/hlir/pass/constant_folding_pass_test.cc
View file @
01a10755
...
...
@@ -46,7 +46,8 @@ std::unordered_map<std::string, std::vector<float>> RunModelTest(
hlir
::
framework
::
ApplyPasses
(
graph
.
get
(),
passes
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
for
(
auto
&
data
:
input_data
)
{
...
...
paddle/cinn/hlir/pass/custom_call_pass.cc
View file @
01a10755
...
...
@@ -17,7 +17,8 @@
#include "paddle/cinn/hlir/op/external_api_registry.h"
#include "paddle/cinn/utils/string.h"
DECLARE_string
(
cinn_custom_call_deny_ops
);
PD_DECLARE_string
(
cinn_custom_call_deny_ops
);
PD_DECLARE_bool
(
cinn_use_cutlass
);
namespace
cinn
{
namespace
hlir
{
...
...
@@ -72,8 +73,10 @@ class GraphAlterHelper {
}
}
node
->
attrs
.
attr_store
[
"original_op"
]
=
node
->
op
()
->
name
;
node
->
attrs
.
op
=
framework
::
Operator
::
Get
(
"custom_call"
);
if
(
!
FLAGS_cinn_use_cutlass
||
node
->
op
()
->
name
!=
"matmul"
)
{
node
->
attrs
.
attr_store
[
"original_op"
]
=
node
->
op
()
->
name
;
node
->
attrs
.
op
=
framework
::
Operator
::
Get
(
"custom_call"
);
}
}
}
...
...
paddle/cinn/hlir/pass/dense_merge_pass_test.cc
View file @
01a10755
...
...
@@ -46,7 +46,8 @@ void RunModelTest(Program& program, // NOLINT
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
for
(
int
idx
=
0
;
idx
<
inputs
.
size
();
++
idx
)
{
...
...
@@ -72,7 +73,8 @@ void RunModelTest(Program& program, // NOLINT
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
for
(
int
idx
=
0
;
idx
<
inputs
.
size
();
++
idx
)
{
...
...
paddle/cinn/hlir/pass/dot_merger_test.cc
View file @
01a10755
...
...
@@ -45,7 +45,8 @@ void RunModelTest(Program& program, // NOLINT
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
for
(
int
idx
=
0
;
idx
<
inputs
.
size
();
++
idx
)
{
...
...
@@ -71,7 +72,8 @@ void RunModelTest(Program& program, // NOLINT
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"FusionMergePass"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
for
(
int
idx
=
0
;
idx
<
inputs
.
size
();
++
idx
)
{
...
...
paddle/cinn/hlir/pass/fusion_merge_pass.cc
View file @
01a10755
...
...
@@ -14,7 +14,7 @@
#include "paddle/cinn/hlir/pass/fusion_merge_pass_util.h"
DECLARE_bool
(
enhance_vertical_fusion_with_recompute
);
PD_
DECLARE_bool
(
enhance_vertical_fusion_with_recompute
);
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/pass/general_fusion_merge_pass.cc
View file @
01a10755
...
...
@@ -26,7 +26,7 @@
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass/lightware_fuse_pass_ctx.h"
#include "paddle/cinn/hlir/pass/general_fusion_merge_pass_utils.h"
DECLARE_bool
(
enhance_vertical_fusion_with_recompute
);
PD_
DECLARE_bool
(
enhance_vertical_fusion_with_recompute
);
namespace
cinn
{
namespace
hlir
{
...
...
paddle/cinn/hlir/pass/general_fusion_merge_pass/fusion_pass_registrar.h
View file @
01a10755
...
...
@@ -52,11 +52,11 @@ class FusionPassRegistrar final : public Registrar {
#define CINN_REGISTER_FUSION_PASS(pass_name, pass_class) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_pass__##pass_name,
\
__reg_
cinn_fusion_
pass__##pass_name, \
"CINN_REGISTER_FUSION_PASS must be called in global namespace"); \
static ::cinn::hlir::pass::FusionPassRegistrar<pass_class> \
__pass_registrar_##pass_name##__(#pass_name);
\
int TouchFusionPassRegistrar_##pass_name() {
\
__pass_registrar_##pass_name##__.Touch();
\
__
cinn_fusion_
pass_registrar_##pass_name##__(#pass_name); \
int Touch
Cinn
FusionPassRegistrar_##pass_name() { \
__
cinn_fusion_
pass_registrar_##pass_name##__.Touch(); \
return 0; \
}
paddle/cinn/hlir/pass/opfusion_test.cc
View file @
01a10755
...
...
@@ -25,7 +25,7 @@
#include "paddle/cinn/hlir/pass/use_pass.h"
#include "paddle/cinn/utils/data_util.h"
DEFINE_string
(
model_dir
,
""
,
""
);
PD_
DEFINE_string
(
model_dir
,
""
,
""
);
namespace
cinn
{
namespace
frontend
{
...
...
@@ -80,7 +80,8 @@ TEST(complex2, complex2) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -135,7 +136,8 @@ TEST(complex1, complex1) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -172,7 +174,8 @@ TEST(fuse_add_relu, fuse_add_relu) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -210,7 +213,8 @@ TEST(fuse_add, fuse_add) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -268,7 +272,8 @@ TEST(conv_bn_conv, conv_bn_conv) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -319,7 +324,8 @@ TEST(fuse_conv_add, fuse_conv_add) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -377,7 +383,8 @@ TEST(conv_add_mul, conv_add_mul) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -426,7 +433,8 @@ TEST(fuse_conv_add1, fuse_conv_add1) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -465,7 +473,8 @@ TEST(transpose_reshape_concat, transpose_reshape_concat) {
auto
scope
=
BuildScope
(
target
,
graph
);
LOG
(
INFO
)
<<
"graph:
\n
"
<<
graph
->
Visualize
();
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
@@ -517,7 +526,8 @@ TEST(conv_bn, conv_bn) {
hlir
::
framework
::
ApplyPass
(
graph
.
get
(),
"OpFusion"
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
runtime_program
=
gc
.
Build
();
scope
->
Var
<
hlir
::
framework
::
Tensor
>
(
"A"
);
...
...
paddle/cinn/hlir/pass/reduce_split_pass_test.cc
View file @
01a10755
...
...
@@ -30,7 +30,8 @@ std::unordered_map<std::string, std::vector<float>> RunModelTest(
hlir
::
framework
::
ApplyPasses
(
graph
.
get
(),
passes
);
auto
scope
=
BuildScope
(
target
,
graph
);
hlir
::
framework
::
GraphCompiler
gc
(
target
,
scope
,
graph
);
hlir
::
framework
::
CompilationContext
context
(
graph
,
scope
,
target
);
hlir
::
framework
::
GraphCompiler
gc
(
context
);
auto
run_program
=
gc
.
Build
();
for
(
auto
&
data
:
input_data
)
{
...
...
Prev
1
…
15
16
17
18
19
20
21
22
23
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment