"ts/webui/package.json" did not exist on "5e17852435e3620c01697f75ce3164cde1475747"
Unverified Commit bfa373f6 authored by Paul Fultz II's avatar Paul Fultz II Committed by GitHub
Browse files

Merge branch 'develop' into jit-layernorm

parents 7cc97df7 5bf4dee6
......@@ -23,7 +23,7 @@
*/
#include <algorithm>
#include <hip/hip_runtime.h>
#include <rocblas.h>
#include <rocblas/rocblas.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
......@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx::arguments args) const override
{
// create rocblas stream handle
auto rocblas_handle = create_rocblas_handle_ptr(ctx);
rocblas_int n = args[1].get_shape().lengths()[0];
float* alpha = reinterpret_cast<float*>(args[0].data());
float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rocblas_handle, n, alpha, vec_ptr, 1));
auto rb_handle = create_rocblas_handle_ptr(ctx);
MIGRAPHX_ROCBLAS_ASSERT(rocblas_set_pointer_mode(rb_handle, rocblas_pointer_mode_device));
rocblas_int n = args[1].get_shape().lengths()[0];
float* alpha = reinterpret_cast<float*>(args[0].data());
float* vec_ptr = reinterpret_cast<float*>(args[1].data());
MIGRAPHX_ROCBLAS_ASSERT(rocblas_sscal(rb_handle, n, alpha, vec_ptr, 1));
MIGRAPHX_ROCBLAS_ASSERT(rocblas_destroy_handle(rb_handle));
return args[1];
}
......
......@@ -22,11 +22,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
#####################################################################################
import subprocess
import subprocess, os
#Debug flag
debug = False
__repo_dir__ = os.path.normpath(
os.path.join(os.path.realpath(__file__), '..', '..'))
# Markdown code blob we should use to insert into notebook files
def getipynb_markdownBlockAsList():
......@@ -222,14 +225,15 @@ def getDelimiter(filename):
def main():
message = open('LICENSE').read()
message = open(os.path.join(__repo_dir__, 'LICENSE')).read()
#Get a list of all the files in our git repo
#bashCommand = "git ls-files --exclude-standard"
#print (bashCommand.split())
proc = subprocess.run("git ls-files --exclude-standard",
shell=True,
stdout=subprocess.PIPE)
stdout=subprocess.PIPE,
cwd=__repo_dir__)
fileList = proc.stdout.decode().split('\n')
message = message.split('\n')
......@@ -237,7 +241,8 @@ def main():
print("Target file list:\n" + str(fileList))
print("Output Message:\n" + str(message))
for file in fileList:
for rfile in fileList:
file = os.path.join(__repo_dir__, rfile)
#print(file)
commentDelim = getDelimiter(file)
if commentDelim is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment