Unverified Commit 3c9df3b4 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Improvement to ck integration (#1859)

Add a CI job to test CK
Add MIGRAPHX_TUNE_CK env variable to only do tuning for CK
Continue tuning even when there is invalid configs
Fix a bug with parallel compilation not using all available threads
Add additional test for gemms using half types
Removed int32 as supported type since it doesnt pass our test suite
parent 1f827a7a
......@@ -130,6 +130,7 @@ rocm_enable_clang_tidy(
-bugprone-implicit-widening-of-multiplication-result
-bugprone-macro-parentheses
-bugprone-signed-char-misuse
-bugprone-unchecked-optional-access
# Disable the aliased reserved identifiers
-cert-dcl37-c
-cert-dcl51-cpp
......
......@@ -89,6 +89,8 @@ def rocmnodename(name) {
node_name = "${rocmtest_name} && vega";
} else if(name == "navi21") {
node_name = "${rocmtest_name} && navi21";
} else if(name == "mi100+") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a)";
} else if(name == "anygpu") {
node_name = "${rocmtest_name} && (gfx908 || gfx90a || vega)";
} else if(name == "nogpu") {
......@@ -120,7 +122,7 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
}
}, hiprtc_gpu_debug: rocmnode('vega') { cmake_build ->
stage('HipRTC GPU Debug') {
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On", gpu_debug: true, hiprtc_workarounds: true)
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release -DMIGRAPHX_USE_HIPRTC=On", gpu_debug: true, hiprtc_workarounds: true)
}
}, all_targets_debug : rocmnode('vega') { cmake_build ->
stage('All targets Release') {
......@@ -134,6 +136,12 @@ rocmtest clang_debug: rocmnode('vega') { cmake_build ->
cmake_build(flags: "-DCMAKE_BUILD_TYPE=debug -DMIGRAPHX_ENABLE_PYTHON=Off -DMIGRAPHX_ENABLE_MLIR=On -DCMAKE_CXX_FLAGS_DEBUG='${debug_flags}' -DCMAKE_C_FLAGS_DEBUG='${debug_flags}'")
}
}
}, ck_release: rocmnode('mi100+') { cmake_build ->
stage('CK Release') {
withEnv(['MIGRAPHX_ENABLE_CK=1', 'MIGRAPHX_TUNE_CK=1']) {
cmake_build(flags: "-DCMAKE_BUILD_TYPE=release")
}
}
}, clang_asan: rocmnode('nogpu') { cmake_build ->
stage('Clang ASAN') {
def sanitizers = "undefined,address"
......
......@@ -28,4 +28,4 @@ ROCmSoftwarePlatform/half@rocm-5.4.2
pybind/pybind11@d159a563383d10c821ba7b2a71905d1207db6de4 --build
msgpack/msgpack-c@cpp-3.3.0 -DMSGPACK_BUILD_TESTS=Off
sqlite3@3.17 -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@ac580f77a84c705c678816ef7195adfcc02bdda5 -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
ROCmSoftwarePlatform/composable_kernel@5172ec5280f14974beee2acf1af1db3b2670244c -DCK_BUILD_JIT_LIB=On -DCMAKE_POSITION_INDEPENDENT_CODE=On
......@@ -111,9 +111,27 @@ struct compile_plan
context* ctx;
operation preop;
instruction_ref ins;
optional<tuning_config> config = nullopt;
std::vector<compiled_result> results = {};
void update_config() { config = get_tuning_config(*ctx, ins, preop); }
optional<tuning_config> config = nullopt;
std::vector<optional<compiled_result>> results = {};
void update_config(bool exhaustive)
{
config = get_tuning_config(*ctx, ins, preop, exhaustive);
}
template <class Vector>
void insert_compiles(Vector& compiles, const value& solution, std::size_t i)
{
compiles.emplace_back([=] {
try
{
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
}
catch(...)
{
results[i] = nullopt;
}
});
}
template <class Vector>
void add_compiles(Vector& compiles, problem_cache& pc)
{
......@@ -127,9 +145,7 @@ struct compile_plan
if(solution.is_null())
return;
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
insert_compiles(compiles, solution, 0);
}
else
{
......@@ -139,18 +155,14 @@ struct compile_plan
for(auto i : range(solutions.size()))
{
auto solution = solutions[i];
compiles.emplace_back([=] {
results[i] = compiled_result{compile(*ctx, ins, preop, solution), ins};
});
insert_compiles(compiles, solution, i);
}
}
}
else
{
results.resize(1);
compiles.emplace_back([=] {
results[0] = compiled_result{compile(*ctx, ins, preop, value{}), ins};
});
insert_compiles(compiles, value{}, 0);
}
}
const compiled_result& benchmark(problem_cache& pc) const
......@@ -158,7 +170,11 @@ struct compile_plan
if(results.empty())
MIGRAPHX_THROW("No configs to tune");
if(results.size() == 1)
return results.front();
{
if(not results.front().has_value())
MIGRAPHX_THROW("No configs to tune");
return *results.front();
}
if(not config)
MIGRAPHX_THROW("Multiple kernels without config");
std::cout << "Benchmarking " << preop.name() << ": " << results.size() << " configs"
......@@ -167,11 +183,17 @@ struct compile_plan
times.reserve(results.size());
std::transform(
results.begin(), results.end(), std::back_inserter(times), [&](const auto& cr) {
return time_op(*ctx, cr.replace.code_object, to_shapes(cr.ins->inputs()), 20).first;
if(not cr.has_value())
return std::numeric_limits<double>::max();
return time_op(*ctx, cr->replace.code_object, to_shapes(cr->ins->inputs()), 20)
.first;
});
auto i = std::distance(times.begin(), std::min_element(times.begin(), times.end()));
std::cout << "Fastest solution: " << config->solutions.at(i) << std::endl;
pc.insert(preop.name(), config->problem, config->solutions.at(i));
return results[i];
if(not results[i].has_value())
MIGRAPHX_THROW("No valid tuned compilation.");
return *results[i];
}
void replace(module& m, problem_cache& pc) const
{
......@@ -185,7 +207,10 @@ void par_compile(std::size_t n, F f)
{
if(n == 0)
return;
par_for(n, n / value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{}, n), f);
auto d = value_of(MIGRAPHX_GPU_COMPILE_PARALLEL{});
if(d == 0)
d = n;
par_for(n, n / d, f);
}
struct compile_manager
......@@ -202,9 +227,7 @@ struct compile_manager
void update_configs()
{
if(not exhaustive)
return;
par_compile(cps.size(), [&](auto i) { cps[i].update_config(); });
par_compile(cps.size(), [&](auto i) { cps[i].update_config(exhaustive); });
}
void compile(module& m)
......
......@@ -63,9 +63,10 @@ compile_op(const std::string& name, context& ctx, const std::vector<shape>& inpu
return compiler_map().at(name).compile_op(ctx, inputs, v);
}
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op)
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive)
{
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op);
return compiler_map().at(op.name()).get_tuning_config(ctx, ins, op, exhaustive);
}
} // namespace gpu
......
......@@ -83,10 +83,23 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
auto a = ins->inputs().front()->get_shape();
auto b = ins->inputs().back()->get_shape();
auto m = a.lens()[a.lens().size() - 2];
auto n = b.lens().back();
auto k = a.lens().back();
// Integer gemms must be divisible by 4 in ck
if(contains({shape::int8_type, shape::int32_type}, ins->get_shape().type()))
{
if(m % 4 != 0)
return false;
if(n % 4 != 0)
return false;
if(k % 4 != 0)
return false;
}
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return a.lens().back() <= 2048;
return k <= 2048;
}
struct find_ck_gemm_pointwise
......
......@@ -79,7 +79,7 @@ using compiler_compile =
using compiler_compile_op =
std::function<operation(context&, const std::vector<shape>& inputs, const value&)>;
using compiler_tuning_config =
std::function<optional<tuning_config>(context&, instruction_ref, const operation&)>;
std::function<optional<tuning_config>(context&, instruction_ref, const operation&, bool)>;
void register_compiler(const std::string& name,
compiler_compile c,
......@@ -91,7 +91,8 @@ compiler_replace
compile(context& ctx, instruction_ref ins, const operation& op, const value& solution);
operation
compile_op(const std::string& name, context& ctx, const std::vector<shape>& inputs, const value& v);
optional<tuning_config> get_tuning_config(context& ctx, instruction_ref ins, const operation& op);
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive);
template <class T>
void register_compiler()
......@@ -125,7 +126,8 @@ template <class Derived>
struct compiler : auto_register_compiler<Derived>
{
const Derived& derived() const { return static_cast<const Derived&>(*this); }
optional<tuning_config> get_tuning_config(context&, instruction_ref, const operation&) const
optional<tuning_config>
get_tuning_config(context&, instruction_ref, const operation&, bool) const
{
return nullopt;
}
......
......@@ -50,6 +50,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_LOG_CK_GEMM);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_TUNING_VALUE);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_CK_DEBUG);
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TUNE_CK);
// NOLINTNEXTLINE
static const char* const ck_gemm_kernel = R"__migraphx__(
......@@ -265,7 +266,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
s = shape{s.type(), {m1, m2}};
}
std::vector<std::string> names() const { return {"gpu::ck_gemm"}; }
std::vector<std::string> names() const { return {"ck_gemm", "gpu::ck_gemm"}; }
static bool standard_batch(const shape& s)
{
......@@ -418,9 +419,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
{
auto shapes = to_shapes(ins->inputs());
auto v = create_settings(ins, op);
if(solution.is_null())
v["tuning_value"] = 4;
else
if(not solution.is_null())
v["tuning_value"] = solution;
return {compile_op(ctx, shapes, v),
[=](module& m, instruction_ref ins2, const operation& code_object) {
......@@ -436,8 +435,10 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
}
optional<tuning_config>
get_tuning_config(context& ctx, instruction_ref ins, const operation& op) const
get_tuning_config(context& ctx, instruction_ref ins, const operation& op, bool exhaustive) const
{
if(not exhaustive and not enabled(MIGRAPHX_TUNE_CK{}))
return nullopt;
tuning_config tc;
auto shapes = to_shapes(ins->inputs());
auto problem = create_problem(shapes, create_settings(ins, op));
......
......@@ -52,7 +52,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
ck::make_tuple(to_ck_tensor<Ds>()...),
to_ck_tensor<E>());
static_assert(desc.is_valid, "Invalid ck gemm.");
static_assert(desc.IsValid(), "Invalid ck gemm.");
G::Run(desc,
to_ck_const_pointer(a.data()),
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_broadcast_half : verify_program<gemm_add_broadcast_half>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::half_type, {1, 1, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto l3_b =
mm->add_instruction(migraphx::make_op("multibroadcast", {{"out_lens", {1, 2, 4}}}), l3);
auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), dot, l3_b);
return p;
}
};
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_half : verify_program<gemm_add_half>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
migraphx::shape m1_shape{migraphx::shape::half_type, {1, 2, 3}};
migraphx::shape m2_shape{migraphx::shape::half_type, {1, 3, 4}};
migraphx::shape m3_shape{migraphx::shape::half_type, {1, 2, 4}};
auto l1 = mm->add_parameter("1", m1_shape);
auto l2 = mm->add_parameter("2", m2_shape);
auto l3 = mm->add_parameter("3", m3_shape);
auto dot = mm->add_instruction(migraphx::make_op("dot"), l1, l2);
mm->add_instruction(migraphx::make_op("add"), dot, l3);
return p;
}
};
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment