Commit 297645e8 authored by Alan Turner's avatar Alan Turner
Browse files

Merge remote-tracking branch 'origin/ck-gemm-fused-transpose' into ck-gsg

parents ac7a0025 55b363c9
...@@ -212,22 +212,28 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x) ...@@ -212,22 +212,28 @@ void from_value_impl(rank<6>, const value& v, optional<T>& x)
x = from_value<T>(v); x = from_value<T>(v);
} }
template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{} or std::is_enum<T>{})> template <class T, MIGRAPHX_REQUIRES(std::is_arithmetic<T>{})>
void from_value_impl(rank<7>, const value& v, T& x) void from_value_impl(rank<7>, const value& v, T& x)
{ {
x = v.to<T>(); x = v.to<T>();
} }
inline void from_value_impl(rank<8>, const value& v, std::string& x) { x = v.to<std::string>(); } template <class T, MIGRAPHX_REQUIRES(std::is_enum<T>{})>
void from_value_impl(rank<8>, const value& v, T& x)
{
x = v.to<T>();
}
inline void from_value_impl(rank<9>, const value& v, std::string& x) { x = v.to<std::string>(); }
template <class T> template <class T>
auto from_value_impl(rank<9>, const value& v, T& x) -> decltype(x.from_value(v), void()) auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(x.from_value(v), void())
{ {
x.from_value(v); x.from_value(v);
} }
template <class T> template <class T>
auto from_value_impl(rank<10>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void()) auto from_value_impl(rank<11>, const value& v, T& x) -> decltype(migraphx_from_value(v, x), void())
{ {
migraphx_from_value(v, x); migraphx_from_value(v, x);
} }
...@@ -243,7 +249,7 @@ value to_value(const T& x) ...@@ -243,7 +249,7 @@ value to_value(const T& x)
template <class T> template <class T>
void from_value(const value& v, T& x) void from_value(const value& v, T& x)
{ {
detail::from_value_impl(rank<10>{}, v, x); detail::from_value_impl(rank<11>{}, v, x);
} }
} // namespace MIGRAPHX_INLINE_NS } // namespace MIGRAPHX_INLINE_NS
......
...@@ -186,20 +186,19 @@ if(MIGRAPHX_USE_HIPRTC) ...@@ -186,20 +186,19 @@ if(MIGRAPHX_USE_HIPRTC)
message(STATUS "MIGraphX is using hipRTC") message(STATUS "MIGraphX is using hipRTC")
target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1) target_compile_definitions(migraphx_gpu PRIVATE -DMIGRAPHX_USE_HIPRTC=1)
else() else()
message(STATUS "MIGraphX is using HIP Clang")
# Get flags needed to compile hip # Get flags needed to compile hip
include(TargetFlags) include(TargetFlags)
message(STATUS "HIP COMPILER FLAGS: ${HIP_COMPILER_FLAGS}") message(STATUS "HIP COMPILER FLAGS: ${HIP_COMPILER_FLAGS}")
target_flags(HIP_COMPILER_FLAGS hip::device) target_flags(HIP_COMPILER_FLAGS hip::device)
message(STATUS "HIP COMPILER FLAGS: ${HIP_COMPILER_FLAGS}")
# Remove cuda arch flags # Remove cuda arch flags
string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE --cuda-gpu-arch=[a-z0-9]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE --offload-arch=[a-z0-9:+-]+ "" HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
# Skip library paths since hip will incorrectly treat it as a source file # Skip library paths since hip will incorrectly treat it as a source file
string(APPEND HIP_COMPILER_FLAGS " ") string(APPEND HIP_COMPILER_FLAGS " ")
# Add ck includes
find_path(CK_INCLUDE_PATH ck/ck.hpp)
message(STATUS "CK path: ${CK_INCLUDE_PATH}")
string(APPEND HIP_COMPILER_FLAGS " -isystem ${CK_INCLUDE_PATH}")
foreach(_unused RANGE 2) foreach(_unused RANGE 2)
string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}") string(REGEX REPLACE " /[^ ]+\\.(a|so) " " " HIP_COMPILER_FLAGS "${HIP_COMPILER_FLAGS}")
endforeach() endforeach()
......
...@@ -681,6 +681,7 @@ struct find_contiguous_tranpose_precompile ...@@ -681,6 +681,7 @@ struct find_contiguous_tranpose_precompile
{ {
return match::name("gpu::contiguous")(match::arg(0)( return match::name("gpu::contiguous")(match::arg(0)(
match::name("transpose")( match::name("transpose")(
match::used_once(),
match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op"))) match::arg(0)(match::name("gpu::precompile_op")(match::used_once()).bind("op")))
.bind("transpose"))); .bind("transpose")));
} }
...@@ -693,12 +694,13 @@ struct find_contiguous_tranpose_precompile ...@@ -693,12 +694,13 @@ struct find_contiguous_tranpose_precompile
auto transpose = r.instructions["transpose"]; auto transpose = r.instructions["transpose"];
auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>(); auto perm = transpose->get_operator().to_value()["permutation"].to_vector<int64_t>();
auto iperm = invert_permutation(perm); auto iperm = invert_permutation(perm);
auto s = auto s = shape::from_permutation(
shape::from_permutation(op_ins->get_shape().type(), op_ins->get_shape().lens(), iperm); op_ins->get_shape().type(), op_ins->get_shape().lens(), perm); // perm or iperm?
auto v = op_ins->get_operator().to_value(); auto v = op_ins->get_operator().to_value();
v["output_shape"] = to_value(s); v["output_shape"] = to_value(s);
auto new_op = make_op("gpu::precompile_op", v); auto new_op = make_op("gpu::precompile_op", v);
m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs()); m.replace_instruction(op_ins, new_op, op_ins->inputs(), op_ins->module_inputs());
assert(ins->get_shape() == transpose->get_shape());
m.replace_instruction(ins, transpose); m.replace_instruction(ins, transpose);
} }
}; };
......
import os, json, subprocess, tempfile, sys, argparse, contextlib import os, json, subprocess, tempfile, sys, argparse, contextlib, multiprocessing, multiprocessing.dummy
ck_function = -1 ck_function = -1
...@@ -23,10 +23,14 @@ def pretty_print(obj): ...@@ -23,10 +23,14 @@ def pretty_print(obj):
def run_driver(b): def run_driver(b):
print(b) print(b)
with tmp_file(lambda tf: json.dump(b, tf)) as tf: with tmp_file(lambda tf: json.dump(b, tf)) as tf:
if not os.path.exists('./bin/gpu-driver'):
print("./bin/gpu-driver not found")
os.abort()
cp = subprocess.run('./bin/gpu-driver {}'.format(tf), cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True, capture_output=True,
check=True,
shell=True) shell=True)
print(cp.stderr.decode())
cp.check_returncode()
for line in cp.stdout.decode().split("\n"): for line in cp.stdout.decode().split("\n"):
s = line.strip() s = line.strip()
if not s: if not s:
...@@ -45,23 +49,29 @@ def get_device_time(s): ...@@ -45,23 +49,29 @@ def get_device_time(s):
return convert_to_float(fields[-1].strip()) return convert_to_float(fields[-1].strip())
def benchmark_ck(config, name, tuning): def run_driver_ck(config, tuning, iterations):
try: b = {
b = { 'settings': {
'settings': { 'iterations': iterations
'iterations': 100 },
}, 'compile_op': {
'compile_op': { 'name': 'ck_gemm',
'name': name, 'check': True,
'check': True, 'tuning_val': tuning,
'tuning_val': tuning, 'inputs': config
'inputs': config
}
} }
for line in run_driver(b): }
return run_driver(b)
def benchmark_ck(config, tuning):
try:
for line in run_driver_ck(config, tuning, 100):
dtime = get_device_time(line) dtime = get_device_time(line)
print(dtime) print(dtime)
return float(dtime) return float(dtime)
print("Failed")
sys.exit(1)
except: except:
return sys.float_info.max return sys.float_info.max
...@@ -86,6 +96,19 @@ def parse_log(f): ...@@ -86,6 +96,19 @@ def parse_log(f):
yield (config, 'ck_gemm_softmax_gemm') yield (config, 'ck_gemm_softmax_gemm')
def precompile(x):
try:
list(run_driver_ck(x[0], x[1], 0))
except:
pass
def precompile_log(f, n):
solutions = ((config, i) for config in parse_log(f) for i in range(n))
with multiprocessing.Pool(24) as p:
list(p.imap(precompile, solutions))
def benchmark_log(f, n): def benchmark_log(f, n):
result = [] result = []
for config, name in parse_log(f): for config, name in parse_log(f):
...@@ -107,12 +130,18 @@ def parse_args(): ...@@ -107,12 +130,18 @@ def parse_args():
type=str, type=str,
metavar='file', metavar='file',
help='Output json file to save tunings') help='Output json file to save tunings')
parser.add_argument('--precompile',
'-p',
action='store_true',
help='Precompile kernels first in parallel')
parser.add_argument('-n', type=int, help='Number of instances to tune') parser.add_argument('-n', type=int, help='Number of instances to tune')
args = parser.parse_args() args = parser.parse_args()
return args return args
def run(args): def run(args):
if (args.precompile):
precompile_log(args.log, args.n)
tuned = benchmark_log(args.log, args.n) tuned = benchmark_log(args.log, args.n)
json.dump(tuned, open(args.out, 'w+')) json.dump(tuned, open(args.out, 'w+'))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment