Commit 151dd91a authored by turneram's avatar turneram
Browse files

Merge remote-tracking branch 'origin/ck-flash-attn' into gemm-perf

parents 280e76d0 5b2b7489
......@@ -1245,6 +1245,79 @@ TEST_CASE(nonzero_test)
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearadd_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAdd
migraphx::program p = migraphx::parse_onnx("qlinearadd_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape a{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_a = {0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24,
26, 28, 30, 32, 34, 36, 38, 40, 42, 44, 46, 48, 50,
52, 54, 56, 58, 60, 62, 64, 66, 68, 70, 72, 74, 76,
78, 80, 82, 84, 86, 88, 90, 92, 94, 96, 98, 100, 102,
104, 106, 108, 110, 112, 114, 116, 118, 120, 122, 124, 126};
migraphx::shape b{migraphx::shape::uint8_type, {64}};
std::vector<uint8_t> data_b = {128, 126, 124, 122, 120, 118, 116, 114, 112, 110, 108, 106, 104,
102, 100, 98, 96, 94, 92, 90, 88, 86, 84, 82, 80, 78,
76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52,
50, 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26,
24, 22, 20, 18, 16, 14, 12, 10, 8, 6, 4, 2};
migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();
std::vector<unsigned char> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<uint8_t> gold = {64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64,
64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(qlinearadd_bcast_test)
{
// github.com/microsoft/onnxruntime/blob/main/docs/ContribOperators.md#com.microsoft.QLinearAdd
migraphx::program p = migraphx::parse_onnx("qlinearadd_bcast_test.onnx");
p.compile(migraphx::make_target("ref"));
migraphx::shape a{migraphx::shape::int8_type, {64}};
std::vector<int8_t> data_a = {-64, -62, -60, -58, -56, -54, -52, -50, -48, -46, -44, -42, -40,
-38, -36, -34, -32, -30, -28, -26, -24, -22, -20, -18, -16, -14,
-12, -10, -8, -6, -4, -2, 0, 2, 4, 6, 8, 10, 12,
14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34, 36, 38,
40, 42, 44, 46, 48, 50, 52, 54, 56, 58, 60, 62};
migraphx::shape b{migraphx::shape::int8_type, {1, 1, 64}};
std::vector<int8_t> data_b = {96, 94, 92, 90, 88, 86, 84, 82, 80, 78, 76, 74, 72,
70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, 48, 46,
44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20,
18, 16, 14, 12, 10, 8, 6, 4, 2, 0, -2, -4, -6,
-8, -10, -12, -14, -16, -18, -20, -22, -24, -26, -28, -30};
migraphx::parameter_map pp;
pp["A"] = migraphx::argument(a, data_a.data());
pp["B"] = migraphx::argument(b, data_b.data());
auto result = p.eval(pp).back();
std::vector<int8_t> result_vector;
result.visit([&](auto output) { result_vector.assign(output.begin(), output.end()); });
std::vector<int8_t> gold = {-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64,
-64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64, -64};
EXPECT(migraphx::verify::verify_rms_range(result_vector, gold));
}
TEST_CASE(resize_downsample_f_test)
{
migraphx::program p = migraphx::parse_onnx("resize_downsample_f_test.onnx");
......
......@@ -44,8 +44,7 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_DUMP_TEST)
// An improved async, that doesn't block
template <class Function>
std::future<typename std::result_of<Function()>::type> detach_async(Function&& f,
bool parallel = true)
std::future<std::invoke_result_t<Function>> detach_async(Function&& f, bool parallel = true)
{
if(parallel)
{
......
......@@ -82,6 +82,27 @@ def parse_args():
default=False,
help='Turn on ort VERBOSE logging via session options')
parser.add_argument(
'--disable-offload-copy',
dest="offload_copy",
action='store_false',
default=True,
help=
'Disable offload copying (user must handle copy to and from device)')
parser.add_argument(
'--disable-fast-math',
dest="fast_math",
action='store_false',
default=True,
help='Disable fast math optimizations (etc: rewrite_gelu)')
parser.add_argument('--exhaustive_tune',
dest="exhaustive_tune",
action='store_true',
default=False,
help='Enable exhaustive tuning for solutions')
args = parser.parse_args()
return args
......@@ -177,7 +198,12 @@ def main():
print(model)
if not args.ort_run:
model.compile(migraphx.get_target(args.target))
model.compile(
migraphx.get_target(args.target),
offload_copy=args.offload_copy,
fast_math=args.fast_math,
exhaustive_tune=args.exhaustive_tune,
)
params = {}
test_inputs = {}
......
......@@ -38,26 +38,32 @@
#include <migraphx/register_op.hpp>
#include <migraphx/json.hpp>
#include <migraphx/convert_to_json.hpp>
#include <array>
#include <algorithm>
#include <cstdarg>
namespace migraphx {
#ifdef MIGRAPHX_BUILD_TESTING
static thread_local bool disable_exception_catch = false; // NOLINT
extern "C" MIGRAPHX_C_EXPORT void migraphx_test_private_disable_exception_catch(bool b)
{
disable_exception_catch = b;
}
#endif
template <class F>
migraphx_status try_(F f, bool output = true) // NOLINT
{
#ifdef MIGRAPHX_BUILD_TESTING
if(disable_exception_catch)
{
f();
}
else
{
#endif
try
{
f();
......@@ -81,7 +87,9 @@ migraphx_status try_(F f, bool output = true) // NOLINT
{
return migraphx_status_unknown_error;
}
#ifdef MIGRAPHX_BUILD_TESTING
}
#endif
return migraphx_status_success;
}
......
......@@ -26,6 +26,7 @@
#include <stdlib.h>
#include <stdbool.h>
#include <stdint.h>
#include <migraphx/api/export.h>
......
......@@ -2,7 +2,7 @@
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -27,11 +27,11 @@ import sys
debug = False
# The filetypes we want to check for that are stamped
# LICENSE is included here as it SHOULD have a liscence in it otherwise flag it as unstamped
# LICENSE is included here as it SHOULD have a license in it otherwise flag it as unstamped
supported_file_types = (".cpp", ".hpp", ".h", ".ipynb", ".py", ".txt", ".sh",
".bsh", "LICENSE", ".cmake")
#add general stuff we shouldn't stamp and any exceptions here
# add general stuff we shouldn't stamp and any exceptions here
unsupported_file_types = [
".onnx", ".pb", ".rst", ".jpg", ".jpeg", ".proto", ".md", ".clang",
".weight", ".ini", ".json", ".docker", ".git", ".rules", ".yml"
......@@ -40,105 +40,89 @@ unsupported_file_types = [
specificIgnores = ("digits.txt", "Dockerfile", "Jenkinsfile", "")
def hasKeySequence(inputfile, key_message):
result = False
def hasKeySequence(inputfile: str, key_message: str) -> bool:
if key_message in inputfile:
result = True
return result
return True
return False
#Simple just open and write stuff to each file with the license stamp
def openAndCheckFile(filename):
result = False
#open save old contents and append things here
if debug is True:
print("Open", filename, end='')
# Simple just open and write stuff to each file with the license stamp
def needStampCheck(filename: str) -> bool:
# open save old contents and append things here
if debug: print("Open", filename, end=' ')
try:
file = open(filename, 'r')
except OSError as e:
if debug is True:
print(str(e) + "....Open Error: Skipping file ")
if debug: print(str(e) + "....Open Error: Skipping file ")
file.close()
return
return False
else:
with file as contents:
try:
save = contents.read()
hasAmdLic = hasKeySequence(
save, "Advanced Micro Devices, Inc. All rights reserved")
#Check if we have a licence stamp already
if hasAmdLic is True:
if debug is True:
print("....Already Stamped: Skipping file ")
# Check if we have a license stamp already
if hasKeySequence(
save,
"Advanced Micro Devices, Inc. All rights reserved"):
if debug: print("....Already Stamped: Skipping file ")
contents.close()
result = True
return False
except UnicodeDecodeError as eu:
if debug is True:
print(str(eu) + "...Skipping binary file ")
if debug: print(f"{str(eu)}...Skipping binary file ")
contents.close()
result = True
return False
return result
return True
# Deterine if filename is desired in the fileTuple past in
def check_filename(filename, fileTuple):
supported = False
for key in fileTuple:
if key in filename:
supported = True
break
return supported
# Check if any element in fileTuple is in filename
def check_filename(filename: str, fileTuple: tuple or list) -> bool:
if any([x in filename for x in fileTuple]):
return True
return False
def main():
def main() -> None:
unsupported_file_types.extend(specificIgnores)
#Get a list of all the tracked files in our git repo
# Get a list of all the tracked files in our git repo
proc = subprocess.run("git ls-files --exclude-standard",
shell=True,
stdout=subprocess.PIPE)
fileList = proc.stdout.decode().split('\n')
if debug is True:
print("Target file list:\n" + str(fileList))
if debug: print("Target file list:\n" + str(fileList))
unsupportedFiles = []
unstampedFiles = []
unknownFiles = []
for file in fileList:
supported = check_filename(file, supported_file_types)
if supported is True:
isStamped = openAndCheckFile(file)
if isStamped is False:
if check_filename(file, supported_file_types):
if needStampCheck(file):
unstampedFiles.append(file)
elif check_filename(file, unsupported_file_types):
unsupportedFiles.append(file)
else:
unsupported = check_filename(file, unsupported_file_types)
if unsupported is True:
unsupportedFiles.append(file)
else:
unknownFiles.append(file)
unknownFiles.append(file)
#Do a bunch of checks based on our file lists
# Do a bunch of checks based on our file lists
if len(unstampedFiles) > 0:
print("Error: The following " + str(len(unstampedFiles)) +
print("\nError: The following " + str(len(unstampedFiles)) +
" files are currently without a license:")
print(str(unstampedFiles))
sys.exit(1)
if len(unknownFiles) > 0:
print("Error: The following " + str(len(unknownFiles)) +
print("\nError: The following " + str(len(unknownFiles)) +
" files not handled:")
print(str(unknownFiles))
sys.exit(2)
sys.exit(0)
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment