"test/vscode:/vscode.git/clone" did not exist on "2c8149f684edaebbace930db34b8c39e4d47b5b7"
Commit 4b189c12 authored by Alan Turner's avatar Alan Turner
Browse files

Add flag to use exclusively CK gemms

parent 70b7a68f
......@@ -34,6 +34,8 @@ struct module;
namespace gpu {
MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_USE_CK_ONLY);
struct ck_gemm
{
operation op = make_op("dot");
......@@ -106,7 +108,7 @@ MIGRAPHX_PRED_MATCHER(is_ck_gemm, instruction_ref ins)
// Skipping GEMMs with a K dimension greater than 2048 is a course-grained strategy
// to avoid poor-performing GEMM kernels from CK
// To-do: Investigate a more precise strategy
return k <= 2048;
return k <= 2048 or enabled(MIGRAPHX_USE_CK_ONLY{});
}
struct find_ck_gemm_pointwise
......
......@@ -115,9 +115,10 @@ def get_gemm_time(config, fp16, provider, timeout):
env=dict(os.environ,
MIGRAPHX_ENABLE_CK=use_CK,
MIGRAPHX_ENABLE_MLIR=use_MLIR,
MIGRAPHX_USE_CK_ONLY="1",
MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
except Exception as e:
print(f"{provider.name} encountered and exception {e}")
print(f"{provider.name} encountered an exception: {e}")
return -100.0
if verify_output(str(out.stdout), provider):
......@@ -147,7 +148,7 @@ def get_gemm_softmax_gemm_time(config, provider, timeout):
MIGRAPHX_ENABLE_MLIR=use_MLIR,
MIGRAPHX_MLIR_USE_SPECIFIC_OPS="dot"))
except Exception as e:
print(f"{provider.name} encountered and exception {e}")
print(f"{provider.name} encountered an exception: {e}")
return -100.0
if verify_output(str(out.stdout), provider):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment