Commit 221dcea5 authored by Paul's avatar Paul
Browse files

Only enable check when tuning

parent c7657479
...@@ -184,6 +184,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler> ...@@ -184,6 +184,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
options.kernel_name = v.get("kernel", "ck_gemm_kernel"); options.kernel_name = v.get("kernel", "ck_gemm_kernel");
options.virtual_inputs = inputs; options.virtual_inputs = inputs;
if (v.get("check", false))
options.params += " -DMIGRAPHX_CK_CHECK=1";
auto src = interpolate_string(ck_gemm_kernel, auto src = interpolate_string(ck_gemm_kernel,
{{"instance", join_strings(instance, ",")}, {{"instance", join_strings(instance, ",")},
{"params", enum_params(inputs.size(), "void * private_p")}, {"params", enum_params(inputs.size(), "void * private_p")},
......
...@@ -109,5 +109,11 @@ struct ck_passthrough ...@@ -109,5 +109,11 @@ struct ck_passthrough
} }
}; };
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
#define MIGRAPHX_CK_STATIC_ASSERT(...)
#endif
} // namespace migraphx } // namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_CK_HPP #endif // MIGRAPHX_GUARD_KERNELS_CK_HPP
...@@ -59,7 +59,7 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds) ...@@ -59,7 +59,7 @@ __device__ void ck_gemm(E e, A a, B b, Ds... ds)
constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock = constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n); GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
static_assert(GridwiseGemm::CheckValidity( MIGRAPHX_CK_STATIC_ASSERT(GridwiseGemm::CheckValidity(
a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map)); a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
__shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()]; __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
......
...@@ -23,6 +23,7 @@ def run_driver(b): ...@@ -23,6 +23,7 @@ def run_driver(b):
with tmp_file(lambda tf: json.dump(b, tf)) as tf: with tmp_file(lambda tf: json.dump(b, tf)) as tf:
cp = subprocess.run('./bin/gpu-driver {}'.format(tf), cp = subprocess.run('./bin/gpu-driver {}'.format(tf),
capture_output=True, capture_output=True,
check=True,
shell=True) shell=True)
for line in cp.stdout.decode().split("\n"): for line in cp.stdout.decode().split("\n"):
s = line.strip() s = line.strip()
...@@ -49,6 +50,7 @@ def benchmark_ck(config, tuning): ...@@ -49,6 +50,7 @@ def benchmark_ck(config, tuning):
}, },
'compile_op': { 'compile_op': {
'name': 'ck_gemm', 'name': 'ck_gemm',
'check': True,
'tuning_val': tuning, 'tuning_val': tuning,
'inputs': config 'inputs': config
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment