Commit 992bec46 authored by “yuguo”'s avatar “yuguo”
Browse files

2.5

parent 0259837d
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include "paddle/cinn/auto_schedule/measure/schedule_measurer.h"
#include "paddle/cinn/auto_schedule/measure/simple_builder.h"
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include "paddle/cinn/auto_schedule/task/task_creator.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
#include "paddle/cinn/runtime/flags.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
frontend::Program CreateAddReluProgram() {
constexpr int M = 32;
constexpr int N = 24;
frontend::NetBuilder builder("test");
auto a = builder.CreateInput(Float(32), {M, N}, "A");
auto b = builder.CreateInput(Float(32), {M, N}, "B");
auto c = builder.Add(a, b);
auto d = builder.Relu(c);
return builder.Build();
}
class TestMeasurer : public ::testing::Test {
public:
std::unique_ptr<GraphCompiler> graph_compiler;
std::vector<TuneTask> tasks;
std::vector<MeasureInput> inputs;
void SetUp() override {
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
auto scope = BuildScope(target, graph);
graph_compiler = std::make_unique<GraphCompiler>(target, scope, graph);
TaskCreator task_creator;
tasks = task_creator.CreateTuneTaskOpLevel(graph.get());
const auto& dtype_dict =
graph->GetAttrs<absl::flat_hash_map<std::string, common::Type>>(
"inferdtype");
const auto& shape_dict = graph->GetAttrs<
absl::flat_hash_map<std::string, hlir::framework::shape_t>>(
"infershape");
auto op_lowerer = std::make_unique<hlir::framework::OpLowerer>(
dtype_dict, shape_dict, target);
inputs.reserve(tasks.size());
for (int i = 0; i < tasks.size(); ++i) {
auto* task = &tasks[i];
task->Initialize(shape_dict, dtype_dict, op_lowerer.get());
MeasureInput input;
input.task = task;
input.lowered_funcs = task->lowered_funcs;
inputs.emplace_back(input);
}
}
};
class ThrowExceptionBuilder : public ScheduleBuilder {
struct Exception : public std::exception {
const char* what() const throw() { return "BuildError"; }
};
BuildResult Build(const MeasureInput& input) override { throw Exception(); }
};
class ThrowExceptionRunner : public ScheduleRunner {
struct Exception : public std::exception {
const char* what() const throw() { return "RunError"; }
};
MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override {
throw Exception();
}
};
TEST_F(TestMeasurer, Basic) {
auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get());
auto runner = std::make_unique<SimpleRunner>(1);
auto measurer =
std::make_unique<ScheduleMeasurer>(builder.get(), runner.get());
std::vector<MeasureResult> results = measurer->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size());
}
TEST_F(TestMeasurer, CatchException) {
auto builder = std::make_unique<SimpleBuilder>(graph_compiler.get());
auto runner = std::make_unique<SimpleRunner>(1);
auto throw_builder = std::make_unique<ThrowExceptionBuilder>();
auto throw_runner = std::make_unique<ThrowExceptionRunner>();
auto measurer_with_build_error =
std::make_unique<ScheduleMeasurer>(throw_builder.get(), runner.get(), 2);
std::vector<MeasureResult> results =
measurer_with_build_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Build failed, error: BuildError\n");
// TODO(CtfGo): test parallel build after we support thread-safe compilation
auto measurer_with_run_error =
std::make_unique<ScheduleMeasurer>(builder.get(), throw_runner.get(), 1);
results = measurer_with_run_error->Measure(inputs);
ASSERT_EQ(inputs.size(), results.size());
EXPECT_EQ(results[0].error_msg, "Run failed, error: RunError\n");
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/schedule_measurer.h"
#include <exception>
#include "paddle/cinn/utils/multi_threading.h"
namespace cinn {
namespace auto_schedule {
ScheduleMeasurer::ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads)
: builder_(builder), runner_(runner), num_threads_(num_threads) {}
std::vector<MeasureResult> ScheduleMeasurer::Measure(
const std::vector<MeasureInput>& inputs) {
if (inputs.empty()) {
LOG(WARNING) << "inputs is empty";
return {};
}
std::vector<BuildResult> build_results(inputs.size());
std::vector<MeasureResult> results(inputs.size());
// define how to build a candidate with the specified index
auto build_fn =
[builder = builder_, &inputs, &build_results, &results](int index) {
VLOG(6) << "Build candidate index: " << index;
auto m_start = std::chrono::steady_clock::now();
try {
build_results[index] = builder->Build(inputs[index]);
} catch (std::exception& e) {
results[index].error_msg =
utils::StringFormat("Build failed, error: %s\n", e.what());
}
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count());
};
// define how to run a candidate with the specified index
auto run_fn =
[runner = runner_, &inputs, &build_results, &results](int index) {
VLOG(6) << "Run candidate index: " << index;
auto m_start = std::chrono::steady_clock::now();
try {
// if error occurred in building, then skip running
if (results[index].error_msg.empty()) {
results[index] = runner->Run(inputs[index], build_results[index]);
}
} catch (std::exception& e) {
results[index].error_msg =
utils::StringFormat("Run failed, error: %s\n", e.what());
}
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - m_start);
results[index].elapsed_time += static_cast<double>(time_span.count());
};
// measure a candidate by calling build and run successively
auto measure_fn = [&build_fn, &run_fn](int index) {
build_fn(index);
run_fn(index);
};
// default num_threads_ is 1 and in that case it will perform all measurements
// sequentially inplace.
utils::parallel_run(
measure_fn, utils::SequenceDispatcher(0, inputs.size()), num_threads_);
VLOG(4) << "Measure " << inputs.size() << " candidates";
return results;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/cinn/auto_schedule/measure/measure.h"
namespace cinn {
namespace auto_schedule {
// Entrance of schedule measurement, it mainly includes two processes:
// which are building the input schedules and running the generated codes.
class ScheduleMeasurer {
public:
ScheduleMeasurer(ScheduleBuilder* builder,
ScheduleRunner* runner,
int num_threads = 1);
// Measure a batch of inputs and return all results once.
std::vector<MeasureResult> Measure(const std::vector<MeasureInput>& inputs);
private:
// The handle to implemented ScheduleBuilder
ScheduleBuilder* builder_;
// The handle to implemented ScheduleRunner
ScheduleRunner* runner_;
// The number of threads used to perform measurement,
// if it is greater than 1 that means parallel measurement.
const int num_threads_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/simple_builder.h"
namespace cinn {
namespace auto_schedule {
using hlir::framework::GraphCompiler;
SimpleBuilder::SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler)
: graph_compiler_(graph_compiler) {}
BuildResult SimpleBuilder::Build(const MeasureInput& input) {
CHECK_NE(graph_compiler_, static_cast<GraphCompiler*>(nullptr))
<< "empty handle to GraphCompiler";
GraphCompiler::CompileOptions compile_options;
compile_options.groups.emplace_back(input.task->subgraph);
compile_options.lowered_funcs.emplace_back(input.lowered_funcs);
compile_options.remove_unused_variables = false;
VLOG(5) << "call GraphCompiler to Build with Graph::Group size="
<< compile_options.groups.size() << ", lowered_funcs group size="
<< compile_options.lowered_funcs.size();
GraphCompiler::CompilationResult compiled_result =
graph_compiler_->Build(compile_options);
BuildResult build_result;
build_result.compiled_scope = graph_compiler_->GetScope().get();
build_result.runtime_program = std::move(compiled_result.runtime_program);
return build_result;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/measure/measure.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace cinn {
namespace auto_schedule {
// This class utilize the GraphCompiler bound to the graph to build
// the input schedule as executable objects
class SimpleBuilder : public ScheduleBuilder {
public:
explicit SimpleBuilder(hlir::framework::GraphCompiler* graph_compiler);
// Build and pack the result
BuildResult Build(const MeasureInput& input) override;
private:
hlir::framework::GraphCompiler* graph_compiler_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include <algorithm>
#include <chrono>
#include <iterator>
#include <limits>
#include <memory>
#include <random>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/hlir/framework/buffer.h"
#include "paddle/cinn/hlir/framework/scope.h"
#include "paddle/cinn/hlir/framework/tensor.h"
namespace cinn {
namespace auto_schedule {
using hlir::framework::Buffer;
using hlir::framework::Shape;
using hlir::framework::Tensor;
// Parameters that needs to be initialized to 0.
// Key is the Op name, and value is the index of the input parameter in the Op.
static const std::unordered_map<std::string, std::vector<int>>
kInitWithZeroParams = {
{"lookup_table", {1}},
{"gather", {1}},
{"gather_nd", {1}},
{"scatter_assign", {2}},
{"scatter_add", {2}},
};
// Generate random value and populate them to the output address of memory
static void PopulateRandomValue(const common::Type& type,
const int numel,
void* raw_ptr) {
std::random_device seed;
std::default_random_engine engine(seed());
if (type == common::Bool()) {
auto* fmt_ptr = reinterpret_cast<bool*>(raw_ptr);
std::bernoulli_distribution dist(0.5);
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I32()) {
auto* fmt_ptr = reinterpret_cast<int*>(raw_ptr);
std::uniform_int_distribution<int> dist(std::numeric_limits<int>::min(),
std::numeric_limits<int>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::I64()) {
auto* fmt_ptr = reinterpret_cast<int64_t*>(raw_ptr);
std::uniform_int_distribution<int64_t> dist(
std::numeric_limits<int64_t>::min(),
std::numeric_limits<int64_t>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else if (type == common::F32()) {
auto* fmt_ptr = reinterpret_cast<float*>(raw_ptr);
std::uniform_real_distribution<float> dist(
std::numeric_limits<float>::min(), std::numeric_limits<float>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
} else {
CHECK_EQ(type.bytes(), 8)
<< "Unsupported type: " << type << ", type.bytes = " << type.bytes();
auto* fmt_ptr = reinterpret_cast<uint8_t*>(raw_ptr);
std::uniform_int_distribution<uint8_t> dist(
std::numeric_limits<uint8_t>::min(),
std::numeric_limits<uint8_t>::max());
std::generate_n(
fmt_ptr, numel, [&engine, &dist]() { return dist(engine); });
}
}
// Initialize a tensor with 0 if init_with_zero == true, otherwise initialize
// the tensor with random value.
static void InitTensorData(Tensor tensor,
const common::Target& target,
bool init_with_zero) {
int mem_size = tensor->shape().numel() * tensor->type().bytes();
auto* tensor_data = tensor->mutable_data(target, tensor->type());
#ifdef CINN_WITH_CUDA
if (target == common::DefaultNVGPUTarget()) {
if (init_with_zero) {
cudaMemset(tensor_data, 0, mem_size);
} else {
void* tmp_buffer = malloc(mem_size);
PopulateRandomValue(tensor->type(), tensor->shape().numel(), tmp_buffer);
cudaMemcpy(tensor_data, tmp_buffer, mem_size, cudaMemcpyHostToDevice);
free(tmp_buffer);
}
}
#endif
if (target == common::DefaultHostTarget()) {
if (init_with_zero) {
memset(tensor_data, 0, mem_size);
} else {
PopulateRandomValue(tensor->type(), tensor->shape().numel(), tensor_data);
}
}
}
// Find all parameter names in the task corresponding to the MeasureInput
// that need to be initialized to 0 when measuring.
static std::unordered_set<std::string> ParamsNeedInitWithZero(
const MeasureInput& input) {
std::unordered_set<std::string> res;
std::vector<hlir::framework::Node*> nodes =
input.task->subgraph->CollectNodes();
for (auto* node : nodes) {
if (kInitWithZeroParams.count(node->op()->name) != 0) {
std::vector<int> param_idxs = kInitWithZeroParams.at(node->op()->name);
const auto& inlinks = node->inlinks_in_order();
for (int param_idx : param_idxs) {
CHECK_GT(inlinks.size(), param_idx);
auto& edge = inlinks.at(param_idx);
std::string param_name =
edge->source()->as<hlir::framework::NodeData>()->id();
VLOG(6) << "param needs to be init with 0: " << param_name;
res.insert(param_name);
}
}
}
return res;
}
SimpleRunner::SimpleRunner(int repeat_times) : repeat_times_(repeat_times) {
CHECK_GT(repeat_times_, 0) << "repeat_times can't less than 0";
}
// Prepare execution arguments of all instructions to run, a argument
// may be obtained from the input of measurement or allocating new buffer
// with random value.
std::map<std::string, cinn_pod_value_t> SimpleRunner::PrepareArgs(
const MeasureInput& input,
const BuildResult& build_result,
hlir::framework::Scope* temp_scope) {
std::map<std::string, cinn_pod_value_t> result;
const auto& target = input.task->target;
const auto* input_args = input.execution_args;
const auto* compiled_scope = build_result.compiled_scope;
const auto& instructions = build_result.runtime_program->GetRunInstructions();
std::unordered_set<std::string> params_need_init_with_zero =
ParamsNeedInitWithZero(input);
auto fill_arg_fn = [&](const std::string& param) {
VLOG(6) << "Filling argument:" << param;
// the argument is duplicated and has been prepared.
if (result.count(param)) {
return;
}
// if the input of measurement specifies this argument,
// we should use it firstly.
if (input_args && input_args->count(param)) {
VLOG(6) << "Argument[" << param << "] use input value";
result.emplace(param, input_args->at(param));
return;
}
if (temp_scope->FindVar(param)) {
auto temp_tensor = temp_scope->GetTensor(param);
result.emplace(param, temp_tensor->buffer());
return;
}
// allocate a new buffer for this argument and store it in
// the temporary scope to be released at proper time.
auto compiled_tensor = compiled_scope->GetTensor(param);
temp_scope->Var<Tensor>(param);
auto temp_tensor = temp_scope->GetTensor(param);
temp_tensor->Resize(compiled_tensor->shape());
temp_tensor->set_type(compiled_tensor->type());
temp_tensor->mutable_data(target, compiled_tensor->type());
InitTensorData(
temp_tensor, target, params_need_init_with_zero.count(param) != 0);
result.emplace(param, temp_tensor->buffer());
};
for (auto&& instr : instructions) {
for (auto&& args : instr->GetInArgs()) {
std::for_each(args.begin(), args.end(), fill_arg_fn);
}
for (auto&& args : instr->GetOutArgs()) {
std::for_each(args.begin(), args.end(), fill_arg_fn);
}
}
return result;
}
MeasureResult SimpleRunner::Run(const MeasureInput& input,
const BuildResult& build_result) {
MeasureResult result;
auto t_start = std::chrono::steady_clock::now();
// prepare execution arguments
VLOG(4) << "SimpleRunner prepare execution arguments";
hlir::framework::Scope temp_scope; // used for store temporary allocated data
auto execution_args = PrepareArgs(input, build_result, &temp_scope);
// Execute each instruction repeatedly and take the average as cost.
result.execution_cost = 0;
const auto& instructions = build_result.runtime_program->GetRunInstructions();
for (auto ct = 0; ct < instructions.size(); ++ct) {
auto&& instr = instructions.at(ct);
VLOG(5) << "Start running instruction-" << ct;
auto run_start = std::chrono::steady_clock::now();
for (int i = 0; i < repeat_times_; ++i) {
instr->Run(&execution_args);
}
#ifdef CINN_WITH_CUDA
if (instr->target_ == common::DefaultNVGPUTarget()) {
CUDA_CALL(cudaDeviceSynchronize());
}
#endif
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - run_start);
auto cost_avg = static_cast<double>(time_span.count()) / repeat_times_;
result.execution_cost += cost_avg;
}
auto time_span = std::chrono::duration_cast<std::chrono::microseconds>(
std::chrono::steady_clock::now() - t_start);
result.elapsed_time = static_cast<double>(time_span.count());
VLOG(4) << "A measurement done:repeat_times[" << repeat_times_
<< "]total_elapsed_time[" << result.elapsed_time
<< "]us,execution_cost[" << result.execution_cost << "]us";
return result;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/measure/measure.h"
#include "paddle/cinn/hlir/framework/instruction.h"
namespace cinn {
namespace auto_schedule {
// This class utilize the built instructions to execute the generated
// kernels and count the elapsed time as the measurement of performance
class SimpleRunner : public ScheduleRunner {
public:
explicit SimpleRunner(int repeat_times);
MeasureResult Run(const MeasureInput& input,
const BuildResult& build_result) override;
private:
std::map<std::string, cinn_pod_value_t> PrepareArgs(
const MeasureInput& input,
const BuildResult& build_result,
hlir::framework::Scope* temp_scope);
private:
// The repeat times of running instructions,
// this runner will return the average time
const int repeat_times_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/measure/simple_runner.h"
#include <gtest/gtest.h>
#include <chrono>
#include <thread>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/optimize.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/hlir/framework/graph_compiler.h"
namespace cinn {
namespace auto_schedule {
using ::cinn::hlir::framework::BuildScope;
using ::cinn::hlir::framework::Graph;
using ::cinn::hlir::framework::GraphCompiler;
using ::cinn::hlir::framework::Instruction;
using ::cinn::hlir::framework::Scope;
class TestSimpleRunner : public ::testing::Test {
public:
#ifdef CINN_WITH_CUDA
Target target = common::DefaultNVGPUTarget();
#else
Target target = common::DefaultHostTarget();
#endif
std::shared_ptr<Graph> graph;
std::shared_ptr<Scope> compiled_scope;
std::unique_ptr<GraphCompiler> graph_compiler;
std::unique_ptr<TuneTask> task;
MeasureInput input;
BuildResult build_result;
static frontend::Program CreateAddReluProgram();
void SetUp() override {
std::unordered_set<std::string> fetch_ids;
auto program = CreateAddReluProgram();
auto graph = cinn::frontend::Optimize(&program, fetch_ids, target);
compiled_scope = BuildScope(target, graph);
graph_compiler =
std::make_unique<GraphCompiler>(target, compiled_scope, graph);
auto runtime_program = graph_compiler->Build();
const auto& instructions = runtime_program->GetRunInstructions();
ASSERT_EQ(1, instructions.size());
build_result.compiled_scope = compiled_scope.get();
build_result.runtime_program = std::move(runtime_program);
task = std::make_unique<TuneTask>();
#ifdef CINN_WITH_CUDA
task->target = common::DefaultNVGPUTarget();
#else
task->target = common::DefaultHostTarget();
#endif
task->subgraph = graph->fusion_groups.front();
input.task = task.get();
}
};
frontend::Program TestSimpleRunner::CreateAddReluProgram() {
constexpr int M = 32;
constexpr int N = 24;
frontend::NetBuilder builder("test");
auto a = builder.CreateInput(Float(32), {M, N}, "A");
auto b = builder.CreateInput(Float(32), {M, N}, "B");
auto c = builder.Add(a, b);
auto d = builder.Relu(c);
return builder.Build();
}
TEST_F(TestSimpleRunner, MeasureWithRandomValue) {
auto runner = std::make_unique<SimpleRunner>(1);
ASSERT_NO_THROW(runner->Run(input, build_result));
}
TEST_F(TestSimpleRunner, MeasureWithSpecifiedArgs) {
auto ta = compiled_scope->GetTensor("A");
ta->mutable_data<float>(target);
auto tb = compiled_scope->GetTensor("B");
tb->mutable_data<float>(target);
std::map<std::string, cinn_pod_value_t> preset_args;
preset_args.emplace("A", ta->buffer());
preset_args.emplace("B", tb->buffer());
auto runner = std::make_unique<SimpleRunner>(1);
// specific several execution args
input.execution_args = &preset_args;
ASSERT_NO_THROW(runner->Run(input, build_result));
}
TEST_F(TestSimpleRunner, TimeMeasured) {
// set up a BuildResult object with one instruction of the `sleep` function
void (*sleep_fn)(void*, int32_t) = [](void*, int32_t) -> void {
std::this_thread::sleep_for(std::chrono::microseconds(100));
};
BuildResult build_result;
build_result.compiled_scope = nullptr;
std::vector<std::unique_ptr<Instruction>> instructions;
instructions.emplace_back(new Instruction(common::DefaultHostTarget(),
nullptr,
{},
{"empty_placeholder"},
"sleep_fn"));
instructions.back()->SetLoweredFunc(reinterpret_cast<void*>(sleep_fn));
instructions.back()->Finalize();
build_result.runtime_program.reset(
new hlir::framework::Program(nullptr, std::move(instructions)));
// to skip the condition check of params in Instruction::PreparePodArgs
std::map<std::string, cinn_pod_value_t> preset_args;
preset_args.emplace("empty_placeholder", cinn_pod_value_t());
input.execution_args = &preset_args;
auto runner = std::make_unique<SimpleRunner>(2);
MeasureResult measure_result = runner->Run(input, build_result);
// because the kernel function will sleep 100 us,
// the cost time of execution and span in total must
// be greater than 100us and 200us (repeatedly running 2 times) respectively.
ASSERT_GE(measure_result.execution_cost, 100);
ASSERT_GE(measure_result.elapsed_time, 200);
}
} // namespace auto_schedule
} // namespace cinn
core_gather_headers()
gather_srcs(cinnapi_src SRCS cooperative_process.cc)
if(WITH_CUDA)
cinn_nv_test(
test_cooperative_process
SRCS
cooperative_process_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder)
endif()
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/schedule/schedule_desc.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
int ExtractNumThreads(const ir::IRSchedule& ir_schedule,
const std::string& bind_axis) {
const ir::ScheduleDesc& trace = ir_schedule.GetTraceDesc();
for (auto&& step : trace.Steps()) {
if (step.type == "Bind" &&
step.attrs.find("thread_axis") != step.attrs.end() &&
absl::get<std::string>(step.attrs.at("thread_axis")) == bind_axis) {
CHECK_EQ(step.inputs.at("loop").size(), 1);
return step.inputs.at("loop")[0].As<ir::For>()->extent.as_int32();
}
}
return 0;
}
std::vector<std::string> FindCandidates(const ir::ScheduleDesc& trace) {
std::vector<std::string> candidate_block_names;
for (auto&& step : trace.Steps()) {
if (step.type == "AnnotateIntAttr" &&
absl::get<std::string>(step.attrs.at("key")) ==
ir::attr::cooperative_process) {
candidate_block_names.push_back(
step.inputs.at("block")[0]
.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name);
}
}
return candidate_block_names;
}
bool CooperativeProcess::Apply(ir::IRSchedule* schedule) {
int num_threads = ExtractNumThreads(*schedule, "threadIdx.x");
const ir::ScheduleDesc& trace = schedule->GetTraceDesc();
std::vector<std::string> candidate_block_names = FindCandidates(trace);
for (auto&& candidate : candidate_block_names) {
auto loop = schedule->GetLoops(candidate).back();
if (loop.As<ir::For>()->extent.as_int32() <= num_threads) {
schedule->Bind(loop, "threadIdx.x");
loop = schedule->GetLoops(candidate).back();
schedule->SyncThreads(loop);
} else {
auto splited_buffer_loop = schedule->Split(loop, {-1, num_threads});
schedule->Bind(splited_buffer_loop.back(), "threadIdx.x");
schedule->SyncThreads(splited_buffer_loop[0]);
}
auto block = schedule->GetBlock(candidate);
schedule->Unannotate(block, ir::attr::cooperative_process);
}
return true;
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/auto_schedule/post_schedule_rule/post_schedule_rule.h"
namespace cinn {
namespace auto_schedule {
/*
* @brief Rewrite the cooperative_process annotation to actually bind the loop
* on threadIdx. This rule is used for collaborative data handling of multiple
* threads within the same block.
*/
class CooperativeProcess : public PostScheduleRule {
public:
CooperativeProcess() = default;
bool Apply(ir::IRSchedule* schedule) final;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/post_schedule_rule/cooperative_process.h"
#include <gtest/gtest.h>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
namespace auto_schedule {
class TestCooperativeProcess : public TestAutoGenRuleBase {
public:
int fixed_rand_seed = 1;
std::vector<std::string> default_input_names;
std::vector<std::string> default_output_names;
};
TEST_F(TestCooperativeProcess, Matmul) {
default_input_names = {"X", "Y"};
default_output_names = {"temp_matmul_out"};
std::vector<int32_t> X_shape = {32, 32};
std::vector<int32_t> Y_shape = {32, 32};
std::vector<int32_t> out_shape = {32, 32};
int num_blocks_y = 2;
int num_blocks_x = 2;
int num_threads_y = 8;
int num_threads_x = 2;
int steps_k = 8;
Initialize(common::DefaultNVGPUTarget());
frontend::Program matmul_op =
tests::OpBuilder("matmul").Build({{"X", X_shape}, {"Y", Y_shape}});
ir::IRSchedule ir_schedule = MakeIRSchedule(matmul_op, fixed_rand_seed);
// split loops
std::vector<ir::Expr> loops = ir_schedule.GetLoops("temp_matmul_out");
std::vector<ir::Expr> k_loops = ir_schedule.Split(loops[2], {steps_k, -1});
std::vector<ir::Expr> j_loops =
ir_schedule.Split(loops[1], {num_blocks_x, num_threads_x, -1});
std::vector<ir::Expr> i_loops =
ir_schedule.Split(loops[0], {num_blocks_y, num_threads_y, -1});
// reorder to "SSRRS": i0, j0, i1, j1, k0, k1, j2, i2
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.Reorder({loops[0],
loops[3],
loops[1],
loops[4],
loops[6],
loops[7],
loops[2],
loops[5]});
// fuse and bind
loops = ir_schedule.GetLoops("temp_matmul_out");
ir::Expr i1_j1_fused = ir_schedule.Fuse({loops[2], loops[3]});
ir::Expr i0_j0_fused = ir_schedule.Fuse({loops[0], loops[1]});
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.Bind(loops[1], "threadIdx.x");
ir_schedule.Bind(loops[0], "blockIdx.x");
// cache read
ir::Expr out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr X_cache_block = ir_schedule.CacheRead(out_block, 1, "shared");
std::string X_cache_block_name = X_cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(X_cache_block, loops[2]);
std::vector<ir::Expr> X_cache_loops =
ir_schedule.GetLoops(X_cache_block_name);
ir_schedule.Fuse({X_cache_loops[3], X_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(X_cache_block_name),
ir::attr::cooperative_process,
0);
out_block = ir_schedule.GetBlock("temp_matmul_out");
ir::Expr Y_cache_block = ir_schedule.CacheRead(out_block, 2, "shared");
std::string Y_cache_block_name = Y_cache_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name;
loops = ir_schedule.GetLoops("temp_matmul_out");
ir_schedule.ComputeAt(Y_cache_block, loops[2]);
std::vector<ir::Expr> Y_cache_loops =
ir_schedule.GetLoops(Y_cache_block_name);
ir_schedule.Fuse({Y_cache_loops[3], Y_cache_loops[4]});
ir_schedule.Annotate(ir_schedule.GetBlock(Y_cache_block_name),
ir::attr::cooperative_process,
0);
// apply CooperativeProcess
CooperativeProcess cooperative_process;
cooperative_process.Apply(&ir_schedule);
// check ir
auto ir = GetIR(ir_schedule);
VLOG(6) << "after CooperativeProcess, ir: \n" << ir;
std::string expected_ir = R"ROC(Expr 0 {
{
ScheduleBlock(root)
{
{
serial for (i, 0, 2)
{
serial for (j, 0, 2)
{
serial for (i_0, 0, 8)
{
serial for (j_0, 0, 2)
{
serial for (i_1, 0, 2)
{
serial for (j_1, 0, 8)
{
ScheduleBlock(temp_matmul_out__reduce_init)
{
i0, i1 = axis.bind(((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1)))
{
temp_matmul_out__reduce_init[((16 * i) + ((2 * i_0) + i_1)), ((16 * j) + ((8 * j_0) + j_1))] = 0.00000000f
}
}
}
}
}
}
}
}
thread_bind[blockIdx.x] for (i_j_fused, 0, 4)
{
thread_bind[threadIdx.x] for (i_0_j_0_fused, 0, 16)
{
serial for (reduce_k_0, 0, 8)
{
serial for (ax0_0_ax1_0_fused, 0, 2)
{
thread_bind[threadIdx.x] for (ax0_0_ax1_0_fused_0, 0, 16)
{
ScheduleBlock(Y_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) / 8) + (4 * reduce_k_0)), ((((16 * ax0_0_ax1_0_fused) + ax0_0_ax1_0_fused_0) % 8) + ((8 * (i_0_j_0_fused % 2)) + (16 * (i_j_fused % 2)))))
attrs(compute_at_extra_var:ax0_0,ax1_0)
{
Y_reshape_shared_temp_buffer[v0, v1] = Y_reshape[v0, v1]
}
}
}
}
__syncthreads()
thread_bind[threadIdx.x] for (ax0_ax1_fused, 0, 8)
{
ScheduleBlock(X_reshape_shared_temp_buffer)
{
v0, v1 = axis.bind(((ax0_ax1_fused / 4) + ((2 * (i_0_j_0_fused / 2)) + (16 * (i_j_fused / 2)))), ((ax0_ax1_fused % 4) + (4 * reduce_k_0)))
attrs(compute_at_extra_var:ax0,ax1)
{
X_reshape_shared_temp_buffer[v0, v1] = X_reshape[v0, v1]
}
}
}
__syncthreads()
serial for (reduce_k_1, 0, 4)
{
serial for (i_1, 0, 2)
{
serial for (j_1, 0, 8)
{
ScheduleBlock(temp_matmul_out)
{
i0_0, i1_0, i2 = axis.bind(((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1)), ((4 * reduce_k_0) + reduce_k_1))
{
temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] = (temp_matmul_out[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))] + (X_reshape_shared_temp_buffer[((2 * (i_0_j_0_fused / 2)) + ((16 * (i_j_fused / 2)) + i_1)), ((4 * reduce_k_0) + reduce_k_1)] * Y_reshape_shared_temp_buffer[((4 * reduce_k_0) + reduce_k_1), ((8 * (i_0_j_0_fused % 2)) + ((16 * (i_j_fused % 2)) + j_1))]))
}
}
}
}
}
}
}
}
}
}
}
} // end Expr 0
)ROC";
ASSERT_EQ(ir, expected_ir);
// build ir::Module and debug source code
auto ir_module = BuildIRModule(ir_schedule);
auto source_code = GenSourceCode(ir_module);
VLOG(6) << "scheduled source code:\n" << source_code;
// execute and check precision
CheckResult(
GenExecutableKernel(ir_module),
GenExecutableKernel(BuildIRModule(MakeIRSchedule(
matmul_op, fixed_rand_seed, /* apply_manual_schedule*/ true))),
default_input_names,
default_output_names,
{X_shape, Y_shape},
{out_shape},
target_);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/**
* Base class for rules of post process,
* used to process schedules that rely on mutate results.
*/
class PostScheduleRule {
public:
PostScheduleRule() = default;
/**
* @brief Apply the post schedule rule to the given SearchState.
* @param state The given SearchState for post schedule.
* @return True if apply successfully.
*/
virtual bool Apply(ir::IRSchedule* schedule) = 0;
};
} // namespace auto_schedule
} // namespace cinn
add_subdirectory(auto_gen_rule)
core_gather_headers()
gather_srcs(cinnapi_src SRCS search_space.cc search_state.cc block_sampler.cc
rule_sampler.cc)
cinn_cc_test(test_search_space SRCS search_space_test.cc DEPS cinncore)
cinn_cc_test(test_search_state SRCS search_state_test.cc DEPS cinncore)
cinn_cc_test(test_block_sampler SRCS block_sampler_test.cc DEPS cinncore)
cinn_cc_test(test_rule_sampler SRCS rule_sampler_test.cc DEPS cinncore)
core_gather_headers()
gather_srcs(
cinnapi_src
SRCS
auto_gen_rule.cc
auto_inline.cc
auto_unroll.cc
multi_level_tiling.cc
skip_rule.cc
auto_bind.cc)
if(WITH_TESTING)
cinn_cc_library(
auto_gen_rule_test_helper
SRCS
test_helper.cc
DEPS
glog
gtest
cinncore)
endif()
if(WITH_CUDA)
cinn_nv_test(
test_mix_rules
SRCS
mix_rules_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder)
cinn_nv_test(
test_auto_bind
SRCS
auto_bind_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder)
cinn_nv_test(
test_multi_level_tiling
SRCS
multi_level_tiling_test.cc
DEPS
cinncore
auto_gen_rule_test_helper
test_program_builder)
endif()
#cinn_cc_test(test_auto_inline SRCS auto_inline_test.cc DEPS cinncore auto_gen_rule_test_helper)
cinn_cc_test(test_skip_rule SRCS skip_rule_test.cc DEPS cinncore)
cinn_cc_test(test_auto_unroll SRCS auto_unroll_test.cc DEPS cinncore)
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h"
#include <glog/logging.h>
#include "paddle/cinn/ir/schedule/ir_schedule.h"
#include "paddle/cinn/ir/utils/ir_copy.h"
#include "paddle/cinn/ir/utils/ir_nodes_collector.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
namespace cinn {
namespace auto_schedule {
static constexpr uint32_t kMaxBlocks = 256;
// check whether the input ir::For is a spatial loop
bool IsSpatialLoop(const ir::For* for_node) {
if (for_node->for_type() != ir::ForType::Serial) return false;
const auto& loop_var = for_node->loop_var;
// collect cases where the loop_var used in one of reduce axis in underneath
// ScheduleBlock
auto used_for_reduce_axis = ir::CollectIRNodesWithoutTensor(
for_node->body, [&loop_var](const Expr* x) {
const auto* block_realize = x->As<ir::ScheduleBlockRealize>();
if (!block_realize) return false;
const auto* schedule_block =
block_realize->schedule_block.As<ir::ScheduleBlock>();
CHECK(schedule_block) << "schedule_block field is not a ScheduleBlock";
CHECK_EQ(block_realize->iter_values.size(),
schedule_block->iter_vars.size());
for (int i = 0; i < block_realize->iter_values.size(); ++i) {
const ir::Var& iter_var = schedule_block->iter_vars[i];
const ir::Expr& binding = block_realize->iter_values[i];
if (iter_var->is_reduce_axis ||
iter_var->name.substr(0, 6) == "reduce") {
auto used_exprs = ir::CollectIRNodesWithoutTensor(
binding, [&loop_var](const Expr* x) {
const ir::_Var_* var = x->As<ir::_Var_>();
if (var &&
(x->same_as(loop_var) || var->name == loop_var->name)) {
return true;
}
return false;
});
if (!used_exprs.empty()) return true;
}
}
return false;
});
if (!used_for_reduce_axis.empty()) return false;
return true;
}
// count the number of loops that can be binded from the input for_node to
// bottom
int CountLoopCanBinded(const ir::For* for_node) {
int cnt = 0;
while (for_node) {
if (for_node->is_binded()) break; // has binded
if (!IsSpatialLoop(for_node)) break; // only spatial loops to be binded
cnt += 1;
CHECK(for_node->body.defined() && for_node->body.As<ir::Block>())
<< "Body is not defined";
const ir::Block* body = for_node->body.As<ir::Block>();
// terminate when body of this loop has more than one statement or the body
// is not a ir::For node
for_node = body->stmts.size() == 1 ? body->stmts[0].As<ir::For>() : nullptr;
}
return cnt;
}
void BindGPUIndex(ir::IRSchedule* ir_schedule,
const std::string& block_name,
int num_loops_to_bind,
int max_blocks,
int max_threads_per_block) {
auto all_loops = ir_schedule->GetLoops(block_name);
CHECK_LE(num_loops_to_bind, all_loops.size())
<< "The number of loops to be bind is greater than size of all_loops";
// check whether it is the case that threadIdx has been binded but blockIdx
// not, the threadIdx can only be binded in the first loop after
// num_loops_to_bind loops because we has excluded other cases in
// CountLoopCanBinded
bool gpu_thread_has_binded =
num_loops_to_bind < all_loops.size() &&
all_loops[num_loops_to_bind].As<ir::For>()->is_gpu_thread_binded();
Expr fused_loop = ir_schedule->Fuse(
{all_loops.begin(), all_loops.begin() + num_loops_to_bind});
int32_t extent = fused_loop.As<ir::For>()->extent.as_int32();
if (gpu_thread_has_binded) {
ir_schedule->Bind(fused_loop, "blockIdx.x");
return;
}
if (extent <= max_threads_per_block) {
ir_schedule->Bind(fused_loop, "threadIdx.x");
return;
}
if (extent <= max_blocks * max_threads_per_block) {
auto splits = ir_schedule->Split(fused_loop, {-1, max_threads_per_block});
CHECK_EQ(splits.size(), 2);
ir_schedule->Bind(splits[0], "blockIdx.x");
ir_schedule->Bind(splits[1], "threadIdx.x");
} else {
auto splits =
ir_schedule->Split(fused_loop, {-1, max_blocks, max_threads_per_block});
CHECK_EQ(splits.size(), 3);
ir_schedule->Reorder({splits[1], splits[2], splits[0]});
all_loops = ir_schedule->GetLoops(block_name);
ir_schedule->Bind(all_loops[0], "blockIdx.x");
ir_schedule->Bind(all_loops[1], "threadIdx.x");
}
}
RuleApplyType AutoBind::Init(ir::IRSchedule* ir_schedule) {
ir_schedule_ = ir_schedule;
for (auto&& block_realize : ir_schedule->GetAllBlocks()) {
auto all_loops = ir_schedule->GetLoops(block_realize);
if (CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0) {
applicable_schedule_blocks_.emplace_back(block_realize);
}
}
num_applicable_ = applicable_schedule_blocks_.size();
VLOG(6) << "Collect applicable_schedule_blocks_:" << num_applicable_;
return num_applicable_ > 0 ? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
void AutoBind::Apply(int index) {
CHECK_LT(index, applicable_schedule_blocks_.size())
<< "invalid apply index:" << index;
auto applied_block = applicable_schedule_blocks_.at(index);
auto all_loops = ir_schedule_->GetLoops(applied_block);
BindGPUIndex(ir_schedule_,
applied_block.As<ir::ScheduleBlockRealize>()
->schedule_block.As<ir::ScheduleBlock>()
->name,
CountLoopCanBinded(all_loops[0].As<ir::For>()),
kMaxBlocks,
target_->max_num_threads());
return;
}
RuleApplyType AutoBind::AnalyseApplyType(SearchState state,
const std::string& block_name) const {
Expr block_expr = state->ir_schedule.GetBlock(block_name);
auto all_loops = state->ir_schedule.GetLoops(block_expr);
return CountLoopCanBinded(all_loops[0].As<ir::For>()) > 0
? RuleApplyType::kApplyAndPruneOtherRules
: RuleApplyType::kCannotApply;
}
std::vector<SearchState> AutoBind::ApplyOnBlock(SearchState state,
const std::string& block_name) {
SearchState new_state = state.Copy();
auto all_loops = state->ir_schedule.GetLoops(block_name);
BindGPUIndex(&new_state->ir_schedule,
block_name,
CountLoopCanBinded(all_loops[0].As<ir::For>()),
kMaxBlocks,
target_->max_num_threads());
return {new_state};
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include "paddle/cinn/ir/ir.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
// Auto bind GPU index(BlockIdx, ThreadIdx) to the loops around the block
class AutoBind : public AutoGenRule {
public:
explicit AutoBind(const common::Target& target) : AutoGenRule(target) {}
~AutoBind() = default;
RuleApplyType Init(ir::IRSchedule* init_schedule) override;
void Apply(int index) override;
std::string GetRuleName() const override { return "AutoBind"; }
RuleApplyType AnalyseApplyType(SearchState state,
const std::string& block_name) const override;
std::vector<SearchState> ApplyOnBlock(SearchState state,
const std::string& block_name) override;
private:
std::vector<Expr> applicable_schedule_blocks_;
};
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_bind.h"
#include <glog/logging.h>
#include <gtest/gtest.h>
#include <cmath>
#include <functional>
#include <numeric>
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/test_helper.h"
#include "paddle/cinn/ir/utils/ir_printer.h"
#include "test/cpp/cinn/program_builder.h"
namespace cinn {
namespace auto_schedule {
static constexpr uint32_t kMaxBlocks = 256;
static constexpr uint32_t kMaxThreadsPerBlock = 1024;
class TestAutoBind : public TestAutoGenRuleBase {
public:
std::vector<std::string> default_input_names = {"X", "Y"};
std::vector<std::string> default_output_names = {"temp_matmul_out"};
void TestApplyOnElementWiseAdd(const std::vector<int>& shape,
const std::string& block_name) {
Initialize(common::DefaultNVGPUTarget());
auto test_program =
tests::OpBuilder("elementwise_add").Build({{"X", shape}, {"Y", shape}});
// construct input parameter
ir::IRSchedule ir_schedule = MakeIRSchedule(test_program);
SearchState state(ir_schedule, 0, {});
std::vector<ir::Expr> func_bodys = ir_schedule.GetModule().GetExprs();
ASSERT_EQ(func_bodys.size(), 1UL);
VLOG(6) << "Original Expr:\n" << func_bodys[0];
// apply
AutoBind auto_bind(target_);
ASSERT_EQ(auto_bind.AnalyseApplyType(state, block_name),
RuleApplyType::kApplyAndPruneOtherRules);
auto result = auto_bind.ApplyOnBlock(state, block_name)[0];
std::vector<ir::Expr> exprs = result->ir_schedule.GetModule().GetExprs();
EXPECT_EQ(exprs.size(), 1UL);
VLOG(6) << "AutoBind applied Expr: " << exprs[0];
// check bind result
auto all_loops = result->ir_schedule.GetLoops(block_name);
int total_num =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
if (total_num <= kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 1);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), total_num);
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_thread_binded());
} else if (total_num <= kMaxBlocks * kMaxThreadsPerBlock) {
ASSERT_EQ(all_loops.size(), 2);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(static_cast<double>(total_num) /
kMaxThreadsPerBlock)));
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
} else {
ASSERT_EQ(all_loops.size(), 3);
EXPECT_EQ(all_loops[0].As<ir::For>()->extent.as_int32(), kMaxBlocks);
EXPECT_TRUE(all_loops[0].As<ir::For>()->is_gpu_block_binded());
EXPECT_EQ(all_loops[1].As<ir::For>()->extent.as_int32(),
kMaxThreadsPerBlock);
EXPECT_TRUE(all_loops[1].As<ir::For>()->is_gpu_thread_binded());
EXPECT_EQ(
all_loops[2].As<ir::For>()->extent.as_int32(),
static_cast<int32_t>(std::ceil(static_cast<double>(total_num) /
(kMaxBlocks * kMaxThreadsPerBlock))));
EXPECT_FALSE(all_loops[2].As<ir::For>()->is_binded());
}
// build and run
auto ir_module = BuildIRModule(result->ir_schedule);
auto source_code = GenSourceCode(ir_module);
VLOG(6) << "Optimized source code:\n" << source_code;
auto manual_ir_module = BuildIRModule(
MakeIRSchedule(test_program, /* apply_manual_schedule*/ true));
VLOG(6) << "Manual-schedule compiled source code:\n"
<< GenSourceCode(manual_ir_module);
CheckResult(GenExecutableKernel(ir_module),
GenExecutableKernel(manual_ir_module),
default_input_names,
{block_name},
{shape, shape},
{shape},
target_);
}
};
TEST_F(TestAutoBind, AnalyseApplyType) {
Initialize(common::DefaultNVGPUTarget());
ir::IRSchedule ir_schedule = MakeIRSchedule(
tests::OpBuilder("matmul").Build({{"X", {32, 64}}, {"Y", {64, 32}}}));
SearchState state(ir_schedule, 0, {});
AutoBind auto_bind(target_);
const std::string& applied_block_name = default_output_names.back();
// outer two loops of initial Expr are spatial loops, so it can be applied
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kApplyAndPruneOtherRules);
state->ir_schedule.Fuse(applied_block_name, {0, 1});
state->ir_schedule.Bind(state->ir_schedule.GetLoops(applied_block_name)[0],
"threadIdx.x");
// after fuse and bind, there is no loops to be binded.
EXPECT_EQ(auto_bind.AnalyseApplyType(state, applied_block_name),
RuleApplyType::kCannotApply);
}
TEST_F(TestAutoBind, ApplyOnBlock) {
TestApplyOnElementWiseAdd({64, 128}, "var_1");
TestApplyOnElementWiseAdd({57, 133, 125}, "var_1");
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/cinn/auto_schedule/search_space/auto_gen_rule/auto_gen_rule.h"
#include <glog/logging.h>
#include <cstdlib>
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
AutoGenRule::AutoGenRule(const common::Target& target) : target_(&target) {}
int AutoGenRule::NumberApplicable() const {
CHECK_GE(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::NumberApplicable() without initialization.";
return num_applicable_;
}
void AutoGenRule::ApplyRandomly() {
CHECK_GT(num_applicable_, 0)
<< "Call " << GetRuleName()
<< "::ApplyRandomly() with NumberApplicable() == 0";
int index = rand() % num_applicable_; // NOLINT
return Apply(index);
}
} // namespace auto_schedule
} // namespace cinn
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/cinn/auto_schedule/search_space/search_state.h"
#include "paddle/cinn/common/target.h"
#include "paddle/cinn/ir/schedule/ir_schedule.h"
namespace cinn {
namespace auto_schedule {
/**
* Enum class representing how this rule can be applied to a ModuleExpr.
*/
enum class RuleApplyType : int {
// This rule cannot be applied to ModuleExpr.
kCannotApply = 0,
// This rule can be applied to ModuleExpr,
// and the original ModuleExpr will be retained for branching with other
// rules.
kApply = 1,
// This rule can be applied, but the original ModuleExpr will be deleted,
// so the branches with other rules applied on the original ModuleExpr will be
// pruned.
kApplyAndPruneOtherRules = 2,
};
/**
* Base class for rules of auto-generating schedule (like Ansor's sketch
* generation)
*
*/
class AutoGenRule {
public:
explicit AutoGenRule(const common::Target& target);
~AutoGenRule() = default;
// Initialize the AutoGenRule, it must be called before further actions.
// Returns false if the rule cannot be applied on the mod_expr, true
// otherwise.
virtual RuleApplyType Init(ir::IRSchedule* ir_schedule) = 0;
// CINN IRSchedule can contain many ScheduleBlock(s) and Loop(s), so
// a auto gen rule may be suitable to different number of
// Schedule Blocks. This method returns the number of ScheduleBlock
// that can be applied by this auto gen rule
virtual int NumberApplicable() const;
// Applies rule on the ir::ModuleExpr for a schedule block randomly
virtual void ApplyRandomly();
// Applies rule on the ir::ModuleExpr for a schedule block specified by index
// between 0 (inclusive) and NumberApplicable() (exclusive)
virtual void Apply(int index) = 0;
// Returns the name of the rule, used for debug.
virtual std::string GetRuleName() const = 0;
// Analyze the ApplyType of the rule used for a block determined by a specific
// SearchState and block name
virtual RuleApplyType AnalyseApplyType(
SearchState state, const std::string& block_name) const = 0;
// Apply the rule to a block determined by a specific SearchState and block
// name
virtual std::vector<SearchState> ApplyOnBlock(
SearchState state, const std::string& block_name) = 0;
protected:
// number of ScheduleBlock that can apply this auto gen rule
int num_applicable_ = -1;
// Target, not owned.
const common::Target* target_;
// IRSchedule, not owned;
ir::IRSchedule* ir_schedule_;
};
} // namespace auto_schedule
} // namespace cinn
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment