Commit bd42c933 authored by Paul's avatar Paul
Browse files

Work with any size

parent 8694b810
......@@ -57,8 +57,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
return false;
if(a.lens()[1] > 2048)
return false;
return (a.lens()[0] % 8 == 0 and a.lens()[1] % 8 == 0 and b.lens()[0] % 8 == 0 and
b.lens()[1] % 8 == 0);
return true;
}
struct find_ck_gemm
......
......@@ -79,19 +79,72 @@ __global__ void ${kernel}(${params})
static std::size_t int_div_ceil(std::size_t x, std::size_t y) { return (x + y - 1) / y; }
static std::size_t block_size_index = 15;
static std::size_t get_block_size(const std::vector<std::string>& s)
struct instance
{
return std::stoull(s[block_size_index]);
}
std::vector<std::string> params;
static const std::size_t block_size_index = 15;
static std::size_t get_grid_size(const std::vector<std::string>& s, std::size_t m, std::size_t n)
{
auto mpb = std::stoull(s[block_size_index + 1]);
auto npb = std::stoull(s[block_size_index + 2]);
return int_div_ceil(m, mpb) * int_div_ceil(n, npb);
}
std::size_t int_at(std::size_t i) const
{
return std::stoull(params[i]);
}
std::size_t get_block_size() const
{
return int_at(block_size_index);
}
std::size_t get_pb(std::size_t i) const
{
assert(i < 4);
return int_at(block_size_index + 1 + i);
}
std::array<std::size_t, 3> get_pad(const std::array<std::size_t, 3>& config) const
{
std::array<std::size_t, 3> result{};
for(auto i:range(config.size()))
{
result[i] = int_div_ceil(config[i], get_pb(i)) * get_pb(i) - config[i];
}
return result;
}
std::size_t get_grid_size(const std::array<std::size_t, 3>& config) const
{
return int_div_ceil(config[0], get_pb(1)) * int_div_ceil(config[1], get_pb(1));
}
void set_ds_layout(const std::string& s)
{
assert(params[2] == "ck::Tuple<>");
params[2] = s;
}
void set_ds_type(const std::string& s)
{
assert(params[8] == "ck::Tuple<>");
params[8] = s;
}
void set_ds_op(const std::string& s)
{
assert(params[12] == "ck_passthrough");
params[12] = s;
}
void set_gemm(const std::string& s)
{
assert(params[13] == "ck::tensor_operation::device::GemmSpecialization::Default");
params[13] = s;
}
std::string str() const
{
return join_strings(params, ",");
}
};
template <class F, class Action>
auto action_decorate(F f, Action action)
......@@ -156,29 +209,40 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto b_shape = inputs[1];
auto c_shape = inputs.back();
auto m = c_shape.lens().front();
auto n = c_shape.lens().back();
std::array<char, 3> keys{'M', 'N', 'K'};
std::array<std::size_t, 3> config{c_shape.lens().front(), c_shape.lens().back(), a_shape.lens().back()};
auto i = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto instance = get_instance(i, [&](const auto& x) -> bool {
auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
});
})};
assert(inputs.size() < 4 or v.contains("post"));
if(v.contains("post"))
{
assert(instance[2] == "ck::Tuple<>");
instance[2] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout);
assert(instance[8] == "ck::Tuple<>");
instance[8] = ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type);
assert(instance[12] == "ck_passthrough");
instance[12] = v.at("post").to<std::string>();
ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
ip.set_ds_op(v.at("post").to<std::string>());
}
auto padding = ip.get_pad(config);
std::string gemm_type;
for(auto i:range(padding.size()))
{
if (padding[i] != 0)
gemm_type += keys[i];
}
if (gemm_type.empty())
gemm_type = "Default";
else
gemm_type += "Padding";
ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
hip_compile_options options;
auto block_size = get_block_size(instance);
auto grid_size = get_grid_size(instance, m, n);
auto block_size = ip.get_block_size();
auto grid_size = ip.get_grid_size(config);
options.set_launch_params(v, grid_size * block_size, block_size);
options.inputs = inputs;
options.output = c_shape;
......@@ -189,7 +253,7 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel,
{{"instance", join_strings(instance, ",")},
{{"instance", ip.str()},
{"params", enum_params(inputs.size(), "void * private_p")},
{"args", enum_params(inputs.size(), "private_p")},
{"preamble", v.get("preamble", std::string{})},
......
......@@ -72,7 +72,7 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds)
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
MIGRAPHX_CK_STATIC_ASSERT(GridwiseGemm::CheckValidity(
static_assert(GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
......@@ -33,9 +33,9 @@ struct gemm_add_relu : verify_program<gemm_add_relu>
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {256, 512}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {512, 1024}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {256, 1024}});
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {2, 3}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {3, 4}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {2, 4}});
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), dot, c);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment