Commit 359bb1cd authored by Paul's avatar Paul
Browse files

Merge branch 'ck-gemm-fused-transpose' into sd-opt

parents 1ac14290 55b363c9
#ifndef MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#define MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/index.hpp>
namespace migraphx {
template <class Tensor>
constexpr auto gemm_get_batches()
{
constexpr auto lens = get_shape_c<Tensor>{}.lens;
constexpr auto strides = get_shape_c<Tensor>{}.strides;
constexpr auto new_lens = sequence(
lens.size() - _c<2>, [&](auto... is) { return make_const_array(_c<lens[is]>...); });
constexpr auto new_strides = sequence(
strides.size() - _c<2>, [&](auto... is) { return make_const_array(_c<strides[is]>...); });
return make_shape(new_lens, new_strides);
}
template <class Tensor>
constexpr auto gemm_get_matrix()
{
constexpr auto lens = get_shape_c<Tensor>{}.lens;
constexpr auto strides = get_shape_c<Tensor>{}.strides;
constexpr auto m = lens.size() - _c<2>;
constexpr auto n = lens.size() - _c<1>;
constexpr auto new_lens = make_const_array(_c<lens[m]>, _c<lens[n]>);
constexpr auto new_strides = make_const_array(_c<strides[m]>, _c<strides[n]>);
return make_shape(new_lens, new_strides);
}
template <class Tensor, class T>
constexpr auto gemm_batch_slice(Tensor t, T i)
{
constexpr auto batch = gemm_get_batches<Tensor>();
constexpr auto matrix = gemm_get_matrix<Tensor>();
return make_tensor_view(t.data() + batch.index(i), matrix);
}
template <class BlocksPerBatch, class T, class... Ts>
constexpr auto gemm_batch_args(index idx, BlocksPerBatch bpb, T x, Ts... xs)
{
return [=](auto f) {
// All tensors should have the same rank
static_assert(
(true and ... and (get_shape_c<T>{}.lens.size() == get_shape_c<Ts>{}.lens.size())));
if constexpr(get_shape_c<T>{}.lens.size() > 2)
{
// Get the first batch since all batches should have the same number of elements
constexpr auto batch = gemm_get_batches<T>();
static_assert(
(true and ... and (batch.elements() == gemm_get_batches<Ts>().elements())));
idx.group_stride(bpb * batch.elements(), [&](auto gidx) {
const auto batch_idx = gidx / bpb;
f(gemm_batch_slice(x, batch_idx), gemm_batch_slice(xs, batch_idx)...);
});
}
else
{
f(x, xs...);
}
};
}
} // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_GEMM_BATCHER_HPP
...@@ -130,6 +130,8 @@ struct index ...@@ -130,6 +130,8 @@ struct index
return blockDim.x; return blockDim.x;
} }
#endif #endif
constexpr auto ngroup() const { return nglobal() / max_nlocal(); }
template <class N, class Stride> template <class N, class Stride>
static constexpr auto max_stride_iterations(N n, Stride stride) static constexpr auto max_stride_iterations(N n, Stride stride)
{ {
...@@ -231,6 +233,12 @@ struct index ...@@ -231,6 +233,12 @@ struct index
{ {
for_stride<true>(local, n, nlocal(), f); for_stride<true>(local, n, nlocal(), f);
} }
template <class F, class N>
__device__ void group_stride(N n, F f) const
{
for_stride<false>(group, n, ngroup(), f);
}
}; };
#ifdef MIGRAPHX_NLOCAL #ifdef MIGRAPHX_NLOCAL
......
...@@ -57,6 +57,7 @@ ...@@ -57,6 +57,7 @@
#include <migraphx/gpu/concat_gpu_opt.hpp> #include <migraphx/gpu/concat_gpu_opt.hpp>
#include <migraphx/gpu/context.hpp> #include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp> #include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_mlir.hpp> #include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp> #include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp> #include <migraphx/gpu/prefuse_ops.hpp>
...@@ -135,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti ...@@ -135,6 +136,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
dead_code_elimination{}, dead_code_elimination{},
fuse_mlir{&ctx}, fuse_mlir{&ctx},
dead_code_elimination{}, dead_code_elimination{},
fuse_ck{&ctx},
dead_code_elimination{},
lowering{&ctx, options.offload_copy}, lowering{&ctx, options.offload_copy},
eliminate_contiguous{"gpu::contiguous"}, eliminate_contiguous{"gpu::contiguous"},
dead_code_elimination{}, dead_code_elimination{},
......
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/apply_alpha_beta.hpp>
struct gemm_add_relu : verify_program<gemm_add_relu>
{
migraphx::program create_program() const
{
migraphx::program p;
auto* mm = p.get_main_module();
auto a = mm->add_parameter("1", {migraphx::shape::half_type, {2, 3}});
auto b = mm->add_parameter("2", {migraphx::shape::half_type, {3, 4}});
auto c = mm->add_parameter("3", {migraphx::shape::half_type, {2, 4}});
auto dot = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto add = mm->add_instruction(migraphx::make_op("add"), dot, c);
mm->add_instruction(migraphx::make_op("relu"), add);
return p;
}
};
import os, json, subprocess, tempfile, sys, argparse, contextlib, multiprocessing, multiprocessing.dummy
@contextlib.contextmanager
def tmp_file(dump=None):
tmp_name = None
try:
with tempfile.NamedTemporaryFile(mode='w+', delete=False) as f:
tmp_name = f.name
if dump:
dump(f)
yield tmp_name
finally:
os.unlink(tmp_name)
def pretty_print(obj):
print(json.dumps(obj, indent=2))
def run_driver(b):
print(b)
with tmp_file(lambda tf: json.dump(b, tf)) as tf:
if not os.path.exists('./bin/gpu-driver'):
print("./bin/gpu-driver not found")
os.abort()
cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True,
shell=True)
print(cp.stderr.decode())
cp.check_returncode()
for line in cp.stdout.decode().split("\n"):
s = line.strip()
if not s:
continue
if not ']: ' in s:
continue
yield s.split(']: ')[1].strip()
def convert_to_float(s):
return s[:-2]
def get_device_time(s):
fields = s.split(',')
return convert_to_float(fields[-1].strip())
def run_driver_ck(config, tuning, iterations):
b = {
'settings': {
'iterations': iterations
},
'compile_op': {
'name': 'ck_gemm',
'check': True,
'tuning_val': tuning,
'inputs': config
}
}
return run_driver(b)
def benchmark_ck(config, tuning):
try:
for line in run_driver_ck(config, tuning, 100):
dtime = get_device_time(line)
print(dtime)
return float(dtime)
print("Failed")
sys.exit(1)
except:
return sys.float_info.max
def benchmark(config, size):
times = [benchmark_ck(config, i) for i in range(size)]
return times.index(min(times))
def parse_log(f):
for line in open(f).readlines():
line = line.strip()
if not line.startswith('ck_gemm:'):
continue
line = line[len('ck_gemm:'):].strip()
config = json.loads(line)
yield config
def precompile(x):
try:
list(run_driver_ck(x[0], x[1], 0))
except:
pass
def precompile_log(f, n):
solutions = ((config, i) for config in parse_log(f) for i in range(n))
with multiprocessing.Pool(24) as p:
list(p.imap(precompile, solutions))
def benchmark_log(f, n):
result = []
for config in parse_log(f):
tuned = benchmark(config, n)
print("Tuned:", tuned)
result.append([config, tuned])
return result
def parse_args():
parser = argparse.ArgumentParser(description="Simple tuner for CK gemms")
parser.add_argument('--log',
'-l',
type=str,
metavar='file',
help='Path to logfile')
parser.add_argument('--out',
'-o',
type=str,
metavar='file',
help='Output json file to save tunings')
parser.add_argument('--precompile',
'-p',
action='store_true',
help='Precompile kernels first in parallel')
parser.add_argument('-n', type=int, help='Number of instances to tune')
args = parser.parse_args()
return args
def run(args):
if (args.precompile):
precompile_log(args.log, args.n)
tuned = benchmark_log(args.log, args.n)
json.dump(tuned, open(args.out, 'w+'))
run(parse_args())
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment