Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d6b4ae77
Commit
d6b4ae77
authored
Mar 10, 2022
by
Shucai Xiao
Browse files
merge optimization to print flops branch
parents
bdf91961
abe2a889
Changes
106
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
495 additions
and
146 deletions
+495
-146
src/onnx/parse_scatternd.cpp
src/onnx/parse_scatternd.cpp
+33
-0
src/program.cpp
src/program.cpp
+12
-15
src/py/migraphx_py.cpp
src/py/migraphx_py.cpp
+84
-80
src/reduce_dims.cpp
src/reduce_dims.cpp
+1
-3
src/shape.cpp
src/shape.cpp
+11
-0
src/simplify_algebra.cpp
src/simplify_algebra.cpp
+0
-1
src/simplify_reshapes.cpp
src/simplify_reshapes.cpp
+0
-1
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
+0
-1
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+2
-1
src/targets/gpu/compile_scatternd.cpp
src/targets/gpu/compile_scatternd.cpp
+60
-0
src/targets/gpu/device/add.cpp
src/targets/gpu/device/add.cpp
+44
-1
src/targets/gpu/device/contiguous.cpp
src/targets/gpu/device/contiguous.cpp
+36
-0
src/targets/gpu/device/gelu.cpp
src/targets/gpu/device/gelu.cpp
+60
-4
src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp
...ts/gpu/device/include/migraphx/gpu/device/multi_index.hpp
+4
-3
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
+6
-4
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
+7
-6
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
+5
-10
src/targets/gpu/device/layernorm.cpp
src/targets/gpu/device/layernorm.cpp
+5
-5
src/targets/gpu/device/mul.cpp
src/targets/gpu/device/mul.cpp
+42
-9
src/targets/gpu/device/mul_add.cpp
src/targets/gpu/device/mul_add.cpp
+83
-2
No files found.
src/onnx/parse_scatternd.cpp
0 → 100644
View file @
d6b4ae77
#include <migraphx/onnx/op_parser.hpp>
#include <migraphx/instruction.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/make_op.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
onnx
{
struct
parse_scatternd
:
op_parser
<
parse_scatternd
>
{
std
::
vector
<
op_desc
>
operators
()
const
{
return
{{
"ScatterND"
}};
}
instruction_ref
parse
(
const
op_desc
&
/*opd*/
,
const
onnx_parser
&
/*parser*/
,
const
onnx_parser
::
node_info
&
info
,
std
::
vector
<
instruction_ref
>&
args
)
const
{
if
(
contains
(
info
.
attributes
,
"reduction"
))
{
if
(
info
.
attributes
.
at
(
"reduction"
).
s
()
==
"add"
)
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_add"
),
args
);
if
(
info
.
attributes
.
at
(
"reduction"
).
s
()
==
"mul"
)
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_mul"
),
args
);
}
return
info
.
add_instruction
(
migraphx
::
make_op
(
"scatternd_none"
),
args
);
}
};
}
// namespace onnx
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/program.cpp
View file @
d6b4ae77
...
@@ -947,9 +947,9 @@ void program::perf_report(std::ostream& os,
...
@@ -947,9 +947,9 @@ void program::perf_report(std::ostream& os,
}
}
void
program
::
debug_print
()
const
{
std
::
cout
<<
*
this
<<
std
::
endl
;
}
void
program
::
debug_print
()
const
{
std
::
cout
<<
*
this
<<
std
::
endl
;
}
void
program
::
debug_print
(
instruction_ref
ins
)
const
void
program
::
debug_print
(
instruction_ref
ins
,
const
std
::
unordered_map
<
instruction_ref
,
std
::
string
>&
ins_names
)
const
{
{
std
::
unordered_map
<
instruction_ref
,
std
::
string
>
names
;
if
(
std
::
any_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
pp
)
{
if
(
std
::
any_of
(
this
->
impl
->
modules
.
begin
(),
this
->
impl
->
modules
.
end
(),
[
&
](
const
auto
&
pp
)
{
return
is_end
(
pp
.
second
.
end
(),
ins
);
return
is_end
(
pp
.
second
.
end
(),
ins
);
}))
}))
...
@@ -965,14 +965,10 @@ void program::debug_print(instruction_ref ins) const
...
@@ -965,14 +965,10 @@ void program::debug_print(instruction_ref ins) const
return
;
return
;
}
}
std
::
stringstream
ss
;
if
(
contains
(
ins_names
,
ins
))
this
->
print
(
names
,
[
&
](
auto
x
,
auto
ins_names
)
{
{
if
(
x
==
ins
)
instruction
::
print
(
std
::
cout
,
ins
,
ins_names
);
{
}
instruction
::
print
(
std
::
cout
,
x
,
ins_names
);
std
::
cout
<<
std
::
endl
;
}
});
}
}
void
program
::
debug_print
(
std
::
ostream
&
os
,
void
program
::
debug_print
(
std
::
ostream
&
os
,
...
@@ -1078,11 +1074,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
...
@@ -1078,11 +1074,12 @@ void generic_get_unused_modules(Map& m, const std::vector<T*>& mods, OutputItera
std
::
transform
(
mods
.
begin
(),
mods
.
end
(),
std
::
inserter
(
used
,
used
.
end
()),
[](
auto
&&
mod
)
{
std
::
transform
(
mods
.
begin
(),
mods
.
end
(),
std
::
inserter
(
used
,
used
.
end
()),
[](
auto
&&
mod
)
{
return
mod
->
name
();
return
mod
->
name
();
});
});
transform_if
(
m
.
begin
(),
transform_if
(
m
.
end
(),
m
.
begin
(),
out
,
m
.
end
(),
[
&
](
auto
&&
pp
)
{
return
not
contains
(
used
,
pp
.
first
);
},
out
,
[](
auto
&&
pp
)
{
return
&
pp
.
second
;
});
[
&
](
auto
&&
pp
)
{
return
not
contains
(
used
,
pp
.
first
);
},
[](
auto
&&
pp
)
{
return
&
pp
.
second
;
});
}
}
std
::
vector
<
const
module
*>
program
::
get_modules
()
const
std
::
vector
<
const
module
*>
program
::
get_modules
()
const
...
...
src/py/migraphx_py.cpp
View file @
d6b4ae77
...
@@ -303,86 +303,90 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
...
@@ -303,86 +303,90 @@ MIGRAPHX_PYBIND11_MODULE(migraphx, m)
.
def
(
"name"
,
&
migraphx
::
operation
::
name
);
.
def
(
"name"
,
&
migraphx
::
operation
::
name
);
m
.
def
(
"parse_tf"
,
m
.
def
(
[](
const
std
::
string
&
filename
,
"parse_tf"
,
bool
is_nhwc
,
[](
const
std
::
string
&
filename
,
unsigned
int
batch_size
,
bool
is_nhwc
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
unsigned
int
batch_size
,
std
::
vector
<
std
::
string
>
output_names
)
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
return
migraphx
::
parse_tf
(
std
::
vector
<
std
::
string
>
output_names
)
{
filename
,
return
migraphx
::
parse_tf
(
migraphx
::
tf_options
{
is_nhwc
,
batch_size
,
map_input_dims
,
output_names
});
filename
,
migraphx
::
tf_options
{
is_nhwc
,
batch_size
,
map_input_dims
,
output_names
});
},
},
"Parse tf protobuf (default format is nhwc)"
,
"Parse tf protobuf (default format is nhwc)"
,
py
::
arg
(
"filename"
),
py
::
arg
(
"filename"
),
py
::
arg
(
"is_nhwc"
)
=
true
,
py
::
arg
(
"is_nhwc"
)
=
true
,
py
::
arg
(
"batch_size"
)
=
1
,
py
::
arg
(
"batch_size"
)
=
1
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"output_names"
)
=
std
::
vector
<
std
::
string
>
());
py
::
arg
(
"output_names"
)
=
std
::
vector
<
std
::
string
>
());
m
.
def
(
"parse_onnx"
,
m
.
def
(
[](
const
std
::
string
&
filename
,
"parse_onnx"
,
unsigned
int
default_dim_value
,
[](
const
std
::
string
&
filename
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
unsigned
int
default_dim_value
,
bool
skip_unknown_operators
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
bool
print_program_on_error
,
bool
skip_unknown_operators
,
int64_t
max_loop_iterations
)
{
bool
print_program_on_error
,
migraphx
::
onnx_options
options
;
int64_t
max_loop_iterations
)
{
options
.
default_dim_value
=
default_dim_value
;
migraphx
::
onnx_options
options
;
options
.
map_input_dims
=
map_input_dims
;
options
.
default_dim_value
=
default_dim_value
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
map_input_dims
=
map_input_dims
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
max_loop_iterations
=
max_loop_iterations
;
options
.
print_program_on_error
=
print_program_on_error
;
return
migraphx
::
parse_onnx
(
filename
,
options
);
options
.
max_loop_iterations
=
max_loop_iterations
;
},
return
migraphx
::
parse_onnx
(
filename
,
options
);
"Parse onnx file"
,
},
py
::
arg
(
"filename"
),
"Parse onnx file"
,
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"filename"
),
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
py
::
arg
(
"print_program_on_error"
)
=
false
,
py
::
arg
(
"max_loop_iterations"
)
=
10
);
m
.
def
(
"parse_onnx_buffer"
,
[](
const
std
::
string
&
onnx_buffer
,
m
.
def
(
unsigned
int
default_dim_value
,
"parse_onnx_buffer"
,
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
[](
const
std
::
string
&
onnx_buffer
,
bool
skip_unknown_operators
,
unsigned
int
default_dim_value
,
bool
print_program_on_error
)
{
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
map_input_dims
,
migraphx
::
onnx_options
options
;
bool
skip_unknown_operators
,
options
.
default_dim_value
=
default_dim_value
;
bool
print_program_on_error
)
{
options
.
map_input_dims
=
map_input_dims
;
migraphx
::
onnx_options
options
;
options
.
skip_unknown_operators
=
skip_unknown_operators
;
options
.
default_dim_value
=
default_dim_value
;
options
.
print_program_on_error
=
print_program_on_error
;
options
.
map_input_dims
=
map_input_dims
;
return
migraphx
::
parse_onnx_buffer
(
onnx_buffer
,
options
);
options
.
skip_unknown_operators
=
skip_unknown_operators
;
},
options
.
print_program_on_error
=
print_program_on_error
;
"Parse onnx file"
,
return
migraphx
::
parse_onnx_buffer
(
onnx_buffer
,
options
);
py
::
arg
(
"filename"
),
},
py
::
arg
(
"default_dim_value"
)
=
1
,
"Parse onnx file"
,
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"filename"
),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
py
::
arg
(
"default_dim_value"
)
=
1
,
py
::
arg
(
"print_program_on_error"
)
=
false
);
py
::
arg
(
"map_input_dims"
)
=
std
::
unordered_map
<
std
::
string
,
std
::
vector
<
std
::
size_t
>>
(),
py
::
arg
(
"skip_unknown_operators"
)
=
false
,
m
.
def
(
"load"
,
py
::
arg
(
"print_program_on_error"
)
=
false
);
[](
const
std
::
string
&
name
,
const
std
::
string
&
format
)
{
migraphx
::
file_options
options
;
m
.
def
(
options
.
format
=
format
;
"load"
,
return
migraphx
::
load
(
name
,
options
);
[](
const
std
::
string
&
name
,
const
std
::
string
&
format
)
{
},
migraphx
::
file_options
options
;
"Load MIGraphX program"
,
options
.
format
=
format
;
py
::
arg
(
"filename"
),
return
migraphx
::
load
(
name
,
options
);
py
::
arg
(
"format"
)
=
"msgpack"
);
},
"Load MIGraphX program"
,
m
.
def
(
"save"
,
py
::
arg
(
"filename"
),
[](
const
migraphx
::
program
&
p
,
const
std
::
string
&
name
,
const
std
::
string
&
format
)
{
py
::
arg
(
"format"
)
=
"msgpack"
);
migraphx
::
file_options
options
;
options
.
format
=
format
;
m
.
def
(
return
migraphx
::
save
(
p
,
name
,
options
);
"save"
,
},
[](
const
migraphx
::
program
&
p
,
const
std
::
string
&
name
,
const
std
::
string
&
format
)
{
"Save MIGraphX program"
,
migraphx
::
file_options
options
;
py
::
arg
(
"p"
),
options
.
format
=
format
;
py
::
arg
(
"filename"
),
return
migraphx
::
save
(
p
,
name
,
options
);
py
::
arg
(
"format"
)
=
"msgpack"
);
},
"Save MIGraphX program"
,
py
::
arg
(
"p"
),
py
::
arg
(
"filename"
),
py
::
arg
(
"format"
)
=
"msgpack"
);
m
.
def
(
"get_target"
,
&
migraphx
::
make_target
);
m
.
def
(
"get_target"
,
&
migraphx
::
make_target
);
m
.
def
(
"generate_argument"
,
&
migraphx
::
generate_argument
,
py
::
arg
(
"s"
),
py
::
arg
(
"seed"
)
=
0
);
m
.
def
(
"generate_argument"
,
&
migraphx
::
generate_argument
,
py
::
arg
(
"s"
),
py
::
arg
(
"seed"
)
=
0
);
...
...
src/reduce_dims.cpp
View file @
d6b4ae77
...
@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -39,9 +39,7 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
{
{
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
{
}
return
n
+
1
;
return
n
+
1
;
}
}
...
...
src/shape.cpp
100755 → 100644
View file @
d6b4ae77
...
@@ -86,6 +86,8 @@ struct shape_impl
...
@@ -86,6 +86,8 @@ struct shape_impl
return
std
::
accumulate
(
return
std
::
accumulate
(
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
m_lens
.
begin
(),
m_lens
.
end
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
}
std
::
shared_ptr
<
shape_impl
>
copy
()
const
{
return
std
::
make_shared
<
shape_impl
>
(
*
this
);
}
};
};
const
std
::
vector
<
shape
::
type_t
>&
shape
::
types
()
const
std
::
vector
<
shape
::
type_t
>&
shape
::
types
()
...
@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
...
@@ -135,6 +137,8 @@ shape::shape(type_t t, std::vector<std::size_t> l, std::vector<std::size_t> s)
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
const
std
::
vector
<
shape
>&
subs
)
:
impl
(
std
::
make_shared
<
shape_impl
>
(
subs
))
{}
shape
::
shape
(
std
::
shared_ptr
<
shape_impl
>
pimpl
)
:
impl
(
std
::
move
(
pimpl
))
{}
shape
shape
::
from_permutation
(
type_t
t
,
shape
shape
::
from_permutation
(
type_t
t
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
std
::
size_t
>&
l
,
const
std
::
vector
<
int64_t
>&
perm
)
const
std
::
vector
<
int64_t
>&
perm
)
...
@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
...
@@ -294,6 +298,13 @@ shape shape::with_lens(const std::vector<std::size_t>& l) const
return
this
->
with_lens
(
this
->
type
(),
l
);
return
this
->
with_lens
(
this
->
type
(),
l
);
}
}
shape
shape
::
with_type
(
type_t
t
)
const
{
auto
c
=
impl
->
copy
();
c
->
m_type
=
t
;
return
{
c
};
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
size_t
shape
::
element_space
()
const
{
return
impl
->
element_space
();
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
std
::
string
shape
::
type_string
()
const
{
return
name
(
this
->
type
());
}
...
...
src/simplify_algebra.cpp
View file @
d6b4ae77
...
@@ -335,7 +335,6 @@ struct find_concat_op
...
@@ -335,7 +335,6 @@ struct find_concat_op
}
}
auto
y
=
p
.
insert_instruction
(
ins
,
op
,
concats
);
auto
y
=
p
.
insert_instruction
(
ins
,
op
,
concats
);
return
{
y
};
return
{
y
};
};
};
std
::
vector
<
instruction_ref
>
args
;
std
::
vector
<
instruction_ref
>
args
;
...
...
src/simplify_reshapes.cpp
View file @
d6b4ae77
...
@@ -327,7 +327,6 @@ struct find_nested_concat
...
@@ -327,7 +327,6 @@ struct find_nested_concat
else
else
args
.
push_back
(
i
);
args
.
push_back
(
i
);
}
}
})(
ins
->
inputs
());
})(
ins
->
inputs
());
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
p
.
replace_instruction
(
ins
,
ins
->
get_operator
(),
args
);
}
}
...
...
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
View file @
d6b4ae77
...
@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
...
@@ -213,7 +213,6 @@ template <std::size_t N, class... Xs>
bool
is_vectorizable
(
const
Xs
&
...
xs
)
bool
is_vectorizable
(
const
Xs
&
...
xs
)
{
{
return
all_of
({
xs
...},
[](
const
auto
&
s
)
{
return
all_of
({
xs
...},
[](
const
auto
&
s
)
{
if
(
s
.
standard
()
and
(
s
.
lens
().
back
()
%
N
)
==
0
)
if
(
s
.
standard
()
and
(
s
.
lens
().
back
()
%
N
)
==
0
)
return
true
;
return
true
;
if
(
s
.
broadcasted
())
if
(
s
.
broadcasted
())
...
...
src/targets/gpu/CMakeLists.txt
View file @
d6b4ae77
...
@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
...
@@ -11,7 +11,7 @@ if(NOT TARGET MIOpen)
endif
()
endif
()
include
(
Embed
)
include
(
Embed
)
file
(
GLOB KERNEL_FILES
file
(
GLOB KERNEL_FILES
${
CONFIGURE_DEPENDS
}
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
${
CMAKE_CURRENT_SOURCE_DIR
}
/kernels/include/migraphx/kernels/*.hpp
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
message
(
STATUS
"KERNEL_FILES:
${
KERNEL_FILES
}
"
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
)
add_embed_library
(
migraphx_kernels
${
KERNEL_FILES
}
)
...
@@ -133,6 +133,7 @@ add_library(migraphx_gpu
...
@@ -133,6 +133,7 @@ add_library(migraphx_gpu
compile_hip_code_object.cpp
compile_hip_code_object.cpp
compile_pointwise.cpp
compile_pointwise.cpp
compile_roialign.cpp
compile_roialign.cpp
compile_scatternd.cpp
concat.cpp
concat.cpp
convert.cpp
convert.cpp
convolution.cpp
convolution.cpp
...
...
src/targets/gpu/compile_scatternd.cpp
0 → 100644
View file @
d6b4ae77
#include <migraphx/gpu/compile_scatternd.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
// NOLINTNEXTLINE
static
const
char
*
const
scatternd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
namespace migraphx {
extern "C" {
__global__ void scatternd_kernel(void* in_indices, void* in_updates, void* output)
{
make_tensors()(in_indices, in_updates, output)([](auto&&... xs) {
scatternd(xs..., REDUCTION);
});
}
}
} // namespace migraphx
int main() {}
)__migraphx__"
;
operation
compile_scatternd
(
context
&
,
const
std
::
vector
<
shape
>&
io_shapes
,
const
std
::
string
&
reduction
)
{
hip_compile_options
options
;
auto
out_s
=
io_shapes
.
back
();
options
.
local
=
1024
;
options
.
global
=
compute_global
(
io_shapes
.
at
(
1
).
elements
(),
options
.
local
);
options
.
inputs
=
io_shapes
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"scatternd_kernel"
;
options
.
virtual_inputs
=
io_shapes
;
options
.
params
+=
" -DREDUCTION=assign_"
+
reduction
+
"{}"
;
return
compile_hip_code_object
(
scatternd_kernel
,
options
);
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/device/add.cpp
View file @
d6b4ae77
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/add.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
auto
stride
=
ss
.
at
(
1
).
strides
();
return
(
stride
[
0
]
==
0
);
}
return
false
;
}
__global__
void
add_kernel
(
void
*
a
,
void
*
b
,
int
n_dim
,
void
*
r
,
int
n
)
{
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
int
idb
=
tid
%
n_dim
;
hr
[
tid
]
=
__hadd2
(
ha
[
tid
],
hb
[
idb
]);
}
}
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
void
add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
{
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
+
y
;
});
auto
sr
=
result
.
get_shape
();
std
::
vector
<
shape
>
ss
;
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
if
(
sr
.
type
()
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
add_kernel
<<<
block_num
,
block_size
,
0
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
+
y
;
});
}
}
}
void
add
(
hipStream_t
stream
,
void
add
(
hipStream_t
stream
,
...
...
src/targets/gpu/device/contiguous.cpp
View file @
d6b4ae77
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/contiguous.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/permutation.hpp>
#include <hip/hip_fp16.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
__global__
void
cont_kernel
(
void
*
in
,
void
*
out
,
int
os1
,
int
os2
,
int
os3
,
int
is1
,
int
is2
,
int
is3
)
{
int
i1
=
blockIdx
.
x
;
int
i2
=
blockIdx
.
y
;
int
i3
=
blockIdx
.
z
;
int
i4
=
threadIdx
.
x
;
__half
*
in_ptr
=
reinterpret_cast
<
__half
*>
(
in
);
__half
*
out_ptr
=
reinterpret_cast
<
__half
*>
(
out
);
int
out_idx
=
i1
*
os1
+
i2
*
os2
+
i3
*
os3
+
i4
;
int
in_idx
=
i1
*
is1
+
i2
*
is2
+
i3
*
is3
+
i4
;
out_ptr
[
out_idx
]
=
in_ptr
[
in_idx
];
}
void
contiguous_nonstandard
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
void
contiguous_nonstandard
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
{
{
shape
s
{
result
.
get_shape
().
type
(),
result
.
get_shape
().
lens
()};
shape
s
{
result
.
get_shape
().
type
(),
result
.
get_shape
().
lens
()};
// auto in_s = arg.get_shape();
// auto perm = find_permutation(in_s);
// if (in_s.type() == shape::half_type and perm == std::vector<int64_t>({0, 2, 1, 3}))
// {
// auto lens = s.lens();
// auto last_dim = s.lens().back();
// dim3 grid(lens[0], lens[1], lens[2]);
// dim3 block(last_dim);
// auto in_stride = in_s.strides();
// auto out_stride = s.strides();
// cont_kernel<<<grid, block, 0, stream>>>(arg.data(), result.data(), out_stride[0],
// out_stride[1], out_stride[2], in_stride[0], in_stride[1], in_stride[2]);
// }
// else
// {
visit_all
(
result
,
arg
)([
&
](
auto
output_v
,
auto
input_v
)
{
visit_all
(
result
,
arg
)([
&
](
auto
output_v
,
auto
input_v
)
{
hip_visit_views
(
output_v
,
input_v
,
s
)([
&
](
auto
output
,
auto
input
,
auto
standard_shape
)
{
hip_visit_views
(
output_v
,
input_v
,
s
)([
&
](
auto
output
,
auto
input
,
auto
standard_shape
)
{
mi_gs_launch
(
stream
,
mi_gs_launch
(
stream
,
standard_shape
)([
=
](
auto
idx
)
__device__
{
output
[
idx
]
=
input
[
idx
];
});
standard_shape
)([
=
](
auto
idx
)
__device__
{
output
[
idx
]
=
input
[
idx
];
});
});
});
});
});
// }
}
}
void
contiguous_packed
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
void
contiguous_packed
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg
)
...
...
src/targets/gpu/device/gelu.cpp
View file @
d6b4ae77
#include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/gelu.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <migraphx/gpu/device/types.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <cmath>
#include <cmath>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -32,15 +34,69 @@ void gelu_new(hipStream_t stream, const argument& result, const argument& arg)
...
@@ -32,15 +34,69 @@ void gelu_new(hipStream_t stream, const argument& result, const argument& arg)
nary
(
stream
,
result
,
arg
)([](
auto
x
)
__device__
{
return
gelu_fn_new
(
to_hip_type
(
x
));
});
nary
(
stream
,
result
,
arg
)([](
auto
x
)
__device__
{
return
gelu_fn_new
(
to_hip_type
(
x
));
});
}
}
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
auto
stride
=
ss
.
at
(
1
).
strides
();
return
(
stride
[
0
]
==
0
);
}
return
false
;
}
__global__
void
add_gelu_kernel
(
void
*
a
,
void
*
b
,
int
n_dim
,
void
*
r
,
int
n
)
{
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
int
idb
=
tid
%
n_dim
;
auto
sum
=
__hadd2
(
ha
[
tid
],
hb
[
idb
]);
__half2
sqrt2
=
__float2half2_rn
(
M_SQRT1_2
);
auto
x
=
__hmul2
(
sum
,
sqrt2
);
auto
f2
=
__half22float2
(
x
);
f2
.
x
=
::
erff
(
f2
.
x
);
f2
.
y
=
::
erff
(
f2
.
y
);
auto
h2
=
__floats2half2_rn
(
f2
.
x
,
f2
.
y
);
auto
one
=
__float2half2_rn
(
1.0
f
);
h2
=
__hadd2
(
h2
,
one
);
__half2
point5
=
__float2half2_rn
(
0.5
f
);
hr
[
tid
]
=
__hmul2
(
sum
,
__hmul2
(
point5
,
h2
));
}
}
void
add_gelu
(
hipStream_t
stream
,
void
add_gelu
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
const
argument
&
arg2
)
const
argument
&
arg2
)
{
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
auto
sr
=
result
.
get_shape
();
auto
sum
=
to_hip_type
(
x
+
y
);
auto
type
=
sr
.
type
();
return
gelu_fn
(
sum
);
std
::
vector
<
shape
>
ss
;
});
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
if
(
type
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
add_gelu_kernel
<<<
block_num
,
block_size
,
0
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
auto
sum
=
to_hip_type
(
x
+
y
);
return
gelu_fn
(
sum
);
});
}
}
}
void
add_gelu_new
(
hipStream_t
stream
,
void
add_gelu_new
(
hipStream_t
stream
,
...
...
src/targets/gpu/device/include/migraphx/gpu/device/multi_index.hpp
View file @
d6b4ae77
...
@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
...
@@ -57,9 +57,10 @@ inline auto mi_nglobal(const hip_shape<N>& s, index_int nlocal)
{
{
assert
(
s
.
standard
);
assert
(
s
.
standard
);
assert
(
s
.
elements
()
>
0
);
assert
(
s
.
elements
()
>
0
);
index_int
n
=
s
.
elements
();
index_int
n
=
s
.
elements
();
index_int
groups
=
(
n
+
nlocal
-
1
)
/
nlocal
;
index_int
groups
=
(
n
+
nlocal
-
1
)
/
nlocal
;
index_int
nglobal
=
std
::
min
<
index_int
>
(
128
,
groups
)
*
nlocal
;
// change the max group num to 1 Million
index_int
nglobal
=
std
::
min
<
index_int
>
((
1
<<
20
),
groups
)
*
nlocal
;
assert
(
groups
>
0
);
assert
(
groups
>
0
);
assert
(
nglobal
>
0
);
assert
(
nglobal
>
0
);
...
...
src/targets/gpu/device/include/migraphx/gpu/device/nary.hpp
View file @
d6b4ae77
...
@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
...
@@ -24,6 +24,8 @@ MIGRAPHX_DECLARE_ENV_VAR(MIGRAPHX_TRACE_NARY);
if(enabled(MIGRAPHX_TRACE_NARY{})) \
if(enabled(MIGRAPHX_TRACE_NARY{})) \
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
std::cout << "nary device function: " << __PRETTY_FUNCTION__ << std::endl;
static
index_int
group_num_global
=
(
1
<<
8
);
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
...
@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
...
@@ -87,7 +89,7 @@ void nary_broadcast_vec_impl(
const
index_int
vec_size
=
4
;
const
index_int
vec_size
=
4
;
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
hip_vec_visit_all
<
vec_size
>
(
result
,
barg
,
args
...)(
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
...
@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
...
@@ -134,7 +136,7 @@ void nary_broadcast_impl(hipStream_t stream, F f, argument result, argument barg
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
index_int
nelements
=
result
.
get_shape
().
elements
();
index_int
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
hip_visit_all
(
result
,
barg
,
args
...)([
&
](
auto
output
,
auto
binput
,
auto
...
inputs
)
{
using
type
=
typename
decltype
(
output
)
::
value_type
;
using
type
=
typename
decltype
(
output
)
::
value_type
;
...
@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
...
@@ -178,7 +180,7 @@ void nary_double_broadcast_vec_impl(
const
index_int
vec_size
=
4
;
const
index_int
vec_size
=
4
;
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
const
index_int
bdim_vec_len
=
bdim_len
/
vec_size
;
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
hip_vec_visit_all
<
vec_size
>
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
...
@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
...
@@ -234,7 +236,7 @@ void nary_double_broadcast_impl(
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
auto
broadcast_idx
=
create_broadcast_index
(
bdim_len
,
bdim_stride
);
const
index_int
nlocal
=
1024
;
const
index_int
nlocal
=
1024
;
const
index_int
nglobal
=
256
*
nlocal
;
const
index_int
nglobal
=
group_num_global
*
nlocal
;
index_int
nelements
=
result
.
get_shape
().
elements
();
index_int
nelements
=
result
.
get_shape
().
elements
();
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
hip_visit_all
(
result
,
barg1
,
barg2
,
args
...)(
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
[
&
](
auto
output
,
auto
binput1
,
auto
binput2
,
auto
...
inputs
)
{
...
...
src/targets/gpu/device/include/migraphx/gpu/device/scan.hpp
View file @
d6b4ae77
...
@@ -44,12 +44,13 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
...
@@ -44,12 +44,13 @@ __device__ void block_scan(index idx, Op op, T init, ForStride fs, Input input,
template
<
index_int
N
,
class
Op
,
class
T
,
class
Input
,
class
Output
>
template
<
index_int
N
,
class
Op
,
class
T
,
class
Input
,
class
Output
>
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
index_int
n
,
Input
input
,
Output
output
)
__device__
void
block_scan
(
index
idx
,
Op
op
,
T
init
,
index_int
n
,
Input
input
,
Output
output
)
{
{
block_scan
<
N
>
(
idx
,
block_scan
<
N
>
(
op
,
idx
,
init
,
op
,
[
&
](
auto
f
)
->
decltype
(
f
(
index_int
{}))
{
return
idx
.
local_stride
(
n
,
f
);
},
init
,
input
,
[
&
](
auto
f
)
->
decltype
(
f
(
index_int
{}))
{
return
idx
.
local_stride
(
n
,
f
);
},
output
);
input
,
output
);
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/include/migraphx/gpu/device/visit.hpp
View file @
d6b4ae77
...
@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
...
@@ -14,28 +14,23 @@ constexpr void visit_tensor_size(index_int n, F f)
{
{
switch
(
n
)
switch
(
n
)
{
{
case
1
:
case
1
:
{
{
f
(
std
::
integral_constant
<
index_int
,
1
>
{});
f
(
std
::
integral_constant
<
index_int
,
1
>
{});
break
;
break
;
}
}
case
2
:
case
2
:
{
{
f
(
std
::
integral_constant
<
index_int
,
2
>
{});
f
(
std
::
integral_constant
<
index_int
,
2
>
{});
break
;
break
;
}
}
case
3
:
case
3
:
{
{
f
(
std
::
integral_constant
<
index_int
,
3
>
{});
f
(
std
::
integral_constant
<
index_int
,
3
>
{});
break
;
break
;
}
}
case
4
:
case
4
:
{
{
f
(
std
::
integral_constant
<
index_int
,
4
>
{});
f
(
std
::
integral_constant
<
index_int
,
4
>
{});
break
;
break
;
}
}
case
5
:
case
5
:
{
{
f
(
std
::
integral_constant
<
index_int
,
5
>
{});
f
(
std
::
integral_constant
<
index_int
,
5
>
{});
break
;
break
;
}
}
...
...
src/targets/gpu/device/layernorm.cpp
View file @
d6b4ae77
...
@@ -184,11 +184,11 @@ auto layernorm_fusion(hipStream_t stream,
...
@@ -184,11 +184,11 @@ auto layernorm_fusion(hipStream_t stream,
const
Arguments
&
...
args
)
const
Arguments
&
...
args
)
{
{
return
[
=
](
auto
input
,
auto
output
)
{
return
[
=
](
auto
input
,
auto
output
)
{
auto
relements
=
arg1
.
get_shape
().
lens
().
back
();
auto
relements
=
arg1
.
get_shape
().
lens
().
back
();
auto
nelements
=
result
.
get_shape
().
elements
()
/
relements
;
auto
nelements
=
result
.
get_shape
().
elements
()
/
relements
;
auto
output_shape
=
result
.
get_shape
();
//
auto output_shape = result.get_shape();
auto
reduce_output_lens
(
output_shape
.
lens
());
//
auto reduce_output_lens(output_shape.lens());
reduce_output_lens
.
back
()
=
1
;
//
reduce_output_lens.back() = 1;
if
((
relements
%
4
)
==
0
)
if
((
relements
%
4
)
==
0
)
layernorm_vec_impl
<
4
>
(
layernorm_vec_impl
<
4
>
(
...
...
src/targets/gpu/device/mul.cpp
View file @
d6b4ae77
#include <migraphx/gpu/device/mul.hpp>
#include <migraphx/gpu/device/mul.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
void
mul
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
2
)
{
auto
stride
=
ss
.
at
(
1
).
strides
();
return
(
stride
[
0
]
==
0
);
}
return
false
;
}
__global__
void
mul_kernel
(
void
*
a
,
void
*
b
,
int
n_dim
,
void
*
r
,
int
n
)
{
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
*
y
;
});
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
tid
<
n
)
{
int
idb
=
tid
%
n_dim
;
hr
[
tid
]
=
__hmul2
(
ha
[
tid
],
hb
[
idb
]);
}
}
}
void
mul
(
hipStream_t
stream
,
void
mul
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
)
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg3
)
{
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
y
,
auto
z
)
auto
sr
=
result
.
get_shape
();
__device__
{
return
x
*
y
*
z
;
});
std
::
vector
<
shape
>
ss
;
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
if
(
sr
.
type
()
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
auto
last_dim
=
sr
.
lens
().
back
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
mul_kernel
<<<
block_num
,
block_size
,
0
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
nary
(
stream
,
result
,
arg1
,
arg2
)([](
auto
x
,
auto
y
)
__device__
{
return
x
*
y
;
});
}
}
}
}
// namespace device
}
// namespace device
...
...
src/targets/gpu/device/mul_add.cpp
View file @
d6b4ae77
#include "migraphx/gpu/device/launch.hpp"
#include <hip/amd_detail/amd_device_functions.h>
#include <hip/amd_detail/amd_hip_runtime.h>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/mul_add.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <migraphx/gpu/device/nary.hpp>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
namespace
device
{
namespace
device
{
__global__
void
mul_add_kernel_dim3
(
void
*
a
,
void
*
x
,
void
*
b
,
int
dim3
,
void
*
r
,
int
n
)
{
int
id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
if
(
id
<
n
)
{
auto
id1
=
id
%
dim3
;
hr
[
id
]
=
__hfma2
(
ha
[
id
],
hx
[
id1
],
hb
[
id1
]);
}
}
__global__
void
mul_add_kernel_dim4
(
void
*
a
,
void
*
x
,
void
*
b
,
int
factor
,
int
dim4
,
void
*
r
,
int
n
)
{
int
id
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
__half2
*
ha
=
reinterpret_cast
<
__half2
*>
(
a
);
__half2
*
hb
=
reinterpret_cast
<
__half2
*>
(
b
);
__half2
*
hx
=
reinterpret_cast
<
__half2
*>
(
x
);
__half2
*
hr
=
reinterpret_cast
<
__half2
*>
(
r
);
if
(
id
<
n
)
{
int
idb
=
id
/
(
factor
*
dim4
)
*
dim4
+
id
%
dim4
;
hr
[
id
]
=
__hfma2
(
ha
[
id
],
hx
[
id
],
hb
[
idb
]);
}
}
static
bool
is_bert
(
const
std
::
vector
<
shape
>&
ss
)
{
auto
n_dim
=
ss
.
front
().
lens
().
size
();
if
(
n_dim
==
3
)
{
auto
stride
=
ss
.
at
(
2
).
strides
();
return
(
stride
[
1
]
==
0
);
}
else
if
(
n_dim
==
2
)
{
auto
stride1
=
ss
.
at
(
1
).
strides
();
auto
stride2
=
ss
.
at
(
2
).
strides
();
return
(
stride1
==
stride2
and
stride1
[
0
]
==
0
);
}
return
false
;
}
void
mul_add
(
hipStream_t
stream
,
void
mul_add
(
hipStream_t
stream
,
const
argument
&
result
,
const
argument
&
result
,
const
argument
&
arg1
,
const
argument
&
arg1
,
const
argument
&
arg2
,
const
argument
&
arg2
,
const
argument
&
arg3
)
const
argument
&
arg3
)
{
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
a
,
auto
b
)
auto
sr
=
result
.
get_shape
();
__device__
{
return
a
*
x
+
b
;
});
auto
type
=
sr
.
type
();
std
::
vector
<
shape
>
ss
;
ss
.
push_back
(
arg1
.
get_shape
());
ss
.
push_back
(
arg2
.
get_shape
());
ss
.
push_back
(
arg3
.
get_shape
());
auto
lens
=
sr
.
lens
();
int
last_dim
=
lens
.
back
()
/
2
;
auto
n_dim
=
lens
.
size
();
if
(
type
==
shape
::
half_type
and
is_bert
(
ss
))
{
auto
elem_num
=
sr
.
elements
()
/
2
;
int
block_size
=
1024
;
int
block_num
=
(
elem_num
+
block_size
-
1
)
/
block_size
;
if
(
n_dim
==
2
)
{
mul_add_kernel_dim3
<<<
block_num
,
block_size
,
0
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
last_dim
,
result
.
data
(),
elem_num
);
}
else
{
int
factor
=
lens
[
1
];
mul_add_kernel_dim4
<<<
block_num
,
block_size
,
0
,
stream
>>>
(
arg1
.
data
(),
arg2
.
data
(),
arg3
.
data
(),
factor
,
last_dim
,
result
.
data
(),
elem_num
);
}
}
else
{
nary
(
stream
,
result
,
arg1
,
arg2
,
arg3
)([](
auto
x
,
auto
a
,
auto
b
)
__device__
{
return
a
*
x
+
b
;
});
}
}
}
}
// namespace device
}
// namespace device
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment