Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
992bec46
Commit
992bec46
authored
Oct 08, 2023
by
“yuguo”
Browse files
2.5
parent
0259837d
Changes
357
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2182 additions
and
0 deletions
+2182
-0
paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc
paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc
+80
-0
paddle/cinn/auto_schedule/cost_model/expr_cost_model.h
paddle/cinn/auto_schedule/cost_model/expr_cost_model.h
+46
-0
paddle/cinn/auto_schedule/cost_model/feature.cc
paddle/cinn/auto_schedule/cost_model/feature.cc
+182
-0
paddle/cinn/auto_schedule/cost_model/feature.h
paddle/cinn/auto_schedule/cost_model/feature.h
+189
-0
paddle/cinn/auto_schedule/cost_model/feature_extractor.cc
paddle/cinn/auto_schedule/cost_model/feature_extractor.cc
+311
-0
paddle/cinn/auto_schedule/cost_model/feature_extractor.h
paddle/cinn/auto_schedule/cost_model/feature_extractor.h
+60
-0
paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc
...e/cinn/auto_schedule/cost_model/feature_extractor_test.cc
+166
-0
paddle/cinn/auto_schedule/cost_model/feature_test.cc
paddle/cinn/auto_schedule/cost_model/feature_test.cc
+28
-0
paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc
paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc
+149
-0
paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h
paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h
+78
-0
paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc
paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc
+70
-0
paddle/cinn/auto_schedule/database/CMakeLists.txt
paddle/cinn/auto_schedule/database/CMakeLists.txt
+7
-0
paddle/cinn/auto_schedule/database/database.cc
paddle/cinn/auto_schedule/database/database.cc
+131
-0
paddle/cinn/auto_schedule/database/database.h
paddle/cinn/auto_schedule/database/database.h
+106
-0
paddle/cinn/auto_schedule/database/database_test.cc
paddle/cinn/auto_schedule/database/database_test.cc
+72
-0
paddle/cinn/auto_schedule/database/jsonfile_database.cc
paddle/cinn/auto_schedule/database/jsonfile_database.cc
+108
-0
paddle/cinn/auto_schedule/database/jsonfile_database.h
paddle/cinn/auto_schedule/database/jsonfile_database.h
+57
-0
paddle/cinn/auto_schedule/database/jsonfile_database_test.cc
paddle/cinn/auto_schedule/database/jsonfile_database_test.cc
+253
-0
paddle/cinn/auto_schedule/measure/CMakeLists.txt
paddle/cinn/auto_schedule/measure/CMakeLists.txt
+7
-0
paddle/cinn/auto_schedule/measure/measure.h
paddle/cinn/auto_schedule/measure/measure.h
+82
-0
No files found.
Too many changes to show.
To preserve performance only
357 of 357+
files are displayed.
Plain diff
Email patch
paddle/cinn/auto_schedule/cost_model/expr_cost_model.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/expr_cost_model.h"
#include <glog/logging.h>
#include <atomic>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include "paddle/cinn/auto_schedule/cost_model/feature_extractor.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
float
ExprCostModel
::
Predict
(
const
ir
::
ModuleExpr
&
sample
,
const
common
::
Target
&
target
)
const
{
if
(
trained_times_
.
load
()
==
0
)
{
return
SearchState
::
NOT_INIT_COST
;
}
FeatureExtractor
extractor
;
Feature
feature
=
extractor
.
Extract
(
sample
,
target
);
std
::
vector
<
float
>
feature_numbers
=
feature
.
ToFixedSizeVector
();
std
::
vector
<
float
>
pred
=
XgbCostModel
::
Predict
({
feature_numbers
});
return
pred
[
0
];
}
void
ExprCostModel
::
Train
(
const
std
::
vector
<
const
ir
::
ModuleExpr
*>&
samples
,
const
std
::
vector
<
float
>&
labels
,
const
common
::
Target
&
target
)
{
trained_times_
.
store
(
1
);
size_t
total_size
=
samples
.
size
();
CHECK_EQ
(
total_size
,
labels
.
size
())
<<
"Samples must have same size as labels"
;
std
::
vector
<
std
::
vector
<
float
>>
train_feature_numbers
(
total_size
);
FeatureExtractor
extractor
;
for
(
size_t
i
=
0
;
i
<
total_size
;
++
i
)
{
CHECK
(
samples
[
i
]
!=
nullptr
)
<<
"Train samples cannot be nullptr"
;
Feature
feature
=
extractor
.
Extract
(
*
samples
[
i
],
target
);
train_feature_numbers
[
i
]
=
feature
.
ToFixedSizeVector
();
}
XgbCostModel
::
Train
(
train_feature_numbers
,
labels
);
}
void
ExprCostModel
::
Update
(
const
std
::
vector
<
const
ir
::
ModuleExpr
*>&
samples
,
const
std
::
vector
<
float
>&
labels
,
const
common
::
Target
&
target
)
{
++
trained_times_
;
size_t
total_size
=
samples
.
size
();
CHECK_EQ
(
total_size
,
labels
.
size
())
<<
"Samples must have same size as labels"
;
std
::
vector
<
std
::
vector
<
float
>>
train_feature_numbers
(
total_size
);
FeatureExtractor
extractor
;
for
(
size_t
i
=
0
;
i
<
total_size
;
++
i
)
{
CHECK
(
samples
[
i
]
!=
nullptr
)
<<
"Train samples cannot be nullptr"
;
Feature
feature
=
extractor
.
Extract
(
*
samples
[
i
],
target
);
train_feature_numbers
[
i
]
=
feature
.
ToFixedSizeVector
();
}
XgbCostModel
::
Update
(
train_feature_numbers
,
labels
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/expr_cost_model.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <atomic>
#include <vector>
#include "paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* A C++ cost model which trains and predicts on ir::Expr
*
*/
class
ExprCostModel
:
public
XgbCostModel
{
public:
virtual
float
Predict
(
const
ir
::
ModuleExpr
&
sample
,
const
common
::
Target
&
target
)
const
;
void
Train
(
const
std
::
vector
<
const
ir
::
ModuleExpr
*>&
samples
,
const
std
::
vector
<
float
>&
labels
,
const
common
::
Target
&
target
);
void
Update
(
const
std
::
vector
<
const
ir
::
ModuleExpr
*>&
samples
,
const
std
::
vector
<
float
>&
labels
,
const
common
::
Target
&
target
);
private:
std
::
atomic
<
int
>
trained_times_
{
0
};
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/feature.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include <glog/logging.h>
#include <vector>
#include "paddle/cinn/common/target.h"
namespace
cinn
{
namespace
auto_schedule
{
Feature
::
Feature
()
:
target_
(
common
::
UnkTarget
()),
stack_encoded_feature_
(
1
),
// initialize a LoopBlockFeature as root block
current_loop_block_index_
(
0
),
parent_indices_
(
1
,
-
1
)
{}
Feature
::
Feature
(
const
common
::
Target
&
target
)
:
target_
(
target
),
stack_encoded_feature_
(
1
),
// initialize a LoopBlockFeature as root block
current_loop_block_index_
(
0
),
parent_indices_
(
1
,
-
1
)
{}
std
::
vector
<
float
>
Feature
::
ToFixedSizeVector
()
{
std
::
vector
<
float
>
ret
(
LoopBlockFeature
::
kTotalSize
+
1
,
0
);
// LoopBlockFeature::kTotalSize plus 1 for target
if
(
target_
==
common
::
DefaultNVGPUTarget
())
{
ret
[
0
]
=
1
;
}
// else 0 for other cases
// loop[i] feature count should multiply iter_multi_num[i]
std
::
vector
<
int
>
iter_multi_num
;
for
(
size_t
i
=
0
;
i
<
stack_encoded_feature_
.
size
();
++
i
)
{
int
j
=
1
;
const
LoopBlockFeature
&
loop_feature
=
stack_encoded_feature_
[
i
];
int
loop_prod
=
1
;
int
parent_prod
=
1
;
if
(
i
!=
0
)
{
parent_prod
=
iter_multi_num
[
parent_indices_
[
i
]];
loop_prod
=
parent_prod
*
loop_feature
.
loop_length
;
}
iter_multi_num
.
push_back
(
loop_prod
);
ret
[
j
]
+=
(
loop_feature
.
float_add_or_sub
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_mul
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_div_or_mod
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_cmp
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_math_func
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_other_call
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_add_or_sub
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_mul
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_div_or_mod
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_cmp
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_math_func
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_other_call
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
bool_op
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
select_op
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
mem_alloc
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
mem_free
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
mem_read
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
mem_write
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_reduce_sum_or_sub
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_reduce_mul
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_reduce_div
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_reduce_max_or_min
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
float_broadcast
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_reduce_sum_or_sub
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_reduce_mul
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_reduce_div
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_reduce_max_or_min
*
loop_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
int_broadcast
*
loop_prod
);
++
j
;
ret
[
j
+
static_cast
<
int
>
(
loop_feature
.
loop_opt_type
)]
+=
1
;
j
+=
LoopBlockFeature
::
kOptApplySize
;
ret
[
j
]
+=
(
loop_feature
.
len_blockIdx_x
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
len_blockIdx_y
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
len_blockIdx_z
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
len_threadIdx_x
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
len_threadIdx_y
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
len_threadIdx_z
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
len_vthread
*
parent_prod
);
++
j
;
ret
[
j
]
+=
(
loop_feature
.
vectorize_factor
*
parent_prod
);
++
j
;
}
for
(
size_t
i
=
0
;
i
<
ret
.
size
();
++
i
)
{
ret
[
i
]
=
slog
(
ret
[
i
]);
}
return
ret
;
}
void
Feature
::
IntoLoopBlock
()
{
stack_encoded_feature_
.
emplace_back
(
LoopBlockFeature
());
stack_encoded_feature_
[
current_loop_block_index_
].
num_sub_loops
+=
1
;
parent_indices_
.
push_back
(
current_loop_block_index_
);
current_loop_block_index_
=
stack_encoded_feature_
.
size
()
-
1
;
}
void
Feature
::
ExitLoopBlock
()
{
current_loop_block_index_
=
parent_indices_
[
current_loop_block_index_
];
}
LoopBlockFeature
&
Feature
::
CurrentLoopBlock
()
{
return
stack_encoded_feature_
[
current_loop_block_index_
];
}
const
LoopBlockFeature
&
Feature
::
CurrentLoopBlock
()
const
{
return
stack_encoded_feature_
[
current_loop_block_index_
];
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/feature.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <vector>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/* Loop feature enums */
enum
class
ForOptimizeFeatureEnum
:
int
{
kNone
,
kGpuBind
,
kParallel
,
kUnroll
,
kVectorize
};
/* function to scale feature numbers */
inline
float
slog
(
float
x
)
{
return
x
<
0
?
std
::
log2
(
-
x
+
1
)
:
std
::
log2
(
x
+
1
);
}
class
LoopBlockFeature
{
public:
// TODO(zhhsplendid): distinguish more types such as float16, float32,
// float64, etc. However speed the gap between float and int are larger than
// different bits, so we just distinguished int and float here
/* Arithmetic features */
int
float_add_or_sub
=
0
;
int
float_mul
=
0
;
int
float_div_or_mod
=
0
;
int
float_cmp
=
0
;
int
float_math_func
=
0
;
int
float_other_call
=
0
;
// like simple assign, cast, etc.
int
int_add_or_sub
=
0
;
int
int_mul
=
0
;
int
int_div_or_mod
=
0
;
int
int_cmp
=
0
;
int
int_math_func
=
0
;
int
int_other_call
=
0
;
// like simple assign, cast, etc.
int
bool_op
=
0
;
int
select_op
=
0
;
static
constexpr
int
kArithSize
=
6
*
2
+
2
;
/**
* Buffer memory features, which is the number of memory operations.
* Note that different size of memory operation can have various speed,
* however the speed difference would be small in OS. A meticulous TODO
* may be collect operand sizes (like alloc size, write size, or so)
*/
int
mem_alloc
=
0
;
int
mem_free
=
0
;
int
mem_read
=
0
;
int
mem_write
=
0
;
static
constexpr
int
kMemSize
=
4
;
/**
* Reduce and Broadcast features
*/
int
float_reduce_sum_or_sub
=
0
;
int
float_reduce_mul
=
0
;
int
float_reduce_div
=
0
;
int
float_reduce_max_or_min
=
0
;
int
float_broadcast
=
0
;
int
int_reduce_sum_or_sub
=
0
;
int
int_reduce_mul
=
0
;
int
int_reduce_div
=
0
;
int
int_reduce_max_or_min
=
0
;
int
int_broadcast
=
0
;
static
constexpr
int
kReduceBroadcastSize
=
10
;
/* Loop type features */
// A TODO maybe add loop position (Inner, Outer, Middle) feature
ForOptimizeFeatureEnum
loop_opt_type
=
ForOptimizeFeatureEnum
::
kNone
;
static
constexpr
int
kOptApplySize
=
5
;
/* Thread features if loop is optimized by GPU or CPU parallelism.
* Useless in other cases.
*/
int
len_blockIdx_x
=
0
;
int
len_blockIdx_y
=
0
;
int
len_blockIdx_z
=
0
;
int
len_threadIdx_x
=
0
;
int
len_threadIdx_y
=
0
;
int
len_threadIdx_z
=
0
;
int
len_vthread
=
0
;
// length of virtual thread
int
vectorize_factor
=
0
;
static
constexpr
int
kThreadFeatureSize
=
8
;
static
constexpr
int
kTotalSize
=
kArithSize
+
kMemSize
+
kReduceBroadcastSize
+
kOptApplySize
+
kThreadFeatureSize
;
/* Non-feature attributes, used to maintain during feature_extractor */
// Number to indicate the loop block inside current one
int
num_sub_loops
=
0
;
// Number of repeats of this loop, -1 represents unknown
int
loop_length
=
1
;
};
/**
* Feature of Expr. It is used in CostModel
*/
class
Feature
{
public:
Feature
();
explicit
Feature
(
const
common
::
Target
&
target
);
// Convert the various-length loop block features to fixed-size vector
std
::
vector
<
float
>
ToFixedSizeVector
();
// Call when visit into a loop block to collect LoopBlockFeature
void
IntoLoopBlock
();
// Call when exit a loop block to collect LoopBlockFeature
void
ExitLoopBlock
();
// The current loop block which we should collect feature on
LoopBlockFeature
&
CurrentLoopBlock
();
// The current loop block which we should collect feature on
const
LoopBlockFeature
&
CurrentLoopBlock
()
const
;
private:
// We treat a computation feature to be encoded as variable-length vector.
// The root compute block is not a loop, but we treat it as a size-1 loop.
// Blocks are encoded like a stack. Each LoopBlockFeature contains a
// num_sub_loops to indicate the next level sub-loop-block it contains.
//
// For example, code like:
//
// some_compute_0
// loop1 {
// some_compute_1
// loop2 {
// some_compute_2
// }
// }
//
// loop3 {
// some_compute_3
// }
//
// We go through the code and push loops into stack, then the features are
// encoded as [loop_block_feature_0, loop_block_feature_1,
// loop_block_feature_2, loop_block_feature_3] where loop_block_feature_i
// stores the features of some_compute_i (such as number of arithmetic
// operations)
//
// loop_block_feature_0.num_sub_loops = 2
// loop_block_feature_1.num_sub_loops = 1
// loop_block_feature_2.num_sub_loops = 0
// loop_block_feature_3.num_sub_loops = 0
std
::
vector
<
LoopBlockFeature
>
stack_encoded_feature_
;
int
current_loop_block_index_
;
std
::
vector
<
int
>
parent_indices_
;
common
::
Target
target_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/feature_extractor.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature_extractor.h"
#include <vector>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/common/type.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "paddle/cinn/optim/transform_polyfor_to_for.h"
namespace
cinn
{
namespace
auto_schedule
{
using
namespace
::
cinn
::
ir
;
// NOLINT
FeatureExtractor
::
FeatureExtractor
()
{}
void
FeatureExtractor
::
Visit
(
const
Expr
*
x
)
{
IRVisitorRequireReImpl
::
Visit
(
x
);
}
Feature
FeatureExtractor
::
Extract
(
const
ir
::
ModuleExpr
&
mod_expr
,
const
common
::
Target
&
target
)
{
feature_
=
Feature
(
target
);
for
(
const
ir
::
Expr
&
e
:
mod_expr
.
GetExprs
())
{
Visit
(
&
e
);
}
return
feature_
;
}
#define VisitDoNothing(NodeType) \
void FeatureExtractor::Visit(const NodeType *x) { \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitDoNothing
(
IntImm
);
VisitDoNothing
(
UIntImm
);
VisitDoNothing
(
FloatImm
);
VisitDoNothing
(
StringImm
);
VisitDoNothing
(
Block
);
VisitDoNothing
(
_Module_
);
VisitDoNothing
(
_Var_
);
VisitDoNothing
(
_LoweredFunc_
);
VisitDoNothing
(
ScheduleBlock
);
VisitDoNothing
(
ScheduleBlockRealize
);
VisitDoNothing
(
Ramp
);
VisitDoNothing
(
_Buffer_
);
VisitDoNothing
(
_BufferRange_
);
#define NotVisitExprFields(NodeType) \
void FeatureExtractor::Visit(const NodeType *x) {}
NotVisitExprFields
(
_Tensor_
)
#define VisitForDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += x->type().lanes(); \
} else { \
feature_.CurrentLoopBlock().int_##member += x->type().lanes(); \
} \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitForDtypePattern
(
Add
,
add_or_sub
);
VisitForDtypePattern
(
Sub
,
add_or_sub
);
VisitForDtypePattern
(
Minus
,
add_or_sub
);
VisitForDtypePattern
(
Mul
,
mul
);
VisitForDtypePattern
(
Div
,
div_or_mod
);
VisitForDtypePattern
(
Mod
,
div_or_mod
);
VisitForDtypePattern
(
FracOp
,
div_or_mod
);
VisitForDtypePattern
(
EQ
,
cmp
);
VisitForDtypePattern
(
NE
,
cmp
);
VisitForDtypePattern
(
GT
,
cmp
);
VisitForDtypePattern
(
GE
,
cmp
);
VisitForDtypePattern
(
LT
,
cmp
);
VisitForDtypePattern
(
LE
,
cmp
);
VisitForDtypePattern
(
Call
,
math_func
);
VisitForDtypePattern
(
PrimitiveNode
,
math_func
);
VisitForDtypePattern
(
Cast
,
other_call
);
VisitForDtypePattern
(
Let
,
other_call
);
#define VisitForMultiOperandsDtypePattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
if (x->type() == common::F32() || x->type() == common::F16() || \
x->type() == common::F64()) { \
feature_.CurrentLoopBlock().float_##member += \
(x->operands().size() - 1); \
} else { \
feature_.CurrentLoopBlock().int_##member += (x->operands().size() - 1); \
} \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitForMultiOperandsDtypePattern
(
Sum
,
add_or_sub
);
VisitForMultiOperandsDtypePattern
(
Product
,
mul
);
#define VisitCountMemberPattern(NodeType, member) \
void FeatureExtractor::Visit(const NodeType *x) { \
feature_.CurrentLoopBlock().member += 1; \
std::vector<const Expr *> sub_exprs = x->expr_fields(); \
for (const Expr *e : sub_exprs) { \
if (e->defined()) { \
Visit(e); \
} \
} \
}
VisitCountMemberPattern
(
And
,
bool_op
);
VisitCountMemberPattern
(
Or
,
bool_op
);
VisitCountMemberPattern
(
Not
,
bool_op
);
VisitCountMemberPattern
(
Max
,
select_op
);
VisitCountMemberPattern
(
Min
,
select_op
);
VisitCountMemberPattern
(
IfThenElse
,
select_op
);
VisitCountMemberPattern
(
Select
,
select_op
);
VisitCountMemberPattern
(
Alloc
,
mem_alloc
);
VisitCountMemberPattern
(
Free
,
mem_free
);
VisitCountMemberPattern
(
Load
,
mem_read
);
VisitCountMemberPattern
(
Store
,
mem_write
);
/* Visit for loops */
void
FeatureExtractor
::
Visit
(
const
For
*
x
)
{
feature_
.
IntoLoopBlock
();
LoopBlockFeature
&
loop_feature
=
feature_
.
CurrentLoopBlock
();
if
(
x
->
min
.
is_constant
()
&&
x
->
extent
.
is_constant
())
{
loop_feature
.
loop_length
=
(
x
->
extent
.
get_constant
()
-
x
->
min
.
get_constant
());
}
else
{
loop_feature
.
loop_length
=
-
1
;
// -1 represents unknown
}
if
(
x
->
is_parallel
())
{
loop_feature
.
loop_opt_type
=
ForOptimizeFeatureEnum
::
kParallel
;
loop_feature
.
len_vthread
=
loop_feature
.
loop_length
;
}
else
if
(
x
->
is_unrolled
())
{
loop_feature
.
loop_opt_type
=
ForOptimizeFeatureEnum
::
kUnroll
;
}
else
if
(
x
->
is_vectorized
())
{
loop_feature
.
loop_opt_type
=
ForOptimizeFeatureEnum
::
kVectorize
;
loop_feature
.
vectorize_factor
=
x
->
vectorize_info
().
factor
;
}
else
if
(
x
->
is_binded
())
{
loop_feature
.
loop_opt_type
=
ForOptimizeFeatureEnum
::
kGpuBind
;
const
BindInfo
&
bind_info
=
x
->
bind_info
();
int
offset
=
bind_info
.
offset
;
if
(
bind_info
.
for_type
==
ForType
::
GPUBlock
)
{
if
(
offset
==
0
)
{
loop_feature
.
len_blockIdx_x
=
loop_feature
.
loop_length
;
}
else
if
(
offset
==
1
)
{
loop_feature
.
len_blockIdx_y
=
loop_feature
.
loop_length
;
}
else
if
(
offset
==
2
)
{
loop_feature
.
len_blockIdx_z
=
loop_feature
.
loop_length
;
}
}
else
if
(
bind_info
.
for_type
==
ForType
::
GPUThread
)
{
if
(
offset
==
0
)
{
loop_feature
.
len_threadIdx_x
=
loop_feature
.
loop_length
;
}
else
if
(
offset
==
1
)
{
loop_feature
.
len_threadIdx_y
=
loop_feature
.
loop_length
;
}
else
if
(
offset
==
2
)
{
loop_feature
.
len_threadIdx_z
=
loop_feature
.
loop_length
;
}
}
}
std
::
vector
<
const
Expr
*>
sub_exprs
=
x
->
expr_fields
();
for
(
const
Expr
*
e
:
sub_exprs
)
{
Visit
(
e
);
}
feature_
.
ExitLoopBlock
();
}
void
FeatureExtractor
::
Visit
(
const
PolyFor
*
x
)
{
Expr
copy
=
optim
::
IRCopy
(
Expr
(
x
));
feature_
.
IntoLoopBlock
();
optim
::
TransformPolyForToFor
(
&
copy
);
ir
::
For
*
loop
=
copy
.
As
<
For
>
();
CHECK
(
loop
!=
nullptr
);
Visit
(
loop
);
feature_
.
ExitLoopBlock
();
}
/* Visit for Reduce and Broadcast */
void
FeatureExtractor
::
Visit
(
const
Reduce
*
x
)
{
if
(
x
->
type
()
==
common
::
F32
()
||
x
->
type
()
==
common
::
F16
()
||
x
->
type
()
==
common
::
F64
())
{
switch
(
x
->
reduce_type
)
{
case
Reduce
::
ReduceType
::
kSum
:
feature_
.
CurrentLoopBlock
().
float_reduce_sum_or_sub
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kSub
:
feature_
.
CurrentLoopBlock
().
float_reduce_sum_or_sub
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kDiv
:
feature_
.
CurrentLoopBlock
().
float_reduce_div
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kMul
:
feature_
.
CurrentLoopBlock
().
float_reduce_mul
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kMax
:
feature_
.
CurrentLoopBlock
().
float_reduce_max_or_min
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kMin
:
feature_
.
CurrentLoopBlock
().
float_reduce_max_or_min
+=
x
->
type
().
lanes
();
break
;
}
}
else
{
switch
(
x
->
reduce_type
)
{
case
Reduce
::
ReduceType
::
kSum
:
feature_
.
CurrentLoopBlock
().
int_reduce_sum_or_sub
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kSub
:
feature_
.
CurrentLoopBlock
().
int_reduce_sum_or_sub
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kDiv
:
feature_
.
CurrentLoopBlock
().
int_reduce_div
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kMul
:
feature_
.
CurrentLoopBlock
().
int_reduce_mul
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kMax
:
feature_
.
CurrentLoopBlock
().
int_reduce_max_or_min
+=
x
->
type
().
lanes
();
break
;
case
Reduce
::
ReduceType
::
kMin
:
feature_
.
CurrentLoopBlock
().
int_reduce_max_or_min
+=
x
->
type
().
lanes
();
break
;
}
}
std
::
vector
<
const
Expr
*>
sub_exprs
=
x
->
expr_fields
();
for
(
const
Expr
*
e
:
sub_exprs
)
{
Visit
(
e
);
}
}
VisitForDtypePattern
(
Broadcast
,
broadcast
);
/* Visit for IntrinsicOp */
void
FeatureExtractor
::
Visit
(
const
IntrinsicOp
*
x
)
{
switch
(
x
->
getKind
())
{
#define __(op__) \
case IntrinsicKind::k##op__: \
Visit(llvm::dyn_cast<intrinsics::op__>(x)); \
break;
INTRINSIC_KIND_FOR_EACH
(
__
)
#undef __
}
}
VisitDoNothing
(
intrinsics
::
BufferGetDataHandle
);
VisitDoNothing
(
intrinsics
::
BufferGetDataConstHandle
);
VisitDoNothing
(
intrinsics
::
PodValueToX
);
VisitDoNothing
(
intrinsics
::
BufferCreate
);
VisitDoNothing
(
intrinsics
::
GetAddr
);
VisitDoNothing
(
intrinsics
::
ArgsConstruct
);
VisitForDtypePattern
(
intrinsics
::
BuiltinIntrin
,
other_call
)
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/feature_extractor.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_visitor.h"
namespace
cinn
{
namespace
auto_schedule
{
class
FeatureExtractor
:
public
ir
::
IRVisitorRequireReImpl
<
void
>
{
public:
FeatureExtractor
();
Feature
Extract
(
const
ir
::
ModuleExpr
&
mod_expr
,
const
common
::
Target
&
target
);
void
Visit
(
const
Expr
*
x
)
override
;
#define __(op__) void Visit(const ir::op__* x) override;
NODETY_FORALL
(
__
)
#undef __
#define __(op__) virtual void Visit(const ir::intrinsics::op__* x);
INTRINSIC_KIND_FOR_EACH
(
__
)
#undef __
private:
Feature
feature_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/feature_extractor_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature_extractor.h"
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <cmath>
#include <unordered_set>
#include <vector>
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/ir_base.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/lang/builtin.h"
#include "paddle/cinn/lang/compute.h"
#include "paddle/cinn/lang/lower.h"
#include "paddle/cinn/lang/placeholder.h"
#include "paddle/cinn/poly/stage.h"
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
FeatureExtractor
,
SimpleAssign
)
{
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
ir
::
Expr
M
(
32
);
ir
::
Expr
N
(
32
);
lang
::
Placeholder
<
float
>
A
(
"A"
,
{
M
,
N
});
ir
::
Tensor
B
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
);
},
"B"
);
poly
::
StageMap
stages
=
poly
::
CreateStages
({
A
,
B
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"SimpleAssign"
,
stages
,
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
ir
::
Expr
ast_expr
=
funcs
[
0
]
->
body
;
VLOG
(
6
)
<<
"Expr to test: "
<<
ast_expr
;
std
::
vector
<
Expr
>
vec_ast
{
ast_expr
};
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
FeatureExtractor
extractor
;
Feature
feature
=
extractor
.
Extract
(
mod_expr
,
target
);
std
::
vector
<
float
>
to_check
=
feature
.
ToFixedSizeVector
();
ASSERT_EQ
(
to_check
.
size
(),
static_cast
<
size_t
>
(
LoopBlockFeature
::
kTotalSize
+
1
));
VLOG
(
6
)
<<
"Feature data before slog:"
;
for
(
size_t
i
=
0
;
i
<
to_check
.
size
();
++
i
)
{
VLOG
(
6
)
<<
i
<<
" "
<<
(
std
::
pow
(
2
,
to_check
[
i
])
-
1
);
if
(
i
!=
0
&&
i
!=
17
&&
i
!=
18
&&
i
!=
29
)
{
ASSERT_EQ
(
to_check
[
i
],
0
);
}
}
// target
#ifdef CINN_WITH_CUDA
ASSERT_EQ
(
to_check
[
0
],
1
);
#else
ASSERT_EQ
(
to_check
[
0
],
0
);
#endif
// mem_read
ASSERT_EQ
(
to_check
[
17
],
slog
(
M
.
get_constant
()
*
N
.
get_constant
()));
// mem_read
// mem_write
ASSERT_EQ
(
to_check
[
18
],
slog
(
M
.
get_constant
()
*
N
.
get_constant
()));
// mem_write
// non-opt loops, including root block
ASSERT_EQ
(
to_check
[
29
],
slog
(
3
));
}
TEST
(
FeatureExtractor
,
MatrixMultiply
)
{
Context
::
Global
().
ResetNameId
();
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
ir
::
Expr
M
(
2
);
ir
::
Expr
N
(
2
);
ir
::
Expr
K
(
4
);
lang
::
Placeholder
<
float
>
A
(
"A"
,
{
M
,
K
});
lang
::
Placeholder
<
float
>
B
(
"B"
,
{
K
,
N
});
ir
::
Var
k
(
K
.
as_int32
(),
"reduce_axis_k"
);
ir
::
Tensor
C
=
lang
::
Compute
(
{
M
,
N
},
[
&
](
Var
i
,
Var
j
)
{
return
lang
::
ReduceSum
(
A
(
i
,
k
)
*
B
(
k
,
j
),
{
k
});
},
"C"
);
poly
::
StageMap
stages
=
poly
::
CreateStages
({
C
});
std
::
vector
<
ir
::
LoweredFunc
>
funcs
=
lang
::
LowerVec
(
"MatrixMultiply"
,
stages
,
{
C
},
{},
{},
nullptr
,
target
,
true
);
std
::
vector
<
Expr
>
vec_ast
{
funcs
[
0
]
->
body
};
ir
::
ModuleExpr
mod_expr
(
vec_ast
);
ir
::
IRSchedule
ir_sch
(
mod_expr
);
std
::
vector
<
ir
::
Expr
>
blocks
=
ir_sch
.
GetAllBlocks
();
std
::
vector
<
ir
::
Expr
>
loops
=
ir_sch
.
GetLoops
(
blocks
[
0
]);
ir_sch
.
Bind
(
loops
.
back
(),
"threadIdx.x"
);
ir
::
Expr
ast_expr
=
mod_expr
.
GetExprs
()[
0
];
VLOG
(
6
)
<<
"Expr to test: "
<<
ast_expr
;
FeatureExtractor
extractor
;
Feature
feature
=
extractor
.
Extract
(
mod_expr
,
target
);
std
::
vector
<
float
>
to_check
=
feature
.
ToFixedSizeVector
();
ASSERT_EQ
(
to_check
.
size
(),
static_cast
<
size_t
>
(
LoopBlockFeature
::
kTotalSize
+
1
));
std
::
unordered_set
<
size_t
>
non_zero_indice
=
{
0
,
1
,
2
,
17
,
18
,
29
,
30
,
37
};
for
(
size_t
i
=
0
;
i
<
to_check
.
size
();
++
i
)
{
VLOG
(
6
)
<<
i
<<
" "
<<
(
std
::
pow
(
2
,
to_check
[
i
])
-
1
);
if
(
!
non_zero_indice
.
count
(
i
))
{
ASSERT_EQ
(
to_check
[
i
],
0
);
}
}
// target
#ifdef CINN_WITH_CUDA
ASSERT_EQ
(
to_check
[
0
],
1
);
#else
ASSERT_EQ
(
to_check
[
0
],
0
);
#endif
float
out_loop
=
M
.
get_constant
()
*
N
.
get_constant
();
float
total_loop
=
out_loop
*
K
.
get_constant
();
// float_mul
ASSERT_EQ
(
to_check
[
1
],
slog
(
total_loop
));
// float_add_or_sub
ASSERT_EQ
(
to_check
[
2
],
slog
(
total_loop
));
// mem_read
ASSERT_EQ
(
to_check
[
17
],
slog
(
total_loop
*
3
));
// mem_write
ASSERT_EQ
(
to_check
[
18
],
slog
(
total_loop
+
out_loop
));
// non-opt loops, including root block
ASSERT_EQ
(
to_check
[
29
],
slog
(
3
));
// GpuBind loop
ASSERT_EQ
(
to_check
[
30
],
slog
(
1
));
// GpuBind loop
ASSERT_EQ
(
to_check
[
37
],
slog
(
out_loop
));
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/feature_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/feature.h"
#include <gtest/gtest.h>
#include <pybind11/embed.h>
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
Feature
,
Basic
)
{
// TODO(zhhsplendid): add some basic tests
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/xgb_cost_model.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include <dirent.h>
#include <glog/logging.h>
#include <pybind11/embed.h>
#include <pybind11/numpy.h>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <atomic>
#include <cassert>
#include <cstring>
#include <iostream>
#include <memory>
#include <mutex>
#include <regex>
#include <string>
#include <utility>
#include <vector>
#include "paddle/cinn/common/python_interpreter_guard.h"
namespace
cinn
{
namespace
auto_schedule
{
std
::
atomic
<
int
>
XgbCostModel
::
xgb_cost_model_count_
(
0
);
// Convert 1D vector to py numpy
template
<
typename
Dtype
>
pybind11
::
array
VectorToNumpy
(
const
std
::
vector
<
Dtype
>&
vec
)
{
return
pybind11
::
array
(
pybind11
::
cast
(
vec
));
}
// Convert 2D vector to py numpy
template
<
typename
Dtype
>
pybind11
::
array
VectorToNumpy
(
const
std
::
vector
<
std
::
vector
<
Dtype
>>&
vec
)
{
if
(
vec
.
size
()
==
0
)
{
return
pybind11
::
array
(
pybind11
::
dtype
::
of
<
Dtype
>
(),
{
0
,
0
});
}
std
::
vector
<
size_t
>
shape
{
vec
.
size
(),
vec
[
0
].
size
()};
pybind11
::
array
ret
(
pybind11
::
dtype
::
of
<
Dtype
>
(),
shape
);
Dtype
*
py_data
=
static_cast
<
Dtype
*>
(
ret
.
mutable_data
());
for
(
size_t
i
=
0
;
i
<
vec
.
size
();
++
i
)
{
assert
(
vec
[
i
].
size
()
==
shape
[
1
]
&&
"Sub vectors must have same size in VectorToNumpy"
);
memcpy
(
py_data
+
(
shape
[
1
]
*
i
),
vec
[
i
].
data
(),
shape
[
1
]
*
sizeof
(
Dtype
));
}
return
ret
;
}
// the Pybind default Python interpreter doesn't contain some paths in
// sys.path, so we have to add it.
//
// Note: the Pybind default Python interpreter only uses default Python.
// Something may be wrong when users use virtual Python environment.
void
AddDistPkgToPythonSysPath
()
{
pybind11
::
module
sys_py_mod
=
pybind11
::
module
::
import
(
"sys"
);
// short version such as "3.7", "3.8", ...
std
::
string
py_short_version
=
sys_py_mod
.
attr
(
"version"
).
cast
<
std
::
string
>
().
substr
(
0
,
3
);
std
::
string
site_pkg_str
=
"/usr/local/lib/python"
+
py_short_version
+
"/dist-packages"
;
sys_py_mod
.
attr
(
"path"
).
attr
(
"append"
)(
site_pkg_str
);
// TODO(zhhsplendid): warning to users if setuptools hasn't been installed
DIR
*
site_pkg_dir
=
opendir
(
site_pkg_str
.
c_str
());
if
(
site_pkg_dir
!=
nullptr
)
{
std
::
regex
setuptool_regex
(
"setuptools-.*-py"
+
py_short_version
+
"
\\
.egg"
);
struct
dirent
*
entry
=
nullptr
;
while
((
entry
=
readdir
(
site_pkg_dir
))
!=
nullptr
)
{
if
(
std
::
regex_match
(
entry
->
d_name
,
setuptool_regex
))
{
sys_py_mod
.
attr
(
"path"
).
attr
(
"append"
)(
site_pkg_str
+
"/"
+
entry
->
d_name
);
}
}
closedir
(
site_pkg_dir
);
}
}
XgbCostModel
::
XgbCostModel
()
{
common
::
PythonInterpreterGuard
::
Guard
();
int
previous
=
xgb_cost_model_count_
.
fetch_add
(
1
);
if
(
previous
==
0
)
{
AddDistPkgToPythonSysPath
();
}
xgb_module_
=
pybind11
::
module
::
import
(
"xgboost"
);
xgb_booster_
=
xgb_module_
.
attr
(
"Booster"
)();
}
void
XgbCostModel
::
Train
(
const
std
::
vector
<
std
::
vector
<
float
>>&
samples
,
const
std
::
vector
<
float
>&
labels
)
{
update_samples_
=
samples
;
update_labels_
=
labels
;
pybind11
::
array
np_samples
=
VectorToNumpy
<
float
>
(
samples
);
pybind11
::
array
np_labels
=
VectorToNumpy
<
float
>
(
labels
);
pybind11
::
object
dmatrix
=
xgb_module_
.
attr
(
"DMatrix"
)(
np_samples
,
np_labels
);
xgb_booster_
=
xgb_module_
.
attr
(
"train"
)(
pybind11
::
dict
(),
dmatrix
,
pybind11
::
int_
(
kTrainRound_
));
}
std
::
vector
<
float
>
XgbCostModel
::
Predict
(
const
std
::
vector
<
std
::
vector
<
float
>>&
samples
)
const
{
pybind11
::
array
np_samples
=
VectorToNumpy
<
float
>
(
samples
);
pybind11
::
object
dmatrix
=
xgb_module_
.
attr
(
"DMatrix"
)(
np_samples
);
pybind11
::
array
py_result
=
xgb_booster_
.
attr
(
"predict"
)(
dmatrix
);
return
py_result
.
cast
<
std
::
vector
<
float
>>
();
}
void
XgbCostModel
::
Update
(
const
std
::
vector
<
std
::
vector
<
float
>>&
samples
,
const
std
::
vector
<
float
>&
labels
)
{
update_samples_
.
insert
(
update_samples_
.
end
(),
samples
.
begin
(),
samples
.
end
());
update_labels_
.
insert
(
update_labels_
.
end
(),
labels
.
begin
(),
labels
.
end
());
pybind11
::
array
np_samples
=
VectorToNumpy
<
float
>
(
update_samples_
);
pybind11
::
array
np_labels
=
VectorToNumpy
<
float
>
(
update_labels_
);
pybind11
::
object
dmatrix
=
xgb_module_
.
attr
(
"DMatrix"
)(
np_samples
,
np_labels
);
xgb_booster_
=
xgb_module_
.
attr
(
"train"
)(
pybind11
::
dict
(),
dmatrix
,
pybind11
::
int_
(
kTrainRound_
));
}
void
XgbCostModel
::
Save
(
const
std
::
string
&
path
)
{
xgb_booster_
.
attr
(
"save_model"
)(
pybind11
::
str
(
path
));
}
void
XgbCostModel
::
Load
(
const
std
::
string
&
path
)
{
xgb_booster_
.
attr
(
"load_model"
)(
pybind11
::
str
(
path
));
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <pybind11/embed.h>
#include <atomic>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "paddle/cinn/common/cost_model.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* A C++ cost model which calls Python xgboost via pybind
*
* Note: this class handles Python interpreter life time in class.
* If you have to call other Python functions out of this class so that meet
* life time conflict, you can check cinn::common::PythonInterpreterGuard
*
* For cinn::common::PythonInterpreterGuard, see:
* cinn/common/python_interpreter_guard.h .cc
*
* For pybind interpreter lifetime management, see:
*
* https://pybind11.readthedocs.io/en/stable/advanced/embedding.html#interpreter-lifetime
* https://pybind11.readthedocs.io/en/stable/reference.html#_CPPv422initialize_interpreterbiPPCKcb
*/
class
XgbCostModel
:
public
CostModel
{
public:
XgbCostModel
();
~
XgbCostModel
()
=
default
;
void
Train
(
const
std
::
vector
<
std
::
vector
<
float
>>&
samples
,
const
std
::
vector
<
float
>&
labels
)
override
;
std
::
vector
<
float
>
Predict
(
const
std
::
vector
<
std
::
vector
<
float
>>&
samples
)
const
override
;
void
Update
(
const
std
::
vector
<
std
::
vector
<
float
>>&
samples
,
const
std
::
vector
<
float
>&
labels
)
override
;
void
Save
(
const
std
::
string
&
path
)
override
;
void
Load
(
const
std
::
string
&
path
)
override
;
private:
// Python xgboost module
pybind11
::
module
xgb_module_
;
// Object points to Python xgb.Booster()
pybind11
::
object
xgb_booster_
;
// atomic int to handle python interpreter lifetime and package dependency
static
std
::
atomic
<
int
>
xgb_cost_model_count_
;
// Default train rounds
static
constexpr
int
kTrainRound_
=
10
;
std
::
vector
<
std
::
vector
<
float
>>
update_samples_
;
std
::
vector
<
float
>
update_labels_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/cost_model/xgb_cost_model_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/cost_model/xgb_cost_model.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <pybind11/embed.h>
#include <cstdio>
#include <cstdlib>
#include <ctime>
#include <vector>
namespace
cinn
{
namespace
auto_schedule
{
TEST
(
CostModel
,
Basic
)
{
XgbCostModel
cost_model
;
srand
(
time
(
NULL
));
int
batch_size
=
16
;
int
feature_size
=
8
;
std
::
vector
<
float
>
labels
(
batch_size
,
1.0
);
std
::
vector
<
std
::
vector
<
float
>>
samples
(
batch_size
,
std
::
vector
<
float
>
(
feature_size
));
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
feature_size
;
++
j
)
{
samples
[
i
][
j
]
=
rand
()
%
10
;
// NOLINT
}
}
cost_model
.
Train
(
samples
,
labels
);
std
::
vector
<
float
>
pred
=
cost_model
.
Predict
(
samples
);
std
::
string
path
=
"./test_cost_model.cpp_save_model"
;
cost_model
.
Save
(
path
);
XgbCostModel
load_cost_model
;
load_cost_model
.
Load
(
path
);
std
::
vector
<
float
>
load_pred
=
cost_model
.
Predict
(
samples
);
ASSERT_EQ
(
pred
.
size
(),
load_pred
.
size
());
for
(
size_t
i
=
0
;
i
<
pred
.
size
();
++
i
)
{
ASSERT_FLOAT_EQ
(
pred
[
i
],
load_pred
[
i
]);
VLOG
(
6
)
<<
"pred["
<<
i
<<
"] = "
<<
pred
[
i
];
}
std
::
remove
(
path
.
c_str
());
cost_model
.
Update
(
samples
,
labels
);
pred
=
cost_model
.
Predict
(
samples
);
for
(
size_t
i
=
0
;
i
<
pred
.
size
();
++
i
)
{
VLOG
(
6
)
<<
"pred["
<<
i
<<
"] = "
<<
pred
[
i
];
}
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/database/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS database.cc jsonfile_database.cc
)
cinn_cc_test
(
test_database SRCS database_test.cc DEPS cinncore
)
cinn_cc_test
(
test_jsonfile_database SRCS jsonfile_database_test.cc DEPS
cinncore
)
paddle/cinn/auto_schedule/database/database.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/database.h"
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/util/json_util.h>
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
namespace
cinn
{
namespace
auto_schedule
{
bool
TuningRecord
::
Compare
::
operator
()(
const
TuningRecord
&
lhs
,
const
TuningRecord
&
rhs
)
const
{
return
lhs
.
execution_cost
<
rhs
.
execution_cost
;
}
proto
::
TuningRecord
TuningRecord
::
ToProto
()
const
{
proto
::
TuningRecord
record_proto
;
record_proto
.
set_task_key
(
task_key
);
record_proto
.
set_execution_cost
(
execution_cost
);
record_proto
.
set_predicted_cost
(
predicted_cost
);
record_proto
.
mutable_trace
()
->
CopyFrom
(
trace
);
return
record_proto
;
}
Database
::
Database
(
int
capacity_per_task
)
:
capacity_per_task_
(
capacity_per_task
)
{
CHECK_GT
(
capacity_per_task_
,
0
)
<<
"capacity_per_task_ should be greater than 0"
;
}
std
::
unique_ptr
<
Database
>
Database
::
Make
(
const
DatabaseConfig
&
config
)
{
if
(
config
.
type
==
DatabaseType
::
kMemory
)
{
return
std
::
make_unique
<
Database
>
(
config
.
capacity_per_task
);
}
else
if
(
config
.
type
==
DatabaseType
::
kJSONFile
)
{
return
std
::
make_unique
<
JSONFileDatabase
>
(
config
.
capacity_per_task
,
config
.
record_file_path
,
true
);
}
LOG
(
FATAL
)
<<
"Unimplemented database type."
;
return
nullptr
;
}
void
Database
::
Insert
(
const
TuningRecord
&
record
)
{
auto
&
records
=
key2record_
[
record
.
task_key
];
records
.
emplace
(
record
);
if
(
records
.
size
()
>
capacity_per_task_
)
{
records
.
erase
(
std
::
prev
(
records
.
end
()));
}
}
bool
Database
::
AddRecord
(
const
TuningRecord
&
record
)
{
CHECK
(
!
record
.
task_key
.
empty
())
<<
"task_key of TuningRecord can't be empty"
;
Insert
(
record
);
return
Commit
(
record
);
}
std
::
vector
<
TuningRecord
>
Database
::
LookUp
(
const
std
::
string
&
task_key
)
{
auto
fit
=
key2record_
.
find
(
task_key
);
if
(
fit
==
key2record_
.
end
())
{
return
{};
}
std
::
vector
<
TuningRecord
>
results
;
results
.
reserve
(
fit
->
second
.
size
());
results
.
assign
(
fit
->
second
.
begin
(),
fit
->
second
.
end
());
return
results
;
}
std
::
vector
<
TuningRecord
>
Database
::
GetTopK
(
const
std
::
string
&
task_key
,
int
k
)
{
auto
fit
=
key2record_
.
find
(
task_key
);
if
(
fit
==
key2record_
.
end
()
||
k
<=
0
)
{
return
{};
}
if
(
k
>
capacity_per_task_
)
{
LOG
(
WARNING
)
<<
"Top k="
<<
k
<<
" is greater than the capacity, will adjust k="
<<
capacity_per_task_
;
k
=
capacity_per_task_
;
}
std
::
vector
<
TuningRecord
>
results
;
results
.
reserve
(
k
);
for
(
const
TuningRecord
&
record
:
fit
->
second
)
{
results
.
emplace_back
(
record
);
if
(
results
.
size
()
==
k
)
{
break
;
}
}
return
results
;
}
size_t
Database
::
Size
()
{
auto
res
=
std
::
accumulate
(
key2record_
.
begin
(),
key2record_
.
end
(),
size_t
(
0
),
[](
size_t
res
,
const
auto
&
kv
)
->
size_t
{
return
std
::
move
(
res
)
+
kv
.
second
.
size
();
});
return
res
;
}
size_t
Database
::
Count
(
const
std
::
string
&
task_key
)
{
auto
fit
=
key2record_
.
find
(
task_key
);
if
(
fit
==
key2record_
.
end
())
{
return
0
;
}
return
fit
->
second
.
size
();
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/database/database.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <unordered_map>
#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/ir/schedule/schedule_desc.pb.h"
namespace
cinn
{
namespace
auto_schedule
{
// Record related data about tuning process of a measure candidate
struct
TuningRecord
{
// the unique key to identify a task
std
::
string
task_key
;
// the predicted cost of CostModel
float
predicted_cost
;
// unit: us
// the ScheduleDesc of this tuning process
ir
::
proto
::
ScheduleDesc
trace
;
// the cost time of the candidate executed during measure
double
execution_cost
;
// unit: us
TuningRecord
()
=
default
;
explicit
TuningRecord
(
const
proto
::
TuningRecord
&
record
)
:
task_key
(
record
.
task_key
()),
predicted_cost
(
record
.
predicted_cost
()),
trace
(
record
.
trace
()),
execution_cost
(
record
.
execution_cost
())
{}
TuningRecord
(
const
std
::
string
&
task_key
,
const
SearchState
&
state
,
double
execution_cost
)
:
task_key
(
task_key
),
predicted_cost
(
state
->
predicted_cost
),
trace
(
state
->
ir_schedule
.
GetTraceDesc
().
ToProto
()),
execution_cost
(
execution_cost
)
{}
// convert to proto object
proto
::
TuningRecord
ToProto
()
const
;
// a binary compare function that denotes when the left
// will be sorted in the front of the right
struct
Compare
{
bool
operator
()(
const
TuningRecord
&
lhs
,
const
TuningRecord
&
rhs
)
const
;
};
};
enum
class
DatabaseType
:
int
{
kMemory
,
kJSONFile
};
struct
DatabaseConfig
{
DatabaseType
type
=
DatabaseType
::
kMemory
;
int
capacity_per_task
=
2
;
std
::
string
record_file_path
=
"/tmp/tuning_record.json"
;
};
// A database supports insert or lookup historial tuning result with specified
// traits. It can be implemented with a concrete storage to save/load underlying
// data, such as memory, file, database server and so on, this base class can be
// regarded as one using memory as its underlying storage medium.
class
Database
{
public:
explicit
Database
(
int
capacity_per_task
);
~
Database
()
=
default
;
// Create a Database with the specific config
static
std
::
unique_ptr
<
Database
>
Make
(
const
DatabaseConfig
&
config
);
// add a record into the database
bool
AddRecord
(
const
TuningRecord
&
record
);
// return all records whose task_keys are equal to the specified key
std
::
vector
<
TuningRecord
>
LookUp
(
const
std
::
string
&
task_key
);
// return the states of the top k in sorted candidates
std
::
vector
<
TuningRecord
>
GetTopK
(
const
std
::
string
&
task_key
,
int
k
);
// return the total number of stored candidates
size_t
Size
();
// return the number of stored candidates with specified key
size_t
Count
(
const
std
::
string
&
task_key
);
protected:
// commit the newly added record into underlying storage
virtual
bool
Commit
(
const
TuningRecord
&
record
)
{
return
true
;
}
// insert a newly added record into memory storage
void
Insert
(
const
TuningRecord
&
record
);
// map task_key to its records
std
::
unordered_map
<
std
::
string
,
std
::
multiset
<
TuningRecord
,
TuningRecord
::
Compare
>>
key2record_
;
// the max number of candidates stored
const
int
capacity_per_task_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/database/database_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/database.h"
#include <gtest/gtest.h>
#include <vector>
#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
class
TestDatabase
:
public
::
testing
::
Test
{
public:
TestDatabase
()
:
test_db
(
2
)
{
auto
state
=
SearchState
(
ir
::
IRSchedule
());
test_db
.
AddRecord
(
TuningRecord
(
"k1"
,
state
,
1.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
state
,
2.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
state
,
3.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
state
,
3.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
state
,
4.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
state
,
5.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
state
,
4.0
));
}
void
SetUp
()
override
{}
Database
test_db
;
};
TEST_F
(
TestDatabase
,
Basic
)
{
ASSERT_EQ
(
test_db
.
Size
(),
6
);
auto
records
=
test_db
.
LookUp
(
"k3"
);
// check the max number of stored candidates will
// be restricted to capacity_per_task
ASSERT_EQ
(
test_db
.
Count
(
"k3"
),
2
);
ASSERT_EQ
(
records
.
size
(),
2
);
EXPECT_EQ
(
records
[
0
].
execution_cost
,
3.0
);
EXPECT_EQ
(
records
[
1
].
execution_cost
,
4.0
);
}
TEST_F
(
TestDatabase
,
GetTopK
)
{
ASSERT_TRUE
(
test_db
.
GetTopK
(
"k5"
,
2
).
empty
());
ASSERT_EQ
(
test_db
.
GetTopK
(
"k4"
,
3
).
size
(),
1
);
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
SearchState
(
ir
::
IRSchedule
(),
1.2
),
2.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
SearchState
(
ir
::
IRSchedule
(),
1.0
),
3.0
));
auto
records
=
test_db
.
GetTopK
(
"k4"
,
3
);
ASSERT_EQ
(
records
.
size
(),
2
);
EXPECT_FLOAT_EQ
(
records
[
0
].
predicted_cost
,
1.2
);
EXPECT_FLOAT_EQ
(
records
[
1
].
predicted_cost
,
1.0
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/database/jsonfile_database.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include <google/protobuf/message.h>
#include <google/protobuf/text_format.h>
#include <google/protobuf/util/json_util.h>
#include <fstream>
#include "paddle/cinn/auto_schedule/auto_schedule.pb.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/utils/multi_threading.h"
namespace
cinn
{
namespace
auto_schedule
{
// append a line to file
void
AppendLineToFile
(
const
std
::
string
&
file_path
,
const
std
::
string
&
line
)
{
std
::
ofstream
os
(
file_path
,
std
::
ofstream
::
app
);
CHECK
(
os
.
good
())
<<
"Cannot open the file to write: "
<<
file_path
;
os
<<
line
<<
std
::
endl
;
}
// read lines from a json file
std
::
vector
<
std
::
string
>
ReadLinesFromFile
(
const
std
::
string
&
file_path
,
bool
allow_new_file
)
{
std
::
ifstream
is
(
file_path
);
if
(
is
.
good
())
{
std
::
vector
<
std
::
string
>
json_strs
;
for
(
std
::
string
str
;
std
::
getline
(
is
,
str
);)
{
json_strs
.
push_back
(
str
);
}
return
json_strs
;
}
CHECK
(
allow_new_file
)
<<
"File doesn't exist: "
<<
file_path
;
std
::
ofstream
os
(
file_path
);
CHECK
(
os
.
good
())
<<
"Cannot create new file: "
<<
file_path
;
return
{};
}
JSONFileDatabase
::
JSONFileDatabase
(
int
capacity_per_task
,
const
std
::
string
&
record_file_path
,
bool
allow_new_file
)
:
Database
(
capacity_per_task
),
record_file_path_
(
record_file_path
)
{
VLOG
(
3
)
<<
"Auto schedule will save/load tuning records on file:"
<<
record_file_path
;
auto
json_lines
=
ReadLinesFromFile
(
record_file_path_
,
allow_new_file
);
std
::
vector
<
cinn
::
auto_schedule
::
proto
::
TuningRecord
>
all_records_proto
(
json_lines
.
size
());
// convert JSON string to proto object
auto
worker_fn
=
[
this
,
&
json_lines
,
&
all_records_proto
](
int
index
)
{
cinn
::
auto_schedule
::
proto
::
TuningRecord
record_proto
;
auto
status
=
google
::
protobuf
::
util
::
JsonStringToMessage
(
json_lines
[
index
],
&
record_proto
);
CHECK
(
status
.
ok
())
<<
"Failed to parse JSON: "
<<
json_lines
[
index
];
all_records_proto
[
index
].
Swap
(
&
record_proto
);
};
utils
::
parallel_run
(
worker_fn
,
utils
::
SequenceDispatcher
(
0
,
json_lines
.
size
()),
-
1
);
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
for
(
const
auto
&
record_proto
:
all_records_proto
)
{
std
::
string
task_key
=
record_proto
.
task_key
();
if
(
task_registry
->
Has
(
task_key
))
{
VLOG
(
4
)
<<
"Add a measured TuningRecord with task_key="
<<
task_key
;
Insert
(
TuningRecord
(
record_proto
));
}
}
}
// convert a TuningRecord object to string in JSON format
std
::
string
JSONFileDatabase
::
RecordToJSON
(
const
TuningRecord
&
record
)
{
proto
::
TuningRecord
record_proto
=
record
.
ToProto
();
std
::
string
json_string
;
auto
status
=
google
::
protobuf
::
util
::
MessageToJsonString
(
record_proto
,
&
json_string
);
CHECK
(
status
.
ok
())
<<
"Failed to serialize record to JSON, task key = "
<<
record
.
task_key
;
VLOG
(
4
)
<<
"json_string =
\n
"
<<
json_string
;
return
json_string
;
}
bool
JSONFileDatabase
::
Commit
(
const
TuningRecord
&
record
)
{
std
::
string
json_string
=
RecordToJSON
(
record
);
AppendLineToFile
(
record_file_path_
,
json_string
);
return
true
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/database/jsonfile_database.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/database/database.h"
namespace
cinn
{
namespace
auto_schedule
{
// JSONFileDatabase is a database implemented by JSON file to save/load
// underlying data.
class
JSONFileDatabase
:
public
Database
{
public:
/*!
* \brief Build a JSONFileDatabase object from a json file.
* \param capacity_per_task The max number of candidates stored.
* \param record_file_path The path of the json file.
* \param allow_new_file Whether to create new file when the given path is not
* found.
*/
JSONFileDatabase
(
int
capacity_per_task
,
const
std
::
string
&
record_file_path
,
bool
allow_new_file
);
~
JSONFileDatabase
()
=
default
;
// convert a TuningRecord object to string in JSON format
std
::
string
RecordToJSON
(
const
TuningRecord
&
record
);
protected:
// commit the newly added record into json file
bool
Commit
(
const
TuningRecord
&
record
)
override
;
// the name of the json file to save tuning records.
std
::
string
record_file_path_
;
};
// append a line to file
void
AppendLineToFile
(
const
std
::
string
&
file_path
,
const
std
::
string
&
line
);
// read lines from a json file
std
::
vector
<
std
::
string
>
ReadLinesFromFile
(
const
std
::
string
&
file_path
,
bool
allow_new_file
=
true
);
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/database/jsonfile_database_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/database/jsonfile_database.h"
#include <google/protobuf/util/message_differencer.h>
#include <gtest/gtest.h>
#include <fstream>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/auto_schedule/task/task_registry.h"
#include "paddle/cinn/cinn.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
auto_schedule
{
// Return lowerd ir AST for example functions used in this test
std
::
vector
<
ir
::
LoweredFunc
>
LowerCompute
(
const
std
::
vector
<
int
>&
shape
,
const
Target
&
target
)
{
CHECK
(
shape
.
size
()
==
2
)
<<
"shape should be 2"
;
std
::
vector
<
Expr
>
domain
;
for
(
auto
i
=
0
;
i
<
shape
.
size
();
++
i
)
{
domain
.
emplace_back
(
shape
[
i
]);
}
Placeholder
<
float
>
A
(
"A"
,
domain
);
ir
::
Tensor
B
,
C
;
B
=
Compute
(
domain
,
[
&
A
](
Var
i
,
Var
j
)
{
return
A
(
i
,
j
);
},
"B"
);
C
=
Compute
(
domain
,
[
&
B
](
Var
i
,
Var
j
)
{
return
B
(
i
,
j
);
},
"C"
);
return
cinn
::
lang
::
LowerVec
(
"test_func"
,
CreateStages
({
A
,
B
}),
{
A
,
B
},
{},
{},
nullptr
,
target
,
true
);
}
// Create a new IRSchedule with copied ir::LoweredFunc AST
ir
::
IRSchedule
MakeIRSchedule
(
const
std
::
vector
<
ir
::
LoweredFunc
>&
lowered_funcs
,
const
std
::
string
&
task_key
)
{
std
::
vector
<
Expr
>
exprs
;
for
(
auto
&&
func
:
lowered_funcs
)
{
exprs
.
emplace_back
(
optim
::
IRCopy
(
func
->
body
));
}
InitialTaskRegistry
*
task_registry
=
InitialTaskRegistry
::
Global
();
task_registry
->
Regist
(
task_key
,
ir
::
ModuleExpr
(
exprs
));
return
ir
::
IRSchedule
(
ir
::
ModuleExpr
(
exprs
));
}
class
TestJSONFileDatabase
:
public
::
testing
::
Test
{
public:
TestJSONFileDatabase
()
:
record_file_path
(
"/tmp/test_record.json"
),
test_db
(
2
,
record_file_path
,
true
)
{}
void
SetUp
()
override
{
lowered_funcs
=
LowerCompute
({
32
,
32
},
target
);
}
void
TearDown
()
override
{
auto
isFileExists
=
[](
const
std
::
string
&
file_path
)
->
bool
{
std
::
ifstream
f
(
file_path
.
c_str
());
return
f
.
good
();
};
if
(
isFileExists
(
record_file_path
))
{
if
(
remove
(
record_file_path
.
c_str
())
==
0
)
{
LOG
(
INFO
)
<<
"Successfully deleted file: "
<<
record_file_path
;
}
else
{
LOG
(
INFO
)
<<
"failed to delete file: "
<<
record_file_path
;
}
}
else
{
LOG
(
INFO
)
<<
"file: "
<<
record_file_path
<<
"does not exist."
;
}
}
std
::
string
record_file_path
;
JSONFileDatabase
test_db
;
std
::
vector
<
ir
::
LoweredFunc
>
lowered_funcs
;
Target
target
=
common
::
DefaultHostTarget
();
};
TEST_F
(
TestJSONFileDatabase
,
Serialize
)
{
ir
::
IRSchedule
ir_sch
=
MakeIRSchedule
(
lowered_funcs
,
"test"
);
auto
fused
=
ir_sch
.
Fuse
(
"B"
,
{
0
,
1
});
VLOG
(
3
)
<<
"after Fuse, Expr: "
<<
fused
;
TuningRecord
record1
(
"test"
,
SearchState
(
std
::
move
(
ir_sch
),
2.0
),
1.0
);
std
::
string
str
=
test_db
.
RecordToJSON
(
record1
);
VLOG
(
3
)
<<
"RecordToJSON: "
<<
str
;
// Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std
::
string
case1
=
"{
\"
taskKey
\"
:
\"
test
\"
,
\"
executionCost
\"
:1,
\"
predictedCost
\"
:2,
\"
trace
\"
:"
"{
\"
steps
\"
:[{
\"
type
\"
:
\"
FuseWithName
\"
,"
"
\"
outputs
\"
:[
\"
e0
\"
],
\"
attrs
\"
:[{
\"
name
\"
:
\"
loops_index
\"
,
\"
dtype
\"
:"
"
\"
INTS
\"
,
\"
ints
\"
:[0,1]},{
\"
name
\"
:
\"
block_"
"name
\"
,
\"
dtype
\"
:
\"
STRING
\"
,
\"
s
\"
:
\"
B
\"
}]}]}}"
;
std
::
string
case2
=
"{
\"
taskKey
\"
:
\"
test
\"
,
\"
executionCost
\"
:1,
\"
predictedCost
\"
:2,
\"
trace
\"
:"
"{
\"
steps
\"
:[{
\"
type
\"
:
\"
FuseWithName
\"
,"
"
\"
outputs
\"
:[
\"
e0
\"
],
\"
attrs
\"
:[{
\"
name
\"
:
\"
block_name
\"
,
\"
dtype
\"
:"
"
\"
STRING
\"
,
\"
s
\"
:
\"
B
\"
},{
\"
name
\"
:
\"
loops_"
"index
\"
,
\"
dtype
\"
:
\"
INTS
\"
,
\"
ints
\"
:[0,1]}]}]}}"
;
EXPECT_EQ
(
true
,
str
==
case1
||
str
==
case2
);
}
TEST_F
(
TestJSONFileDatabase
,
SaveLoad
)
{
ir
::
IRSchedule
ir_sch1
=
MakeIRSchedule
(
lowered_funcs
,
"k1"
);
auto
fused1
=
ir_sch1
.
Fuse
(
"B"
,
{
0
,
1
});
ir
::
IRSchedule
ir_sch2
=
MakeIRSchedule
(
lowered_funcs
,
"k2"
);
test_db
.
AddRecord
(
TuningRecord
(
"k1"
,
SearchState
(
std
::
move
(
ir_sch1
),
1.5
),
1.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
SearchState
(
std
::
move
(
ir_sch2
),
3.5
),
3.0
));
std
::
vector
<
std
::
string
>
strs
=
ReadLinesFromFile
(
record_file_path
);
ASSERT_EQ
(
strs
.
size
(),
2
);
// Because the serialization of protobuf does not guarantee the order, we give
// all possible results.
std
::
string
case1
=
"{
\"
taskKey
\"
:
\"
k1
\"
,
\"
executionCost
\"
:1,
\"
predictedCost
\"
:1.5,
\"
trace
\"
:"
"{
\"
steps
\"
:[{
\"
type
\"
:
\"
FuseWithName
\"
,"
"
\"
outputs
\"
:[
\"
e0
\"
],
\"
attrs
\"
:[{
\"
name
\"
:
\"
loops_index
\"
,
\"
dtype
\"
:"
"
\"
INTS
\"
,
\"
ints
\"
:[0,1]},{
\"
name
\"
:
\"
block_"
"name
\"
,
\"
dtype
\"
:
\"
STRING
\"
,
\"
s
\"
:
\"
B
\"
}]}]}}"
;
std
::
string
case2
=
"{
\"
taskKey
\"
:
\"
k1
\"
,
\"
executionCost
\"
:1,
\"
predictedCost
\"
:1.5,
\"
trace
\"
:"
"{
\"
steps
\"
:[{
\"
type
\"
:
\"
FuseWithName
\"
,"
"
\"
outputs
\"
:[
\"
e0
\"
],
\"
attrs
\"
:[{
\"
name
\"
:
\"
block_name
\"
,
\"
dtype
\"
:"
"
\"
STRING
\"
,
\"
s
\"
:
\"
B
\"
},{
\"
name
\"
:
\"
loops_"
"index
\"
,
\"
dtype
\"
:
\"
INTS
\"
,
\"
ints
\"
:[0,1]}]}]}}"
;
EXPECT_EQ
(
true
,
strs
[
0
]
==
case1
||
strs
[
0
]
==
case2
);
EXPECT_EQ
(
strs
[
1
],
"{
\"
taskKey
\"
:
\"
k2
\"
,
\"
executionCost
\"
:3,
\"
predictedCost
\"
:3.5,"
"
\"
trace
\"
:{}}"
);
}
TEST_F
(
TestJSONFileDatabase
,
Basic
)
{
test_db
.
AddRecord
(
TuningRecord
(
"k1"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k1"
),
1.0
),
1.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k2"
),
1.0
),
2.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k2"
),
1.0
),
3.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k3"
),
8.0
),
3.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k3"
),
7.0
),
4.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k3"
),
6.0
),
5.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k4"
),
1.0
),
4.0
));
ASSERT_EQ
(
test_db
.
Size
(),
6
);
auto
records
=
test_db
.
LookUp
(
"k3"
);
// check the max number of stored candidates will
// be restricted to capacity_per_task
ASSERT_EQ
(
test_db
.
Count
(
"k3"
),
2
);
ASSERT_EQ
(
records
.
size
(),
2
);
EXPECT_EQ
(
records
[
0
].
execution_cost
,
3.0
);
EXPECT_EQ
(
records
[
1
].
execution_cost
,
4.0
);
}
TEST_F
(
TestJSONFileDatabase
,
GetTopK
)
{
test_db
.
AddRecord
(
TuningRecord
(
"k1"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k1"
),
1.0
),
1.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k2"
),
1.0
),
2.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k2"
),
1.0
),
3.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k3"
),
1.0
),
3.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k3"
),
1.0
),
4.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k3"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k3"
),
1.0
),
5.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k4"
),
2.0
),
4.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k4"
),
1.2
),
2.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k4"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k4"
),
1.0
),
3.0
));
auto
records
=
test_db
.
GetTopK
(
"k4"
,
3
);
ASSERT_EQ
(
records
.
size
(),
2
);
EXPECT_FLOAT_EQ
(
records
[
0
].
predicted_cost
,
1.2
);
EXPECT_FLOAT_EQ
(
records
[
1
].
predicted_cost
,
1.0
);
}
TEST_F
(
TestJSONFileDatabase
,
Reload
)
{
ir
::
IRSchedule
ir_sch
=
MakeIRSchedule
(
lowered_funcs
,
"k1"
);
auto
fused
=
ir_sch
.
Fuse
(
"B"
,
{
0
,
1
});
test_db
.
AddRecord
(
TuningRecord
(
"k1"
,
SearchState
(
std
::
move
(
ir_sch
),
1.0
),
1.0
));
test_db
.
AddRecord
(
TuningRecord
(
"k2"
,
SearchState
(
MakeIRSchedule
(
lowered_funcs
,
"k2"
),
1.0
),
2.0
));
auto
records
=
test_db
.
LookUp
(
"k1"
);
ASSERT_EQ
(
records
.
size
(),
1
);
JSONFileDatabase
new_db
(
2
,
record_file_path
,
false
);
ASSERT_EQ
(
new_db
.
Size
(),
2
);
auto
loaded_records
=
new_db
.
LookUp
(
"k1"
);
ASSERT_EQ
(
records
.
size
(),
loaded_records
.
size
());
EXPECT_EQ
(
records
[
0
].
task_key
,
loaded_records
[
0
].
task_key
);
EXPECT_EQ
(
records
[
0
].
execution_cost
,
loaded_records
[
0
].
execution_cost
);
EXPECT_EQ
(
records
[
0
].
predicted_cost
,
loaded_records
[
0
].
predicted_cost
);
// check the equality of trace info between original TuningRecord and the
// loaded TuningRecord
const
auto
&
lhs_trace
=
records
[
0
].
trace
;
const
auto
&
rhs_trace
=
loaded_records
[
0
].
trace
;
google
::
protobuf
::
util
::
MessageDifferencer
dif
;
static
const
google
::
protobuf
::
Descriptor
*
descriptor
=
cinn
::
ir
::
proto
::
ScheduleDesc_Step
::
descriptor
();
dif
.
TreatAsSet
(
descriptor
->
FindFieldByName
(
"attrs"
));
EXPECT_TRUE
(
dif
.
Compare
(
lhs_trace
,
rhs_trace
));
// check the equality of module expr between original TuningRecord
// and the loaded TuningRecord by replaying with tracing ScheduleDesc
ir
::
IRSchedule
lhs_sch
=
MakeIRSchedule
(
lowered_funcs
,
"k1"
);
ir
::
IRSchedule
rhs_sch
=
MakeIRSchedule
(
lowered_funcs
,
"k1"
);
ir
::
ScheduleDesc
::
ReplayWithProto
(
lhs_trace
,
&
lhs_sch
);
ir
::
ScheduleDesc
::
ReplayWithProto
(
rhs_trace
,
&
rhs_sch
);
auto
lhs_exprs
=
lhs_sch
.
GetModule
().
GetExprs
();
auto
rhs_exprs
=
rhs_sch
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
lhs_exprs
.
size
(),
rhs_exprs
.
size
());
for
(
auto
i
=
0
;
i
<
lhs_exprs
.
size
();
++
i
)
{
std
::
string
lhs
=
utils
::
GetStreamCnt
(
lhs_exprs
.
at
(
i
));
std
::
string
rhs
=
utils
::
GetStreamCnt
(
rhs_exprs
.
at
(
i
));
size_t
remove_prefix_len
=
28
;
ASSERT_EQ
(
lhs
.
erase
(
0
,
remove_prefix_len
),
rhs
.
erase
(
0
,
remove_prefix_len
));
}
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS schedule_measurer.cc simple_builder.cc
simple_runner.cc
)
cinn_cc_test
(
test_simple_runner SRCS simple_runner_test.cc DEPS cinncore
)
cinn_cc_test
(
test_measurer SRCS measurer_test.cc DEPS cinncore
)
paddle/cinn/auto_schedule/measure/measure.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/task/tune_task.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/hlir/framework/instruction.h"
#include "paddle/cinn/runtime/cinn_runtime.h"
namespace
cinn
{
namespace
auto_schedule
{
// The input to a measurer
struct
MeasureInput
{
// The task object related to this measurement.
const
TuneTask
*
task
;
// lowered Exprs to be measured
std
::
vector
<
ir
::
LoweredFunc
>
lowered_funcs
;
// It is used to pass for some arguments that maybe
// specified value in advance. default is null
const
std
::
map
<
std
::
string
,
cinn_pod_value_t
>*
execution_args
=
nullptr
;
};
// The result of a measurement
struct
MeasureResult
{
// The time cost of execution in average of running
// with a specific repeated times.
double
execution_cost
=
0.0
;
// unit: us
// The time cost of the whole measurement process including
// building and running
double
elapsed_time
=
0.0
;
// unit: us
// used to return detail messages once an error occurred during measurement,
// empty if nothing goes wrong
std
::
string
error_msg
;
};
// The result of building with input schedule
struct
BuildResult
{
// The scope that owns detail compilation infos of parameters in the runtime
// program
const
hlir
::
framework
::
Scope
*
compiled_scope
;
// The executable program
std
::
unique_ptr
<
hlir
::
framework
::
Program
>
runtime_program
;
};
// This interface defines how to generate executable objects
// with input schedule. A builder should not contain stateful data
// related to any task so it can be called parallelly among multiple
// processes of task tuning.
class
ScheduleBuilder
{
public:
virtual
BuildResult
Build
(
const
MeasureInput
&
input
)
=
0
;
};
// This interface defines how to run the built result. Like above
// ScheduleBuilder, a runner shoule be implemented with not bound to a specific
// task.
class
ScheduleRunner
{
public:
virtual
MeasureResult
Run
(
const
MeasureInput
&
input
,
const
BuildResult
&
build_result
)
=
0
;
};
}
// namespace auto_schedule
}
// namespace cinn
Prev
1
…
3
4
5
6
7
8
9
10
11
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment