Unverified Commit bfa373f6 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into jit-layernorm

parents 7cc97df7 5bf4dee6
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
*/ */
#include <algorithm> #include <algorithm>
#include <hip/hip_runtime.h> #include <hip/hip_runtime.h>
#include <rocblas.h> #include <rocblas/rocblas.h>
#include <migraphx/migraphx.h> #include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API #include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric> #include <numeric>
...@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base ...@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx::arguments args) const override migraphx::arguments args) const override
{ {
// create rocblas stream handle // create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx); auto rb_handle = create_rocblas_handle_ptr(ctx);
rocblas_int n = args[1].get_shape().lengths()[0]; MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_pointer_mode(rb_handle, rocblas_pointer_mode_device));
float* alpha = reinterpret_cast<float*>(args[0].data()); rocblas_int n = args[1].get_shape().lengths()[0];
float* vec_ptr = reinterpret_cast<float*>(args[1].data()); float* alpha = reinterpret_cast<float*>(args[0].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1)); float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rb_handle, n, alpha, vec_ptr, 1));
MIGRAPHX_ROCBLAS_ASSERT(rocblas_destroy_handle(rb_handle));
return args[1]; return args[1];
} }
......
...@@ -22,11 +22,14 @@ ...@@ -22,11 +22,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE. # THE SOFTWARE.
##################################################################################### #####################################################################################
import subprocess import subprocess, os
#Debug flag #Debug flag
debug = False debug = False
__repo_dir__ = os.path.normpath(
os.path.join(os.path.realpath(__file__), '..', '..'))
# Markdown code blob we should use to insert into notebook files # Markdown code blob we should use to insert into notebook files
def getipynb_markdownBlockAsList(): def getipynb_markdownBlockAsList():
...@@ -222,14 +225,15 @@ def getDelimiter(filename): ...@@ -222,14 +225,15 @@ def getDelimiter(filename):
def main(): def main():
message = open('LICENSE').read() message = open(os.path.join(__repo_dir__, 'LICENSE')).read()
#Get a list of all the files in our git repo #Get a list of all the files in our git repo
#bashCommand = "git ls-files --exclude-standard" #bashCommand = "git ls-files --exclude-standard"
#print (bashCommand.split()) #print (bashCommand.split())
proc = subprocess.run("git ls-files --exclude-standard", proc = subprocess.run("git ls-files --exclude-standard",
shell=True, shell=True,
stdout=subprocess.PIPE) stdout=subprocess.PIPE,
cwd=__repo_dir__)
fileList = proc.stdout.decode().split('\n') fileList = proc.stdout.decode().split('\n')
message = message.split('\n') message = message.split('\n')
...@@ -237,7 +241,8 @@ def main(): ...@@ -237,7 +241,8 @@ def main():
print("Target file list:\n" + str(fileList)) print("Target file list:\n" + str(fileList))
print("Output Message:\n" + str(message)) print("Output Message:\n" + str(message))
for file in fileList: for rfile in fileList:
file = os.path.join(__repo_dir__, rfile)
#print(file) #print(file)
commentDelim = getDelimiter(file) commentDelim = getDelimiter(file)
if commentDelim is not None: if commentDelim is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment