Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
bfa373f6
Unverified
Commit
bfa373f6
authored
Aug 10, 2022
by
Paul Fultz II
Committed by
GitHub
Aug 10, 2022
Browse files
Merge branch 'develop' into jit-layernorm
parents
7cc97df7
5bf4dee6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
17 additions
and
10 deletions
+17
-10
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
...phx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
+8
-6
tools/license_stamper.py
tools/license_stamper.py
+9
-4
No files found.
examples/migraphx/custom_op_rocblas_kernel/custom_op_rocblas_kernel.cpp
View file @
bfa373f6
...
@@ -23,7 +23,7 @@
...
@@ -23,7 +23,7 @@
*/
*/
#include <algorithm>
#include <algorithm>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <rocblas.h>
#include <rocblas
/rocblas
.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.h>
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <migraphx/migraphx.hpp> // MIGraphX's C++ API
#include <numeric>
#include <numeric>
...
@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
...
@@ -56,11 +56,13 @@ struct sscal_custom_op final : migraphx::experimental_custom_op_base
migraphx
::
arguments
args
)
const
override
migraphx
::
arguments
args
)
const
override
{
{
// create rocblas stream handle
// create rocblas stream handle
auto
rocblas_handle
=
create_rocblas_handle_ptr
(
ctx
);
auto
rb_handle
=
create_rocblas_handle_ptr
(
ctx
);
rocblas_int
n
=
args
[
1
].
get_shape
().
lengths
()[
0
];
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_set_pointer_mode
(
rb_handle
,
rocblas_pointer_mode_device
));
float
*
alpha
=
reinterpret_cast
<
float
*>
(
args
[
0
].
data
());
rocblas_int
n
=
args
[
1
].
get_shape
().
lengths
()[
0
];
float
*
vec_ptr
=
reinterpret_cast
<
float
*>
(
args
[
1
].
data
());
float
*
alpha
=
reinterpret_cast
<
float
*>
(
args
[
0
].
data
());
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_sscal
(
rocblas_handle
,
n
,
alpha
,
vec_ptr
,
1
));
float
*
vec_ptr
=
reinterpret_cast
<
float
*>
(
args
[
1
].
data
());
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_sscal
(
rb_handle
,
n
,
alpha
,
vec_ptr
,
1
));
MIGRAPHX_ROCBLAS_ASSERT
(
rocblas_destroy_handle
(
rb_handle
));
return
args
[
1
];
return
args
[
1
];
}
}
...
...
tools/license_stamper.py
View file @
bfa373f6
...
@@ -22,11 +22,14 @@
...
@@ -22,11 +22,14 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# THE SOFTWARE.
#####################################################################################
#####################################################################################
import
subprocess
import
subprocess
,
os
#Debug flag
#Debug flag
debug
=
False
debug
=
False
__repo_dir__
=
os
.
path
.
normpath
(
os
.
path
.
join
(
os
.
path
.
realpath
(
__file__
),
'..'
,
'..'
))
# Markdown code blob we should use to insert into notebook files
# Markdown code blob we should use to insert into notebook files
def
getipynb_markdownBlockAsList
():
def
getipynb_markdownBlockAsList
():
...
@@ -222,14 +225,15 @@ def getDelimiter(filename):
...
@@ -222,14 +225,15 @@ def getDelimiter(filename):
def
main
():
def
main
():
message
=
open
(
'LICENSE'
).
read
()
message
=
open
(
os
.
path
.
join
(
__repo_dir__
,
'LICENSE'
)
)
.
read
()
#Get a list of all the files in our git repo
#Get a list of all the files in our git repo
#bashCommand = "git ls-files --exclude-standard"
#bashCommand = "git ls-files --exclude-standard"
#print (bashCommand.split())
#print (bashCommand.split())
proc
=
subprocess
.
run
(
"git ls-files --exclude-standard"
,
proc
=
subprocess
.
run
(
"git ls-files --exclude-standard"
,
shell
=
True
,
shell
=
True
,
stdout
=
subprocess
.
PIPE
)
stdout
=
subprocess
.
PIPE
,
cwd
=
__repo_dir__
)
fileList
=
proc
.
stdout
.
decode
().
split
(
'
\n
'
)
fileList
=
proc
.
stdout
.
decode
().
split
(
'
\n
'
)
message
=
message
.
split
(
'
\n
'
)
message
=
message
.
split
(
'
\n
'
)
...
@@ -237,7 +241,8 @@ def main():
...
@@ -237,7 +241,8 @@ def main():
print
(
"Target file list:
\n
"
+
str
(
fileList
))
print
(
"Target file list:
\n
"
+
str
(
fileList
))
print
(
"Output Message:
\n
"
+
str
(
message
))
print
(
"Output Message:
\n
"
+
str
(
message
))
for
file
in
fileList
:
for
rfile
in
fileList
:
file
=
os
.
path
.
join
(
__repo_dir__
,
rfile
)
#print(file)
#print(file)
commentDelim
=
getDelimiter
(
file
)
commentDelim
=
getDelimiter
(
file
)
if
commentDelim
is
not
None
:
if
commentDelim
is
not
None
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment