Commit 337787d6 authored by Khalique Ahmed's avatar Khalique Ahmed
Browse files

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX...

Merge branch 'develop' of https://github.com/ROCmSoftwarePlatform/AMDMIGraphX into bcast_transpose_generic
parents 827d22e9 c7f0fbc4
#####################################################################################
# The MIT License (MIT)
#
# Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
......@@ -23,10 +23,55 @@
#####################################################################################
import migraphx
import ctypes
import os
import glob
def test_conv_relu():
hip = ctypes.cdll.LoadLibrary("libamdhip64.so")
# Full path of the library is needed to fix an issue on sles
# where the library is not loaded otherwise.
# We check for the presence of library at the following paths,
# in the order listed below:
#
# 1. 'rocm_path' environment variable
# 2. /opt/rocm
# 3. /opt/rocm-*
#
# If the library is not found at any of these paths, we fall back
# to the library path being detected automatically.
library = "libamdhip64.so"
# Environment variable containing path to rocm
rocm_path_env_var = "rocm_path"
# Check for rocm_path, default to /opt/rocm if it does not exist.
rocm_path_var = os.getenv(rocm_path_env_var, default="/opt/rocm")
# Join the paths to the library to get full path,
# e.g. /opt/rocm/lib/libamdhip64.so
library_file = os.path.join(rocm_path_var, "lib", library)
# Check if the library file exists at the specified path
if os.path.exists(library_file):
# Replace library name by full path to the library
library = library_file
else:
# Pattern match to look for path to different
# rocm versions: /opt/rocm-*
rocm_path_pattern = "/opt/rocm-*/lib/libamdhip64.so"
matching_libraries = glob.glob(rocm_path_pattern)
if matching_libraries:
# Replace library name by full path to the first
# library found.
library = matching_libraries[0]
# Loads library either by using the full path to the
# library, if it has been detected earlier,
# or, proceeds to load the library based on the name
# of the library.
hip = ctypes.cdll.LoadLibrary(library)
p = migraphx.parse_onnx("conv_relu_maxpool_test.onnx")
print(p)
......
......@@ -82,6 +82,27 @@ def parse_args():
default=False,
help='Turn on ort VERBOSE logging via session options')
parser.add_argument(
'--disable-offload-copy',
dest="offload_copy",
action='store_false',
default=True,
help=
'Disable offload copying (user must handle copy to and from device)')
parser.add_argument(
'--disable-fast-math',
dest="fast_math",
action='store_false',
default=True,
help='Disable fast math optimizations (etc: rewrite_gelu)')
parser.add_argument('--exhaustive_tune',
dest="exhaustive_tune",
action='store_true',
default=False,
help='Enable exhaustive tuning for solutions')
args = parser.parse_args()
return args
......@@ -177,7 +198,12 @@ def main():
print(model)
if not args.ort_run:
model.compile(migraphx.get_target(args.target))
model.compile(
migraphx.get_target(args.target),
offload_copy=args.offload_copy,
fast_math=args.fast_math,
exhaustive_tune=args.exhaustive_tune,
)
params = {}
test_inputs = {}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment