Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
Paddle
Commits
992bec46
Commit
992bec46
authored
Oct 08, 2023
by
“yuguo”
Browse files
2.5
parent
0259837d
Changes
357
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1744 additions
and
0 deletions
+1744
-0
paddle/cinn/auto_schedule/measure/measurer_test.cc
paddle/cinn/auto_schedule/measure/measurer_test.cc
+136
-0
paddle/cinn/auto_schedule/measure/schedule_measurer.cc
paddle/cinn/auto_schedule/measure/schedule_measurer.cc
+88
-0
paddle/cinn/auto_schedule/measure/schedule_measurer.h
paddle/cinn/auto_schedule/measure/schedule_measurer.h
+46
-0
paddle/cinn/auto_schedule/measure/simple_builder.cc
paddle/cinn/auto_schedule/measure/simple_builder.cc
+45
-0
paddle/cinn/auto_schedule/measure/simple_builder.h
paddle/cinn/auto_schedule/measure/simple_builder.h
+37
-0
paddle/cinn/auto_schedule/measure/simple_runner.cc
paddle/cinn/auto_schedule/measure/simple_runner.cc
+252
-0
paddle/cinn/auto_schedule/measure/simple_runner.h
paddle/cinn/auto_schedule/measure/simple_runner.h
+45
-0
paddle/cinn/auto_schedule/measure/simple_runner_test.cc
paddle/cinn/auto_schedule/measure/simple_runner_test.cc
+144
-0
paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt
paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt
+14
-0
paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc
...n/auto_schedule/post_schedule_rule/cooperative_process.cc
+77
-0
paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h
...nn/auto_schedule/post_schedule_rule/cooperative_process.h
+35
-0
paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc
...o_schedule/post_schedule_rule/cooperative_process_test.cc
+218
-0
paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h
...inn/auto_schedule/post_schedule_rule/post_schedule_rule.h
+38
-0
paddle/cinn/auto_schedule/search_space/CMakeLists.txt
paddle/cinn/auto_schedule/search_space/CMakeLists.txt
+11
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt
...n/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt
+53
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc
...inn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc
+185
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h
...cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h
+50
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc
...uto_schedule/search_space/auto_gen_rule/auto_bind_test.cc
+133
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc
...auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc
+45
-0
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h
.../auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h
+92
-0
No files found.
Too many changes to show.
To preserve performance only
357 of 357+
files are displayed.
Plain diff
Email patch
paddle/cinn/auto_schedule/measure/measurer_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include "paddle/cinn/auto_schedule/measure/schedule_measurer.h"
#include "paddle/cinn/auto_schedule/measure/simple_builder.h"
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/runtime/flags.h"
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
hlir
::
framework
::
BuildScope
;
using
::
cinn
::
hlir
::
framework
::
Graph
;
using
::
cinn
::
hlir
::
framework
::
GraphCompiler
;
frontend
::
Program
CreateAddReluProgram
()
{
constexpr
int
M
=
32
;
constexpr
int
N
=
24
;
frontend
::
NetBuilder
builder
(
"test"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
M
,
N
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
M
,
N
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
);
auto
d
=
builder
.
Relu
(
c
);
return
builder
.
Build
();
}
class
TestMeasurer
:
public
::
testing
::
Test
{
public:
std
::
unique_ptr
<
GraphCompiler
>
graph_compiler
;
std
::
vector
<
TuneTask
>
tasks
;
std
::
vector
<
MeasureInput
>
inputs
;
void
SetUp
()
override
{
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
std
::
unordered_set
<
std
::
string
>
fetch_ids
;
auto
program
=
CreateAddReluProgram
();
auto
graph
=
cinn
::
frontend
::
Optimize
(
&
program
,
fetch_ids
,
target
);
auto
scope
=
BuildScope
(
target
,
graph
);
graph_compiler
=
std
::
make_unique
<
GraphCompiler
>
(
target
,
scope
,
graph
);
TaskCreator
task_creator
;
tasks
=
task_creator
.
CreateTuneTaskOpLevel
(
graph
.
get
());
const
auto
&
dtype_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
common
::
Type
>>
(
"inferdtype"
);
const
auto
&
shape_dict
=
graph
->
GetAttrs
<
absl
::
flat_hash_map
<
std
::
string
,
hlir
::
framework
::
shape_t
>>
(
"infershape"
);
auto
op_lowerer
=
std
::
make_unique
<
hlir
::
framework
::
OpLowerer
>
(
dtype_dict
,
shape_dict
,
target
);
inputs
.
reserve
(
tasks
.
size
());
for
(
int
i
=
0
;
i
<
tasks
.
size
();
++
i
)
{
auto
*
task
=
&
tasks
[
i
];
task
->
Initialize
(
shape_dict
,
dtype_dict
,
op_lowerer
.
get
());
MeasureInput
input
;
input
.
task
=
task
;
input
.
lowered_funcs
=
task
->
lowered_funcs
;
inputs
.
emplace_back
(
input
);
}
}
};
class
ThrowExceptionBuilder
:
public
ScheduleBuilder
{
struct
Exception
:
public
std
::
exception
{
const
char
*
what
()
const
throw
()
{
return
"BuildError"
;
}
};
BuildResult
Build
(
const
MeasureInput
&
input
)
override
{
throw
Exception
();
}
};
class
ThrowExceptionRunner
:
public
ScheduleRunner
{
struct
Exception
:
public
std
::
exception
{
const
char
*
what
()
const
throw
()
{
return
"RunError"
;
}
};
MeasureResult
Run
(
const
MeasureInput
&
input
,
const
BuildResult
&
build_result
)
override
{
throw
Exception
();
}
};
TEST_F
(
TestMeasurer
,
Basic
)
{
auto
builder
=
std
::
make_unique
<
SimpleBuilder
>
(
graph_compiler
.
get
());
auto
runner
=
std
::
make_unique
<
SimpleRunner
>
(
1
);
auto
measurer
=
std
::
make_unique
<
ScheduleMeasurer
>
(
builder
.
get
(),
runner
.
get
());
std
::
vector
<
MeasureResult
>
results
=
measurer
->
Measure
(
inputs
);
ASSERT_EQ
(
inputs
.
size
(),
results
.
size
());
}
TEST_F
(
TestMeasurer
,
CatchException
)
{
auto
builder
=
std
::
make_unique
<
SimpleBuilder
>
(
graph_compiler
.
get
());
auto
runner
=
std
::
make_unique
<
SimpleRunner
>
(
1
);
auto
throw_builder
=
std
::
make_unique
<
ThrowExceptionBuilder
>
();
auto
throw_runner
=
std
::
make_unique
<
ThrowExceptionRunner
>
();
auto
measurer_with_build_error
=
std
::
make_unique
<
ScheduleMeasurer
>
(
throw_builder
.
get
(),
runner
.
get
(),
2
);
std
::
vector
<
MeasureResult
>
results
=
measurer_with_build_error
->
Measure
(
inputs
);
ASSERT_EQ
(
inputs
.
size
(),
results
.
size
());
EXPECT_EQ
(
results
[
0
].
error_msg
,
"Build failed, error: BuildError
\n
"
);
// TODO(CtfGo): test parallel build after we support thread-safe compilation
auto
measurer_with_run_error
=
std
::
make_unique
<
ScheduleMeasurer
>
(
builder
.
get
(),
throw_runner
.
get
(),
1
);
results
=
measurer_with_run_error
->
Measure
(
inputs
);
ASSERT_EQ
(
inputs
.
size
(),
results
.
size
());
EXPECT_EQ
(
results
[
0
].
error_msg
,
"Run failed, error: RunError
\n
"
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/schedule_measurer.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/schedule_measurer.h"
#include <exception>
#include "paddle/cinn/utils/multi_threading.h"
namespace
cinn
{
namespace
auto_schedule
{
ScheduleMeasurer
::
ScheduleMeasurer
(
ScheduleBuilder
*
builder
,
ScheduleRunner
*
runner
,
int
num_threads
)
:
builder_
(
builder
),
runner_
(
runner
),
num_threads_
(
num_threads
)
{}
std
::
vector
<
MeasureResult
>
ScheduleMeasurer
::
Measure
(
const
std
::
vector
<
MeasureInput
>&
inputs
)
{
if
(
inputs
.
empty
())
{
LOG
(
WARNING
)
<<
"inputs is empty"
;
return
{};
}
std
::
vector
<
BuildResult
>
build_results
(
inputs
.
size
());
std
::
vector
<
MeasureResult
>
results
(
inputs
.
size
());
// define how to build a candidate with the specified index
auto
build_fn
=
[
builder
=
builder_
,
&
inputs
,
&
build_results
,
&
results
](
int
index
)
{
VLOG
(
6
)
<<
"Build candidate index: "
<<
index
;
auto
m_start
=
std
::
chrono
::
steady_clock
::
now
();
try
{
build_results
[
index
]
=
builder
->
Build
(
inputs
[
index
]);
}
catch
(
std
::
exception
&
e
)
{
results
[
index
].
error_msg
=
utils
::
StringFormat
(
"Build failed, error: %s
\n
"
,
e
.
what
());
}
auto
time_span
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
steady_clock
::
now
()
-
m_start
);
results
[
index
].
elapsed_time
+=
static_cast
<
double
>
(
time_span
.
count
());
};
// define how to run a candidate with the specified index
auto
run_fn
=
[
runner
=
runner_
,
&
inputs
,
&
build_results
,
&
results
](
int
index
)
{
VLOG
(
6
)
<<
"Run candidate index: "
<<
index
;
auto
m_start
=
std
::
chrono
::
steady_clock
::
now
();
try
{
// if error occurred in building, then skip running
if
(
results
[
index
].
error_msg
.
empty
())
{
results
[
index
]
=
runner
->
Run
(
inputs
[
index
],
build_results
[
index
]);
}
}
catch
(
std
::
exception
&
e
)
{
results
[
index
].
error_msg
=
utils
::
StringFormat
(
"Run failed, error: %s
\n
"
,
e
.
what
());
}
auto
time_span
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
steady_clock
::
now
()
-
m_start
);
results
[
index
].
elapsed_time
+=
static_cast
<
double
>
(
time_span
.
count
());
};
// measure a candidate by calling build and run successively
auto
measure_fn
=
[
&
build_fn
,
&
run_fn
](
int
index
)
{
build_fn
(
index
);
run_fn
(
index
);
};
// default num_threads_ is 1 and in that case it will perform all measurements
// sequentially inplace.
utils
::
parallel_run
(
measure_fn
,
utils
::
SequenceDispatcher
(
0
,
inputs
.
size
()),
num_threads_
);
VLOG
(
4
)
<<
"Measure "
<<
inputs
.
size
()
<<
" candidates"
;
return
results
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/schedule_measurer.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/cinn/auto_schedule/measure/measure.h"
namespace
cinn
{
namespace
auto_schedule
{
// Entrance of schedule measurement, it mainly includes two processes:
// which are building the input schedules and running the generated codes.
class
ScheduleMeasurer
{
public:
ScheduleMeasurer
(
ScheduleBuilder
*
builder
,
ScheduleRunner
*
runner
,
int
num_threads
=
1
);
// Measure a batch of inputs and return all results once.
std
::
vector
<
MeasureResult
>
Measure
(
const
std
::
vector
<
MeasureInput
>&
inputs
);
private:
// The handle to implemented ScheduleBuilder
ScheduleBuilder
*
builder_
;
// The handle to implemented ScheduleRunner
ScheduleRunner
*
runner_
;
// The number of threads used to perform measurement,
// if it is greater than 1 that means parallel measurement.
const
int
num_threads_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/simple_builder.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/simple_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
using
hlir
::
framework
::
GraphCompiler
;
SimpleBuilder
::
SimpleBuilder
(
hlir
::
framework
::
GraphCompiler
*
graph_compiler
)
:
graph_compiler_
(
graph_compiler
)
{}
BuildResult
SimpleBuilder
::
Build
(
const
MeasureInput
&
input
)
{
CHECK_NE
(
graph_compiler_
,
static_cast
<
GraphCompiler
*>
(
nullptr
))
<<
"empty handle to GraphCompiler"
;
GraphCompiler
::
CompileOptions
compile_options
;
compile_options
.
groups
.
emplace_back
(
input
.
task
->
subgraph
);
compile_options
.
lowered_funcs
.
emplace_back
(
input
.
lowered_funcs
);
compile_options
.
remove_unused_variables
=
false
;
VLOG
(
5
)
<<
"call GraphCompiler to Build with Graph::Group size="
<<
compile_options
.
groups
.
size
()
<<
", lowered_funcs group size="
<<
compile_options
.
lowered_funcs
.
size
();
GraphCompiler
::
CompilationResult
compiled_result
=
graph_compiler_
->
Build
(
compile_options
);
BuildResult
build_result
;
build_result
.
compiled_scope
=
graph_compiler_
->
GetScope
().
get
();
build_result
.
runtime_program
=
std
::
move
(
compiled_result
.
runtime_program
);
return
build_result
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/simple_builder.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/measure/measure.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace
cinn
{
namespace
auto_schedule
{
// This class utilize the GraphCompiler bound to the graph to build
// the input schedule as executable objects
class
SimpleBuilder
:
public
ScheduleBuilder
{
public:
explicit
SimpleBuilder
(
hlir
::
framework
::
GraphCompiler
*
graph_compiler
);
// Build and pack the result
BuildResult
Build
(
const
MeasureInput
&
input
)
override
;
private:
hlir
::
framework
::
GraphCompiler
*
graph_compiler_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/simple_runner.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include <algorithm>
#include <chrono>
#include <iterator>
#include <limits>
#include <memory>
#include <random>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/buffer.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/framework/tensor.h"
namespace
cinn
{
namespace
auto_schedule
{
using
hlir
::
framework
::
Buffer
;
using
hlir
::
framework
::
Shape
;
using
hlir
::
framework
::
Tensor
;
// Parameters that needs to be initialized to 0.
// Key is the Op name, and value is the index of the input parameter in the Op.
static
const
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
int
>>
kInitWithZeroParams
=
{
{
"lookup_table"
,
{
1
}},
{
"gather"
,
{
1
}},
{
"gather_nd"
,
{
1
}},
{
"scatter_assign"
,
{
2
}},
{
"scatter_add"
,
{
2
}},
};
// Generate random value and populate them to the output address of memory
static
void
PopulateRandomValue
(
const
common
::
Type
&
type
,
const
int
numel
,
void
*
raw_ptr
)
{
std
::
random_device
seed
;
std
::
default_random_engine
engine
(
seed
());
if
(
type
==
common
::
Bool
())
{
auto
*
fmt_ptr
=
reinterpret_cast
<
bool
*>
(
raw_ptr
);
std
::
bernoulli_distribution
dist
(
0.5
);
std
::
generate_n
(
fmt_ptr
,
numel
,
[
&
engine
,
&
dist
]()
{
return
dist
(
engine
);
});
}
else
if
(
type
==
common
::
I32
())
{
auto
*
fmt_ptr
=
reinterpret_cast
<
int
*>
(
raw_ptr
);
std
::
uniform_int_distribution
<
int
>
dist
(
std
::
numeric_limits
<
int
>::
min
(),
std
::
numeric_limits
<
int
>::
max
());
std
::
generate_n
(
fmt_ptr
,
numel
,
[
&
engine
,
&
dist
]()
{
return
dist
(
engine
);
});
}
else
if
(
type
==
common
::
I64
())
{
auto
*
fmt_ptr
=
reinterpret_cast
<
int64_t
*>
(
raw_ptr
);
std
::
uniform_int_distribution
<
int64_t
>
dist
(
std
::
numeric_limits
<
int64_t
>::
min
(),
std
::
numeric_limits
<
int64_t
>::
max
());
std
::
generate_n
(
fmt_ptr
,
numel
,
[
&
engine
,
&
dist
]()
{
return
dist
(
engine
);
});
}
else
if
(
type
==
common
::
F32
())
{
auto
*
fmt_ptr
=
reinterpret_cast
<
float
*>
(
raw_ptr
);
std
::
uniform_real_distribution
<
float
>
dist
(
std
::
numeric_limits
<
float
>::
min
(),
std
::
numeric_limits
<
float
>::
max
());
std
::
generate_n
(
fmt_ptr
,
numel
,
[
&
engine
,
&
dist
]()
{
return
dist
(
engine
);
});
}
else
{
CHECK_EQ
(
type
.
bytes
(),
8
)
<<
"Unsupported type: "
<<
type
<<
", type.bytes = "
<<
type
.
bytes
();
auto
*
fmt_ptr
=
reinterpret_cast
<
uint8_t
*>
(
raw_ptr
);
std
::
uniform_int_distribution
<
uint8_t
>
dist
(
std
::
numeric_limits
<
uint8_t
>::
min
(),
std
::
numeric_limits
<
uint8_t
>::
max
());
std
::
generate_n
(
fmt_ptr
,
numel
,
[
&
engine
,
&
dist
]()
{
return
dist
(
engine
);
});
}
}
// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize
// the tensor with random value.
static
void
InitTensorData
(
Tensor
tensor
,
const
common
::
Target
&
target
,
bool
init_with_zero
)
{
int
mem_size
=
tensor
->
shape
().
numel
()
*
tensor
->
type
().
bytes
();
auto
*
tensor_data
=
tensor
->
mutable_data
(
target
,
tensor
->
type
());
#ifdef CINN_WITH_CUDA
if
(
target
==
common
::
DefaultNVGPUTarget
())
{
if
(
init_with_zero
)
{
cudaMemset
(
tensor_data
,
0
,
mem_size
);
}
else
{
void
*
tmp_buffer
=
malloc
(
mem_size
);
PopulateRandomValue
(
tensor
->
type
(),
tensor
->
shape
().
numel
(),
tmp_buffer
);
cudaMemcpy
(
tensor_data
,
tmp_buffer
,
mem_size
,
cudaMemcpyHostToDevice
);
free
(
tmp_buffer
);
}
}
#endif
if
(
target
==
common
::
DefaultHostTarget
())
{
if
(
init_with_zero
)
{
memset
(
tensor_data
,
0
,
mem_size
);
}
else
{
PopulateRandomValue
(
tensor
->
type
(),
tensor
->
shape
().
numel
(),
tensor_data
);
}
}
}
// Find all parameter names in the task corresponding to the MeasureInput
// that need to be initialized to 0 when measuring.
static
std
::
unordered_set
<
std
::
string
>
ParamsNeedInitWithZero
(
const
MeasureInput
&
input
)
{
std
::
unordered_set
<
std
::
string
>
res
;
std
::
vector
<
hlir
::
framework
::
Node
*>
nodes
=
input
.
task
->
subgraph
->
CollectNodes
();
for
(
auto
*
node
:
nodes
)
{
if
(
kInitWithZeroParams
.
count
(
node
->
op
()
->
name
)
!=
0
)
{
std
::
vector
<
int
>
param_idxs
=
kInitWithZeroParams
.
at
(
node
->
op
()
->
name
);
const
auto
&
inlinks
=
node
->
inlinks_in_order
();
for
(
int
param_idx
:
param_idxs
)
{
CHECK_GT
(
inlinks
.
size
(),
param_idx
);
auto
&
edge
=
inlinks
.
at
(
param_idx
);
std
::
string
param_name
=
edge
->
source
()
->
as
<
hlir
::
framework
::
NodeData
>
()
->
id
();
VLOG
(
6
)
<<
"param needs to be init with 0: "
<<
param_name
;
res
.
insert
(
param_name
);
}
}
}
return
res
;
}
SimpleRunner
::
SimpleRunner
(
int
repeat_times
)
:
repeat_times_
(
repeat_times
)
{
CHECK_GT
(
repeat_times_
,
0
)
<<
"repeat_times can't less than 0"
;
}
// Prepare execution arguments of all instructions to run, a argument
// may be obtained from the input of measurement or allocating new buffer
// with random value.
std
::
map
<
std
::
string
,
cinn_pod_value_t
>
SimpleRunner
::
PrepareArgs
(
const
MeasureInput
&
input
,
const
BuildResult
&
build_result
,
hlir
::
framework
::
Scope
*
temp_scope
)
{
std
::
map
<
std
::
string
,
cinn_pod_value_t
>
result
;
const
auto
&
target
=
input
.
task
->
target
;
const
auto
*
input_args
=
input
.
execution_args
;
const
auto
*
compiled_scope
=
build_result
.
compiled_scope
;
const
auto
&
instructions
=
build_result
.
runtime_program
->
GetRunInstructions
();
std
::
unordered_set
<
std
::
string
>
params_need_init_with_zero
=
ParamsNeedInitWithZero
(
input
);
auto
fill_arg_fn
=
[
&
](
const
std
::
string
&
param
)
{
VLOG
(
6
)
<<
"Filling argument:"
<<
param
;
// the argument is duplicated and has been prepared.
if
(
result
.
count
(
param
))
{
return
;
}
// if the input of measurement specifies this argument,
// we should use it firstly.
if
(
input_args
&&
input_args
->
count
(
param
))
{
VLOG
(
6
)
<<
"Argument["
<<
param
<<
"] use input value"
;
result
.
emplace
(
param
,
input_args
->
at
(
param
));
return
;
}
if
(
temp_scope
->
FindVar
(
param
))
{
auto
temp_tensor
=
temp_scope
->
GetTensor
(
param
);
result
.
emplace
(
param
,
temp_tensor
->
buffer
());
return
;
}
// allocate a new buffer for this argument and store it in
// the temporary scope to be released at proper time.
auto
compiled_tensor
=
compiled_scope
->
GetTensor
(
param
);
temp_scope
->
Var
<
Tensor
>
(
param
);
auto
temp_tensor
=
temp_scope
->
GetTensor
(
param
);
temp_tensor
->
Resize
(
compiled_tensor
->
shape
());
temp_tensor
->
set_type
(
compiled_tensor
->
type
());
temp_tensor
->
mutable_data
(
target
,
compiled_tensor
->
type
());
InitTensorData
(
temp_tensor
,
target
,
params_need_init_with_zero
.
count
(
param
)
!=
0
);
result
.
emplace
(
param
,
temp_tensor
->
buffer
());
};
for
(
auto
&&
instr
:
instructions
)
{
for
(
auto
&&
args
:
instr
->
GetInArgs
())
{
std
::
for_each
(
args
.
begin
(),
args
.
end
(),
fill_arg_fn
);
}
for
(
auto
&&
args
:
instr
->
GetOutArgs
())
{
std
::
for_each
(
args
.
begin
(),
args
.
end
(),
fill_arg_fn
);
}
}
return
result
;
}
MeasureResult
SimpleRunner
::
Run
(
const
MeasureInput
&
input
,
const
BuildResult
&
build_result
)
{
MeasureResult
result
;
auto
t_start
=
std
::
chrono
::
steady_clock
::
now
();
// prepare execution arguments
VLOG
(
4
)
<<
"SimpleRunner prepare execution arguments"
;
hlir
::
framework
::
Scope
temp_scope
;
// used for store temporary allocated data
auto
execution_args
=
PrepareArgs
(
input
,
build_result
,
&
temp_scope
);
// Execute each instruction repeatedly and take the average as cost.
result
.
execution_cost
=
0
;
const
auto
&
instructions
=
build_result
.
runtime_program
->
GetRunInstructions
();
for
(
auto
ct
=
0
;
ct
<
instructions
.
size
();
++
ct
)
{
auto
&&
instr
=
instructions
.
at
(
ct
);
VLOG
(
5
)
<<
"Start running instruction-"
<<
ct
;
auto
run_start
=
std
::
chrono
::
steady_clock
::
now
();
for
(
int
i
=
0
;
i
<
repeat_times_
;
++
i
)
{
instr
->
Run
(
&
execution_args
);
}
#ifdef CINN_WITH_CUDA
if
(
instr
->
target_
==
common
::
DefaultNVGPUTarget
())
{
CUDA_CALL
(
cudaDeviceSynchronize
());
}
#endif
auto
time_span
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
steady_clock
::
now
()
-
run_start
);
auto
cost_avg
=
static_cast
<
double
>
(
time_span
.
count
())
/
repeat_times_
;
result
.
execution_cost
+=
cost_avg
;
}
auto
time_span
=
std
::
chrono
::
duration_cast
<
std
::
chrono
::
microseconds
>
(
std
::
chrono
::
steady_clock
::
now
()
-
t_start
);
result
.
elapsed_time
=
static_cast
<
double
>
(
time_span
.
count
());
VLOG
(
4
)
<<
"A measurement done:repeat_times["
<<
repeat_times_
<<
"]total_elapsed_time["
<<
result
.
elapsed_time
<<
"]us,execution_cost["
<<
result
.
execution_cost
<<
"]us"
;
return
result
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/simple_runner.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/measure/measure.h"
#include "paddle/cinn/hlir/framework/instruction.h"
namespace
cinn
{
namespace
auto_schedule
{
// This class utilize the built instructions to execute the generated
// kernels and count the elapsed time as the measurement of performance
class
SimpleRunner
:
public
ScheduleRunner
{
public:
explicit
SimpleRunner
(
int
repeat_times
);
MeasureResult
Run
(
const
MeasureInput
&
input
,
const
BuildResult
&
build_result
)
override
;
private:
std
::
map
<
std
::
string
,
cinn_pod_value_t
>
PrepareArgs
(
const
MeasureInput
&
input
,
const
BuildResult
&
build_result
,
hlir
::
framework
::
Scope
*
temp_scope
);
private:
// The repeat times of running instructions,
// this runner will return the average time
const
int
repeat_times_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/measure/simple_runner_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include <gtest/gtest.h>
#include <chrono>
#include <thread>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace
cinn
{
namespace
auto_schedule
{
using
::
cinn
::
hlir
::
framework
::
BuildScope
;
using
::
cinn
::
hlir
::
framework
::
Graph
;
using
::
cinn
::
hlir
::
framework
::
GraphCompiler
;
using
::
cinn
::
hlir
::
framework
::
Instruction
;
using
::
cinn
::
hlir
::
framework
::
Scope
;
class
TestSimpleRunner
:
public
::
testing
::
Test
{
public:
#ifdef CINN_WITH_CUDA
Target
target
=
common
::
DefaultNVGPUTarget
();
#else
Target
target
=
common
::
DefaultHostTarget
();
#endif
std
::
shared_ptr
<
Graph
>
graph
;
std
::
shared_ptr
<
Scope
>
compiled_scope
;
std
::
unique_ptr
<
GraphCompiler
>
graph_compiler
;
std
::
unique_ptr
<
TuneTask
>
task
;
MeasureInput
input
;
BuildResult
build_result
;
static
frontend
::
Program
CreateAddReluProgram
();
void
SetUp
()
override
{
std
::
unordered_set
<
std
::
string
>
fetch_ids
;
auto
program
=
CreateAddReluProgram
();
auto
graph
=
cinn
::
frontend
::
Optimize
(
&
program
,
fetch_ids
,
target
);
compiled_scope
=
BuildScope
(
target
,
graph
);
graph_compiler
=
std
::
make_unique
<
GraphCompiler
>
(
target
,
compiled_scope
,
graph
);
auto
runtime_program
=
graph_compiler
->
Build
();
const
auto
&
instructions
=
runtime_program
->
GetRunInstructions
();
ASSERT_EQ
(
1
,
instructions
.
size
());
build_result
.
compiled_scope
=
compiled_scope
.
get
();
build_result
.
runtime_program
=
std
::
move
(
runtime_program
);
task
=
std
::
make_unique
<
TuneTask
>
();
#ifdef CINN_WITH_CUDA
task
->
target
=
common
::
DefaultNVGPUTarget
();
#else
task
->
target
=
common
::
DefaultHostTarget
();
#endif
task
->
subgraph
=
graph
->
fusion_groups
.
front
();
input
.
task
=
task
.
get
();
}
};
frontend
::
Program
TestSimpleRunner
::
CreateAddReluProgram
()
{
constexpr
int
M
=
32
;
constexpr
int
N
=
24
;
frontend
::
NetBuilder
builder
(
"test"
);
auto
a
=
builder
.
CreateInput
(
Float
(
32
),
{
M
,
N
},
"A"
);
auto
b
=
builder
.
CreateInput
(
Float
(
32
),
{
M
,
N
},
"B"
);
auto
c
=
builder
.
Add
(
a
,
b
);
auto
d
=
builder
.
Relu
(
c
);
return
builder
.
Build
();
}
TEST_F
(
TestSimpleRunner
,
MeasureWithRandomValue
)
{
auto
runner
=
std
::
make_unique
<
SimpleRunner
>
(
1
);
ASSERT_NO_THROW
(
runner
->
Run
(
input
,
build_result
));
}
TEST_F
(
TestSimpleRunner
,
MeasureWithSpecifiedArgs
)
{
auto
ta
=
compiled_scope
->
GetTensor
(
"A"
);
ta
->
mutable_data
<
float
>
(
target
);
auto
tb
=
compiled_scope
->
GetTensor
(
"B"
);
tb
->
mutable_data
<
float
>
(
target
);
std
::
map
<
std
::
string
,
cinn_pod_value_t
>
preset_args
;
preset_args
.
emplace
(
"A"
,
ta
->
buffer
());
preset_args
.
emplace
(
"B"
,
tb
->
buffer
());
auto
runner
=
std
::
make_unique
<
SimpleRunner
>
(
1
);
// specific several execution args
input
.
execution_args
=
&
preset_args
;
ASSERT_NO_THROW
(
runner
->
Run
(
input
,
build_result
));
}
TEST_F
(
TestSimpleRunner
,
TimeMeasured
)
{
// set up a BuildResult object with one instruction of the `sleep` function
void
(
*
sleep_fn
)(
void
*
,
int32_t
)
=
[](
void
*
,
int32_t
)
->
void
{
std
::
this_thread
::
sleep_for
(
std
::
chrono
::
microseconds
(
100
));
};
BuildResult
build_result
;
build_result
.
compiled_scope
=
nullptr
;
std
::
vector
<
std
::
unique_ptr
<
Instruction
>>
instructions
;
instructions
.
emplace_back
(
new
Instruction
(
common
::
DefaultHostTarget
(),
nullptr
,
{},
{
"empty_placeholder"
},
"sleep_fn"
));
instructions
.
back
()
->
SetLoweredFunc
(
reinterpret_cast
<
void
*>
(
sleep_fn
));
instructions
.
back
()
->
Finalize
();
build_result
.
runtime_program
.
reset
(
new
hlir
::
framework
::
Program
(
nullptr
,
std
::
move
(
instructions
)));
// to skip the condition check of params in Instruction::PreparePodArgs
std
::
map
<
std
::
string
,
cinn_pod_value_t
>
preset_args
;
preset_args
.
emplace
(
"empty_placeholder"
,
cinn_pod_value_t
());
input
.
execution_args
=
&
preset_args
;
auto
runner
=
std
::
make_unique
<
SimpleRunner
>
(
2
);
MeasureResult
measure_result
=
runner
->
Run
(
input
,
build_result
);
// because the kernel function will sleep 100 us,
// the cost time of execution and span in total must
// be greater than 100us and 200us (repeatedly running 2 times) respectively.
ASSERT_GE
(
measure_result
.
execution_cost
,
100
);
ASSERT_GE
(
measure_result
.
elapsed_time
,
200
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/post_schedule_rule/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS cooperative_process.cc
)
if
(
WITH_CUDA
)
cinn_nv_test
(
test_cooperative_process
SRCS
cooperative_process_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder
)
endif
()
paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
auto_schedule
{
int
ExtractNumThreads
(
const
ir
::
IRSchedule
&
ir_schedule
,
const
std
::
string
&
bind_axis
)
{
const
ir
::
ScheduleDesc
&
trace
=
ir_schedule
.
GetTraceDesc
();
for
(
auto
&&
step
:
trace
.
Steps
())
{
if
(
step
.
type
==
"Bind"
&&
step
.
attrs
.
find
(
"thread_axis"
)
!=
step
.
attrs
.
end
()
&&
absl
::
get
<
std
::
string
>
(
step
.
attrs
.
at
(
"thread_axis"
))
==
bind_axis
)
{
CHECK_EQ
(
step
.
inputs
.
at
(
"loop"
).
size
(),
1
);
return
step
.
inputs
.
at
(
"loop"
)[
0
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
();
}
}
return
0
;
}
std
::
vector
<
std
::
string
>
FindCandidates
(
const
ir
::
ScheduleDesc
&
trace
)
{
std
::
vector
<
std
::
string
>
candidate_block_names
;
for
(
auto
&&
step
:
trace
.
Steps
())
{
if
(
step
.
type
==
"AnnotateIntAttr"
&&
absl
::
get
<
std
::
string
>
(
step
.
attrs
.
at
(
"key"
))
==
ir
::
attr
::
cooperative_process
)
{
candidate_block_names
.
push_back
(
step
.
inputs
.
at
(
"block"
)[
0
]
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
);
}
}
return
candidate_block_names
;
}
bool
CooperativeProcess
::
Apply
(
ir
::
IRSchedule
*
schedule
)
{
int
num_threads
=
ExtractNumThreads
(
*
schedule
,
"threadIdx.x"
);
const
ir
::
ScheduleDesc
&
trace
=
schedule
->
GetTraceDesc
();
std
::
vector
<
std
::
string
>
candidate_block_names
=
FindCandidates
(
trace
);
for
(
auto
&&
candidate
:
candidate_block_names
)
{
auto
loop
=
schedule
->
GetLoops
(
candidate
).
back
();
if
(
loop
.
As
<
ir
::
For
>
()
->
extent
.
as_int32
()
<=
num_threads
)
{
schedule
->
Bind
(
loop
,
"threadIdx.x"
);
loop
=
schedule
->
GetLoops
(
candidate
).
back
();
schedule
->
SyncThreads
(
loop
);
}
else
{
auto
splited_buffer_loop
=
schedule
->
Split
(
loop
,
{
-
1
,
num_threads
});
schedule
->
Bind
(
splited_buffer_loop
.
back
(),
"threadIdx.x"
);
schedule
->
SyncThreads
(
splited_buffer_loop
[
0
]);
}
auto
block
=
schedule
->
GetBlock
(
candidate
);
schedule
->
Unannotate
(
block
,
ir
::
attr
::
cooperative_process
);
}
return
true
;
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h"
namespace
cinn
{
namespace
auto_schedule
{
/*
* @brief Rewrite the cooperative_process annotation to actually bind the loop
* on threadIdx. This rule is used for collaborative data handling of multiple
* threads within the same block.
*/
class
CooperativeProcess
:
public
PostScheduleRule
{
public:
CooperativeProcess
()
=
default
;
bool
Apply
(
ir
::
IRSchedule
*
schedule
)
final
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h"
#include <gtest/gtest.h>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
class
TestCooperativeProcess
:
public
TestAutoGenRuleBase
{
public:
int
fixed_rand_seed
=
1
;
std
::
vector
<
std
::
string
>
default_input_names
;
std
::
vector
<
std
::
string
>
default_output_names
;
};
TEST_F
(
TestCooperativeProcess
,
Matmul
)
{
default_input_names
=
{
"X"
,
"Y"
};
default_output_names
=
{
"temp_matmul_out"
};
std
::
vector
<
int32_t
>
X_shape
=
{
32
,
32
};
std
::
vector
<
int32_t
>
Y_shape
=
{
32
,
32
};
std
::
vector
<
int32_t
>
out_shape
=
{
32
,
32
};
int
num_blocks_y
=
2
;
int
num_blocks_x
=
2
;
int
num_threads_y
=
8
;
int
num_threads_x
=
2
;
int
steps_k
=
8
;
Initialize
(
common
::
DefaultNVGPUTarget
());
frontend
::
Program
matmul_op
=
tests
::
OpBuilder
(
"matmul"
).
Build
({{
"X"
,
X_shape
},
{
"Y"
,
Y_shape
}});
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
matmul_op
,
fixed_rand_seed
);
// split loops
std
::
vector
<
ir
::
Expr
>
loops
=
ir_schedule
.
GetLoops
(
"temp_matmul_out"
);
std
::
vector
<
ir
::
Expr
>
k_loops
=
ir_schedule
.
Split
(
loops
[
2
],
{
steps_k
,
-
1
});
std
::
vector
<
ir
::
Expr
>
j_loops
=
ir_schedule
.
Split
(
loops
[
1
],
{
num_blocks_x
,
num_threads_x
,
-
1
});
std
::
vector
<
ir
::
Expr
>
i_loops
=
ir_schedule
.
Split
(
loops
[
0
],
{
num_blocks_y
,
num_threads_y
,
-
1
});
// reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2
loops
=
ir_schedule
.
GetLoops
(
"temp_matmul_out"
);
ir_schedule
.
Reorder
({
loops
[
0
],
loops
[
3
],
loops
[
1
],
loops
[
4
],
loops
[
6
],
loops
[
7
],
loops
[
2
],
loops
[
5
]});
// fuse and bind
loops
=
ir_schedule
.
GetLoops
(
"temp_matmul_out"
);
ir
::
Expr
i1_j1_fused
=
ir_schedule
.
Fuse
({
loops
[
2
],
loops
[
3
]});
ir
::
Expr
i0_j0_fused
=
ir_schedule
.
Fuse
({
loops
[
0
],
loops
[
1
]});
loops
=
ir_schedule
.
GetLoops
(
"temp_matmul_out"
);
ir_schedule
.
Bind
(
loops
[
1
],
"threadIdx.x"
);
ir_schedule
.
Bind
(
loops
[
0
],
"blockIdx.x"
);
// cache read
ir
::
Expr
out_block
=
ir_schedule
.
GetBlock
(
"temp_matmul_out"
);
ir
::
Expr
X_cache_block
=
ir_schedule
.
CacheRead
(
out_block
,
1
,
"shared"
);
std
::
string
X_cache_block_name
=
X_cache_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
loops
=
ir_schedule
.
GetLoops
(
"temp_matmul_out"
);
ir_schedule
.
ComputeAt
(
X_cache_block
,
loops
[
2
]);
std
::
vector
<
ir
::
Expr
>
X_cache_loops
=
ir_schedule
.
GetLoops
(
X_cache_block_name
);
ir_schedule
.
Fuse
({
X_cache_loops
[
3
],
X_cache_loops
[
4
]});
ir_schedule
.
Annotate
(
ir_schedule
.
GetBlock
(
X_cache_block_name
),
ir
::
attr
::
cooperative_process
,
0
);
out_block
=
ir_schedule
.
GetBlock
(
"temp_matmul_out"
);
ir
::
Expr
Y_cache_block
=
ir_schedule
.
CacheRead
(
out_block
,
2
,
"shared"
);
std
::
string
Y_cache_block_name
=
Y_cache_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
;
loops
=
ir_schedule
.
GetLoops
(
"temp_matmul_out"
);
ir_schedule
.
ComputeAt
(
Y_cache_block
,
loops
[
2
]);
std
::
vector
<
ir
::
Expr
>
Y_cache_loops
=
ir_schedule
.
GetLoops
(
Y_cache_block_name
);
ir_schedule
.
Fuse
({
Y_cache_loops
[
3
],
Y_cache_loops
[
4
]});
ir_schedule
.
Annotate
(
ir_schedule
.
GetBlock
(
Y_cache_block_name
),
ir
::
attr
::
cooperative_process
,
0
);
// apply CooperativeProcess
CooperativeProcess
cooperative_process
;
cooperative_process
.
Apply
(
&
ir_schedule
);
// check ir
auto
ir
=
GetIR
(
ir_schedule
);
VLOG
(
6
)
<<
"after CooperativeProcess, ir:
\n
"
<<
ir
;
std
::
string
expected_ir
=
R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
serial for (j, 0, 2)
{
serial for (i_0, 0, 8)
{
serial for (j_0, 0, 2)
{
serial for (i_1, 0, 2)
{
serial for (j_1, 0, 8)
{
ScheduleBlock(temp_matmul_out__reduce_init)
{
i0, i1 = axis.bind(((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1)))
{
temp_matmul_out__reduce_init[((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))] = 0.00000000f
}
}
}
}
}
}
}
}
thread_bind[blockIdx.x] for (i_j_fused, 0, 4)
{
thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 16)
{
serial for (reduce_k_0, 0, 8)
{
serial for (ax0_0_ax1_0_fused, 0, 2)
{
thread_bind[threadIdx.x] for (ax0_0_ax1_0_fused_0, 0, 16)
{
ScheduleBlock(Y_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) / 8) + (4 * reduce_k_0)), ((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) % 8) + ((8 * (i_0_j_0_fused % 2)) + (16 * (i_j_fused % 2)))))
attrs(compute_at_extra_var:ax0_0,ax1_0)
{
Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1]
}
}
}
}
__syncthreads()
thread_bind[threadIdx.x] for (ax0_ax1_fused, 0, 8)
{
ScheduleBlock(X_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_ax1_fused / 4) + ((2 * (i_0_j_0_fused / 2)) + (16 * (i_j_fused / 2)))), ((ax0_ax1_fused % 4) + (4 * reduce_k_0)))
attrs(compute_at_extra_var:ax0,ax1)
{
X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1]
}
}
}
__syncthreads()
serial for (reduce_k_1, 0, 4)
{
serial for (i_1, 0, 2)
{
serial for (j_1, 0, 8)
{
ScheduleBlock(temp_matmul_out)
{
i0_0, i1_0, i2 = axis.bind(((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1)), ((4 * reduce_k_0) + reduce_k_1))
{
temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] = (temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] + (X_reshape_shared_temp_buffer[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))]))
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
)ROC"
;
ASSERT_EQ
(
ir
,
expected_ir
);
// build ir::Module and debug source code
auto
ir_module
=
BuildIRModule
(
ir_schedule
);
auto
source_code
=
GenSourceCode
(
ir_module
);
VLOG
(
6
)
<<
"scheduled source code:
\n
"
<<
source_code
;
// execute and check precision
CheckResult
(
GenExecutableKernel
(
ir_module
),
GenExecutableKernel
(
BuildIRModule
(
MakeIRSchedule
(
matmul_op
,
fixed_rand_seed
,
/* apply_manual_schedule*/
true
))),
default_input_names
,
default_output_names
,
{
X_shape
,
Y_shape
},
{
out_shape
},
target_
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* Base class for rules of post process,
* used to process schedules that rely on mutate results.
*/
class
PostScheduleRule
{
public:
PostScheduleRule
()
=
default
;
/**
* @brief Apply the post schedule rule to the given SearchState.
* @param state The given SearchState for post schedule.
* @return True if apply successfully.
*/
virtual
bool
Apply
(
ir
::
IRSchedule
*
schedule
)
=
0
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/CMakeLists.txt
0 → 100644
View file @
992bec46
add_subdirectory
(
auto_gen_rule
)
core_gather_headers
()
gather_srcs
(
cinnapi_src SRCS search_space.cc search_state.cc block_sampler.cc
rule_sampler.cc
)
cinn_cc_test
(
test_search_space SRCS search_space_test.cc DEPS cinncore
)
cinn_cc_test
(
test_search_state SRCS search_state_test.cc DEPS cinncore
)
cinn_cc_test
(
test_block_sampler SRCS block_sampler_test.cc DEPS cinncore
)
cinn_cc_test
(
test_rule_sampler SRCS rule_sampler_test.cc DEPS cinncore
)
paddle/cinn/auto_schedule/search_space/auto_gen_rule/CMakeLists.txt
0 → 100644
View file @
992bec46
core_gather_headers
()
gather_srcs
(
cinnapi_src
SRCS
auto_gen_rule.cc
auto_inline.cc
auto_unroll.cc
multi_level_tiling.cc
skip_rule.cc
auto_bind.cc
)
if
(
WITH_TESTING
)
cinn_cc_library
(
auto_gen_rule_test_helper
SRCS
test_helper.cc
DEPS
glog
gtest
cinncore
)
endif
()
if
(
WITH_CUDA
)
cinn_nv_test
(
test_mix_rules
SRCS
mix_rules_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder
)
cinn_nv_test
(
test_auto_bind
SRCS
auto_bind_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder
)
cinn_nv_test
(
test_multi_level_tiling
SRCS
multi_level_tiling_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder
)
endif
()
#cinn_cc_test(test_auto_inline SRCS auto_inline_test.cc DEPS cinncore auto_gen_rule_test_helper)
cinn_cc_test
(
test_skip_rule SRCS skip_rule_test.cc DEPS cinncore
)
cinn_cc_test
(
test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore
)
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h"
#include <glog/logging.h>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace
cinn
{
namespace
auto_schedule
{
static
constexpr
uint32_t
kMaxBlocks
=
256
;
// check whether the input ir::For is a spatial loop
bool
IsSpatialLoop
(
const
ir
::
For
*
for_node
)
{
if
(
for_node
->
for_type
()
!=
ir
::
ForType
::
Serial
)
return
false
;
const
auto
&
loop_var
=
for_node
->
loop_var
;
// collect cases where the loop_var used in one of reduce axis in underneath
// ScheduleBlock
auto
used_for_reduce_axis
=
ir
::
CollectIRNodesWithoutTensor
(
for_node
->
body
,
[
&
loop_var
](
const
Expr
*
x
)
{
const
auto
*
block_realize
=
x
->
As
<
ir
::
ScheduleBlockRealize
>
();
if
(
!
block_realize
)
return
false
;
const
auto
*
schedule_block
=
block_realize
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
();
CHECK
(
schedule_block
)
<<
"schedule_block field is not a ScheduleBlock"
;
CHECK_EQ
(
block_realize
->
iter_values
.
size
(),
schedule_block
->
iter_vars
.
size
());
for
(
int
i
=
0
;
i
<
block_realize
->
iter_values
.
size
();
++
i
)
{
const
ir
::
Var
&
iter_var
=
schedule_block
->
iter_vars
[
i
];
const
ir
::
Expr
&
binding
=
block_realize
->
iter_values
[
i
];
if
(
iter_var
->
is_reduce_axis
||
iter_var
->
name
.
substr
(
0
,
6
)
==
"reduce"
)
{
auto
used_exprs
=
ir
::
CollectIRNodesWithoutTensor
(
binding
,
[
&
loop_var
](
const
Expr
*
x
)
{
const
ir
::
_Var_
*
var
=
x
->
As
<
ir
::
_Var_
>
();
if
(
var
&&
(
x
->
same_as
(
loop_var
)
||
var
->
name
==
loop_var
->
name
))
{
return
true
;
}
return
false
;
});
if
(
!
used_exprs
.
empty
())
return
true
;
}
}
return
false
;
});
if
(
!
used_for_reduce_axis
.
empty
())
return
false
;
return
true
;
}
// count the number of loops that can be binded from the input for_node to
// bottom
int
CountLoopCanBinded
(
const
ir
::
For
*
for_node
)
{
int
cnt
=
0
;
while
(
for_node
)
{
if
(
for_node
->
is_binded
())
break
;
// has binded
if
(
!
IsSpatialLoop
(
for_node
))
break
;
// only spatial loops to be binded
cnt
+=
1
;
CHECK
(
for_node
->
body
.
defined
()
&&
for_node
->
body
.
As
<
ir
::
Block
>
())
<<
"Body is not defined"
;
const
ir
::
Block
*
body
=
for_node
->
body
.
As
<
ir
::
Block
>
();
// terminate when body of this loop has more than one statement or the body
// is not a ir::For node
for_node
=
body
->
stmts
.
size
()
==
1
?
body
->
stmts
[
0
].
As
<
ir
::
For
>
()
:
nullptr
;
}
return
cnt
;
}
void
BindGPUIndex
(
ir
::
IRSchedule
*
ir_schedule
,
const
std
::
string
&
block_name
,
int
num_loops_to_bind
,
int
max_blocks
,
int
max_threads_per_block
)
{
auto
all_loops
=
ir_schedule
->
GetLoops
(
block_name
);
CHECK_LE
(
num_loops_to_bind
,
all_loops
.
size
())
<<
"The number of loops to be bind is greater than size of all_loops"
;
// check whether it is the case that threadIdx has been binded but blockIdx
// not, the threadIdx can only be binded in the first loop after
// num_loops_to_bind loops because we has excluded other cases in
// CountLoopCanBinded
bool
gpu_thread_has_binded
=
num_loops_to_bind
<
all_loops
.
size
()
&&
all_loops
[
num_loops_to_bind
].
As
<
ir
::
For
>
()
->
is_gpu_thread_binded
();
Expr
fused_loop
=
ir_schedule
->
Fuse
(
{
all_loops
.
begin
(),
all_loops
.
begin
()
+
num_loops_to_bind
});
int32_t
extent
=
fused_loop
.
As
<
ir
::
For
>
()
->
extent
.
as_int32
();
if
(
gpu_thread_has_binded
)
{
ir_schedule
->
Bind
(
fused_loop
,
"blockIdx.x"
);
return
;
}
if
(
extent
<=
max_threads_per_block
)
{
ir_schedule
->
Bind
(
fused_loop
,
"threadIdx.x"
);
return
;
}
if
(
extent
<=
max_blocks
*
max_threads_per_block
)
{
auto
splits
=
ir_schedule
->
Split
(
fused_loop
,
{
-
1
,
max_threads_per_block
});
CHECK_EQ
(
splits
.
size
(),
2
);
ir_schedule
->
Bind
(
splits
[
0
],
"blockIdx.x"
);
ir_schedule
->
Bind
(
splits
[
1
],
"threadIdx.x"
);
}
else
{
auto
splits
=
ir_schedule
->
Split
(
fused_loop
,
{
-
1
,
max_blocks
,
max_threads_per_block
});
CHECK_EQ
(
splits
.
size
(),
3
);
ir_schedule
->
Reorder
({
splits
[
1
],
splits
[
2
],
splits
[
0
]});
all_loops
=
ir_schedule
->
GetLoops
(
block_name
);
ir_schedule
->
Bind
(
all_loops
[
0
],
"blockIdx.x"
);
ir_schedule
->
Bind
(
all_loops
[
1
],
"threadIdx.x"
);
}
}
RuleApplyType
AutoBind
::
Init
(
ir
::
IRSchedule
*
ir_schedule
)
{
ir_schedule_
=
ir_schedule
;
for
(
auto
&&
block_realize
:
ir_schedule
->
GetAllBlocks
())
{
auto
all_loops
=
ir_schedule
->
GetLoops
(
block_realize
);
if
(
CountLoopCanBinded
(
all_loops
[
0
].
As
<
ir
::
For
>
())
>
0
)
{
applicable_schedule_blocks_
.
emplace_back
(
block_realize
);
}
}
num_applicable_
=
applicable_schedule_blocks_
.
size
();
VLOG
(
6
)
<<
"Collect applicable_schedule_blocks_:"
<<
num_applicable_
;
return
num_applicable_
>
0
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
void
AutoBind
::
Apply
(
int
index
)
{
CHECK_LT
(
index
,
applicable_schedule_blocks_
.
size
())
<<
"invalid apply index:"
<<
index
;
auto
applied_block
=
applicable_schedule_blocks_
.
at
(
index
);
auto
all_loops
=
ir_schedule_
->
GetLoops
(
applied_block
);
BindGPUIndex
(
ir_schedule_
,
applied_block
.
As
<
ir
::
ScheduleBlockRealize
>
()
->
schedule_block
.
As
<
ir
::
ScheduleBlock
>
()
->
name
,
CountLoopCanBinded
(
all_loops
[
0
].
As
<
ir
::
For
>
()),
kMaxBlocks
,
target_
->
max_num_threads
());
return
;
}
RuleApplyType
AutoBind
::
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
{
Expr
block_expr
=
state
->
ir_schedule
.
GetBlock
(
block_name
);
auto
all_loops
=
state
->
ir_schedule
.
GetLoops
(
block_expr
);
return
CountLoopCanBinded
(
all_loops
[
0
].
As
<
ir
::
For
>
())
>
0
?
RuleApplyType
::
kApplyAndPruneOtherRules
:
RuleApplyType
::
kCannotApply
;
}
std
::
vector
<
SearchState
>
AutoBind
::
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
{
SearchState
new_state
=
state
.
Copy
();
auto
all_loops
=
state
->
ir_schedule
.
GetLoops
(
block_name
);
BindGPUIndex
(
&
new_state
->
ir_schedule
,
block_name
,
CountLoopCanBinded
(
all_loops
[
0
].
As
<
ir
::
For
>
()),
kMaxBlocks
,
target_
->
max_num_threads
());
return
{
new_state
};
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
// Auto bind GPU index(BlockIdx, ThreadIdx) to the loops around the block
class
AutoBind
:
public
AutoGenRule
{
public:
explicit
AutoBind
(
const
common
::
Target
&
target
)
:
AutoGenRule
(
target
)
{}
~
AutoBind
()
=
default
;
RuleApplyType
Init
(
ir
::
IRSchedule
*
init_schedule
)
override
;
void
Apply
(
int
index
)
override
;
std
::
string
GetRuleName
()
const
override
{
return
"AutoBind"
;
}
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
override
;
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
override
;
private:
std
::
vector
<
Expr
>
applicable_schedule_blocks_
;
};
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind_test.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cmath>
#include <functional>
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace
cinn
{
namespace
auto_schedule
{
static
constexpr
uint32_t
kMaxBlocks
=
256
;
static
constexpr
uint32_t
kMaxThreadsPerBlock
=
1024
;
class
TestAutoBind
:
public
TestAutoGenRuleBase
{
public:
std
::
vector
<
std
::
string
>
default_input_names
=
{
"X"
,
"Y"
};
std
::
vector
<
std
::
string
>
default_output_names
=
{
"temp_matmul_out"
};
void
TestApplyOnElementWiseAdd
(
const
std
::
vector
<
int
>&
shape
,
const
std
::
string
&
block_name
)
{
Initialize
(
common
::
DefaultNVGPUTarget
());
auto
test_program
=
tests
::
OpBuilder
(
"elementwise_add"
).
Build
({{
"X"
,
shape
},
{
"Y"
,
shape
}});
// construct input parameter
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
test_program
);
SearchState
state
(
ir_schedule
,
0
,
{});
std
::
vector
<
ir
::
Expr
>
func_bodys
=
ir_schedule
.
GetModule
().
GetExprs
();
ASSERT_EQ
(
func_bodys
.
size
(),
1UL
);
VLOG
(
6
)
<<
"Original Expr:
\n
"
<<
func_bodys
[
0
];
// apply
AutoBind
auto_bind
(
target_
);
ASSERT_EQ
(
auto_bind
.
AnalyseApplyType
(
state
,
block_name
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
auto
result
=
auto_bind
.
ApplyOnBlock
(
state
,
block_name
)[
0
];
std
::
vector
<
ir
::
Expr
>
exprs
=
result
->
ir_schedule
.
GetModule
().
GetExprs
();
EXPECT_EQ
(
exprs
.
size
(),
1UL
);
VLOG
(
6
)
<<
"AutoBind applied Expr: "
<<
exprs
[
0
];
// check bind result
auto
all_loops
=
result
->
ir_schedule
.
GetLoops
(
block_name
);
int
total_num
=
std
::
accumulate
(
shape
.
begin
(),
shape
.
end
(),
1
,
std
::
multiplies
<
int
>
());
if
(
total_num
<=
kMaxThreadsPerBlock
)
{
ASSERT_EQ
(
all_loops
.
size
(),
1
);
EXPECT_EQ
(
all_loops
[
0
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
(),
total_num
);
EXPECT_TRUE
(
all_loops
[
0
].
As
<
ir
::
For
>
()
->
is_gpu_thread_binded
());
}
else
if
(
total_num
<=
kMaxBlocks
*
kMaxThreadsPerBlock
)
{
ASSERT_EQ
(
all_loops
.
size
(),
2
);
EXPECT_EQ
(
all_loops
[
0
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
(),
static_cast
<
int32_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
total_num
)
/
kMaxThreadsPerBlock
)));
EXPECT_TRUE
(
all_loops
[
0
].
As
<
ir
::
For
>
()
->
is_gpu_block_binded
());
EXPECT_EQ
(
all_loops
[
1
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
(),
kMaxThreadsPerBlock
);
EXPECT_TRUE
(
all_loops
[
1
].
As
<
ir
::
For
>
()
->
is_gpu_thread_binded
());
}
else
{
ASSERT_EQ
(
all_loops
.
size
(),
3
);
EXPECT_EQ
(
all_loops
[
0
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
(),
kMaxBlocks
);
EXPECT_TRUE
(
all_loops
[
0
].
As
<
ir
::
For
>
()
->
is_gpu_block_binded
());
EXPECT_EQ
(
all_loops
[
1
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
(),
kMaxThreadsPerBlock
);
EXPECT_TRUE
(
all_loops
[
1
].
As
<
ir
::
For
>
()
->
is_gpu_thread_binded
());
EXPECT_EQ
(
all_loops
[
2
].
As
<
ir
::
For
>
()
->
extent
.
as_int32
(),
static_cast
<
int32_t
>
(
std
::
ceil
(
static_cast
<
double
>
(
total_num
)
/
(
kMaxBlocks
*
kMaxThreadsPerBlock
))));
EXPECT_FALSE
(
all_loops
[
2
].
As
<
ir
::
For
>
()
->
is_binded
());
}
// build and run
auto
ir_module
=
BuildIRModule
(
result
->
ir_schedule
);
auto
source_code
=
GenSourceCode
(
ir_module
);
VLOG
(
6
)
<<
"Optimized source code:
\n
"
<<
source_code
;
auto
manual_ir_module
=
BuildIRModule
(
MakeIRSchedule
(
test_program
,
/* apply_manual_schedule*/
true
));
VLOG
(
6
)
<<
"Manual-schedule compiled source code:
\n
"
<<
GenSourceCode
(
manual_ir_module
);
CheckResult
(
GenExecutableKernel
(
ir_module
),
GenExecutableKernel
(
manual_ir_module
),
default_input_names
,
{
block_name
},
{
shape
,
shape
},
{
shape
},
target_
);
}
};
TEST_F
(
TestAutoBind
,
AnalyseApplyType
)
{
Initialize
(
common
::
DefaultNVGPUTarget
());
ir
::
IRSchedule
ir_schedule
=
MakeIRSchedule
(
tests
::
OpBuilder
(
"matmul"
).
Build
({{
"X"
,
{
32
,
64
}},
{
"Y"
,
{
64
,
32
}}}));
SearchState
state
(
ir_schedule
,
0
,
{});
AutoBind
auto_bind
(
target_
);
const
std
::
string
&
applied_block_name
=
default_output_names
.
back
();
// outer two loops of initial Expr are spatial loops, so it can be applied
EXPECT_EQ
(
auto_bind
.
AnalyseApplyType
(
state
,
applied_block_name
),
RuleApplyType
::
kApplyAndPruneOtherRules
);
state
->
ir_schedule
.
Fuse
(
applied_block_name
,
{
0
,
1
});
state
->
ir_schedule
.
Bind
(
state
->
ir_schedule
.
GetLoops
(
applied_block_name
)[
0
],
"threadIdx.x"
);
// after fuse and bind, there is no loops to be binded.
EXPECT_EQ
(
auto_bind
.
AnalyseApplyType
(
state
,
applied_block_name
),
RuleApplyType
::
kCannotApply
);
}
TEST_F
(
TestAutoBind
,
ApplyOnBlock
)
{
TestApplyOnElementWiseAdd
({
64
,
128
},
"var_1"
);
TestApplyOnElementWiseAdd
({
57
,
133
,
125
},
"var_1"
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.cc
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include <glog/logging.h>
#include <cstdlib>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
AutoGenRule
::
AutoGenRule
(
const
common
::
Target
&
target
)
:
target_
(
&
target
)
{}
int
AutoGenRule
::
NumberApplicable
()
const
{
CHECK_GE
(
num_applicable_
,
0
)
<<
"Call "
<<
GetRuleName
()
<<
"::NumberApplicable() without initialization."
;
return
num_applicable_
;
}
void
AutoGenRule
::
ApplyRandomly
()
{
CHECK_GT
(
num_applicable_
,
0
)
<<
"Call "
<<
GetRuleName
()
<<
"::ApplyRandomly() with NumberApplicable() == 0"
;
int
index
=
rand
()
%
num_applicable_
;
// NOLINT
return
Apply
(
index
);
}
}
// namespace auto_schedule
}
// namespace cinn
paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h
0 → 100644
View file @
992bec46
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace
cinn
{
namespace
auto_schedule
{
/**
* Enum class representing how this rule can be applied to a ModuleExpr.
*/
enum
class
RuleApplyType
:
int
{
// This rule cannot be applied to ModuleExpr.
kCannotApply
=
0
,
// This rule can be applied to ModuleExpr,
// and the original ModuleExpr will be retained for branching with other
// rules.
kApply
=
1
,
// This rule can be applied, but the original ModuleExpr will be deleted,
// so the branches with other rules applied on the original ModuleExpr will be
// pruned.
kApplyAndPruneOtherRules
=
2
,
};
/**
* Base class for rules of auto-generating schedule (like Ansor's sketch
* generation)
*
*/
class
AutoGenRule
{
public:
explicit
AutoGenRule
(
const
common
::
Target
&
target
);
~
AutoGenRule
()
=
default
;
// Initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true
// otherwise.
virtual
RuleApplyType
Init
(
ir
::
IRSchedule
*
ir_schedule
)
=
0
;
// CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so
// a auto gen rule may be suitable to different number of
// Schedule Blocks. This method returns the number of ScheduleBlock
// that can be applied by this auto gen rule
virtual
int
NumberApplicable
()
const
;
// Applies rule on the ir::ModuleExpr for a schedule block randomly
virtual
void
ApplyRandomly
();
// Applies rule on the ir::ModuleExpr for a schedule block specified by index
// between 0 (inclusive) and NumberApplicable() (exclusive)
virtual
void
Apply
(
int
index
)
=
0
;
// Returns the name of the rule, used for debug.
virtual
std
::
string
GetRuleName
()
const
=
0
;
// Analyze the ApplyType of the rule used for a block determined by a specific
// SearchState and block name
virtual
RuleApplyType
AnalyseApplyType
(
SearchState
state
,
const
std
::
string
&
block_name
)
const
=
0
;
// Apply the rule to a block determined by a specific SearchState and block
// name
virtual
std
::
vector
<
SearchState
>
ApplyOnBlock
(
SearchState
state
,
const
std
::
string
&
block_name
)
=
0
;
protected:
// number of ScheduleBlock that can apply this auto gen rule
int
num_applicable_
=
-
1
;
// Target, not owned.
const
common
::
Target
*
target_
;
// IRSchedule, not owned;
ir
::
IRSchedule
*
ir_schedule_
;
};
}
// namespace auto_schedule
}
// namespace cinn
Prev
1
…
4
5
6
7
8
9
10
11
12
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment