Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
05e81ed3
Commit
05e81ed3
authored
Feb 21, 2023
by
charlie
Browse files
Merge branch 'select_module_op' of github.com:ROCmSoftwarePlatform/AMDMIGraphX into dyn_batch_pass
parents
89c8b52c
5de36e4a
Changes
18
Hide whitespace changes
Inline
Side-by-side
Showing
18 changed files
with
464 additions
and
112 deletions
+464
-112
Dockerfile
Dockerfile
+2
-1
cmake/PythonModules.cmake
cmake/PythonModules.cmake
+1
-1
src/include/migraphx/context.hpp
src/include/migraphx/context.hpp
+10
-10
src/include/migraphx/op/select_module.hpp
src/include/migraphx/op/select_module.hpp
+2
-2
src/targets/gpu/jit/reduce.cpp
src/targets/gpu/jit/reduce.cpp
+7
-7
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
+4
-0
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
+71
-12
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
...argets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
+9
-10
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
+12
-3
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
+292
-51
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
+9
-7
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+24
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-1
test/ref_ops_test.cpp
test/ref_ops_test.cpp
+3
-3
test/verify/test_reduce_op_large.cpp
test/verify/test_reduce_op_large.cpp
+13
-0
test/verify/test_select_module_add.cpp
test/verify/test_select_module_add.cpp
+1
-1
test/verify/test_select_module_reduce.cpp
test/verify/test_select_module_reduce.cpp
+2
-2
tools/install_prereqs.sh
tools/install_prereqs.sh
+1
-1
No files found.
Dockerfile
View file @
05e81ed3
...
...
@@ -54,8 +54,9 @@ RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y --allow-
apt-get clean
&&
\
rm
-rf
/var/lib/apt/lists/
*
# add this for roctracer dependancies
RUN
pip3
install
CppHeaderParser
packaging
==
22.0
RUN
pip3
install
CppHeaderParser
# Workaround broken rocm packages
RUN
ln
-s
/opt/rocm-
*
/opt/rocm
...
...
cmake/PythonModules.cmake
View file @
05e81ed3
...
...
@@ -76,7 +76,7 @@ function(py_add_module NAME)
)
endfunction
()
set
(
PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9
)
set
(
PYTHON_SEARCH_VERSIONS 2.7 3.5 3.6 3.7 3.8 3.9
3.10
)
set
(
PYTHON_DISABLE_VERSIONS
""
CACHE STRING
""
)
foreach
(
PYTHON_DISABLE_VERSION
${
PYTHON_DISABLE_VERSIONS
}
)
list
(
REMOVE_ITEM PYTHON_SEARCH_VERSIONS
${
PYTHON_DISABLE_VERSION
}
)
...
...
src/include/migraphx/context.hpp
View file @
05e81ed3
...
...
@@ -182,13 +182,13 @@ struct context
void
wait_for
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
std
::
move
(
queue
)
)
;
(
*
this
).
private_detail_te_get_handle
().
wait_for
(
queue
);
}
void
finish_on
(
any_ptr
queue
)
{
assert
((
*
this
).
private_detail_te_handle_mem_var
);
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
std
::
move
(
queue
)
)
;
(
*
this
).
private_detail_te_get_handle
().
finish_on
(
queue
);
}
void
finish
()
const
...
...
@@ -261,29 +261,29 @@ struct context
template
<
class
T
>
static
auto
private_detail_te_default_wait_for
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
wait_for
(
std
::
move
(
queue
))
)
->
decltype
(
private_detail_te_self
.
wait_for
(
queue
))
{
private_detail_te_self
.
wait_for
(
std
::
move
(
queue
)
)
;
private_detail_te_self
.
wait_for
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_wait_for
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
wait_for_context
(
private_detail_te_self
,
std
::
move
(
queue
)
)
;
wait_for_context
(
private_detail_te_self
,
queue
);
}
template
<
class
T
>
static
auto
private_detail_te_default_finish_on
(
char
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
->
decltype
(
private_detail_te_self
.
finish_on
(
std
::
move
(
queue
))
)
->
decltype
(
private_detail_te_self
.
finish_on
(
queue
))
{
private_detail_te_self
.
finish_on
(
std
::
move
(
queue
)
)
;
private_detail_te_self
.
finish_on
(
queue
);
}
template
<
class
T
>
static
void
private_detail_te_default_finish_on
(
float
,
T
&&
private_detail_te_self
,
any_ptr
queue
)
{
finish_on_context
(
private_detail_te_self
,
std
::
move
(
queue
)
)
;
finish_on_context
(
private_detail_te_self
,
queue
);
}
template
<
typename
PrivateDetailTypeErasedT
>
...
...
@@ -335,13 +335,13 @@ struct context
void
wait_for
(
any_ptr
queue
)
override
{
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
std
::
move
(
queue
)
)
;
private_detail_te_default_wait_for
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish_on
(
any_ptr
queue
)
override
{
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
std
::
move
(
queue
)
)
;
private_detail_te_default_finish_on
(
char
(
0
),
private_detail_te_value
,
queue
);
}
void
finish
()
const
override
{
private_detail_te_value
.
finish
();
}
...
...
src/include/migraphx/op/select_module.hpp
View file @
05e81ed3
...
...
@@ -43,7 +43,7 @@ struct select_module
std
::
string
name
()
const
{
return
"select_module"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
,
std
::
vector
<
module_ref
>
)
const
shape
compute_shape
(
const
std
::
vector
<
shape
>&
,
const
std
::
vector
<
module_ref
>
&
)
const
{
return
shape
{
output_dyn_shapes
};
}
...
...
@@ -72,7 +72,7 @@ struct select_module
{
MIGRAPHX_THROW
(
"SELECT_MODULE: no compatible submodules found for given input shapes"
);
}
auto
module_to_run
=
*
module_iter
;
auto
*
module_to_run
=
*
module_iter
;
std
::
unordered_map
<
std
::
string
,
argument
>
params
;
// add input parameters
...
...
src/targets/gpu/jit/reduce.cpp
View file @
05e81ed3
...
...
@@ -118,17 +118,17 @@ struct reduce_compiler : compiler<reduce_compiler>
options
.
virtual_inputs
=
reduce_dims
(
inputs
);
auto
faxis
=
find_fast_axis
({
options
.
virtual_inputs
.
front
()});
vectorize
vec
{};
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
{
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
}
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
nelements
=
options
.
virtual_inputs
.
back
().
elements
();
auto
algo
=
v
.
get
(
"algo"
,
get_reduce_algo
(
options
.
virtual_inputs
));
if
(
algo
==
"block"
)
{
// Vectorize if the axis is a reduction axis
if
(
options
.
virtual_inputs
.
back
().
lens
()[
faxis
]
==
1
)
vec
=
vectorize
::
elements
(
ctx
,
faxis
,
options
.
virtual_inputs
);
auto
relements
=
get_reduce_elements
(
options
.
virtual_inputs
)
/
vec
.
size
;
auto
block_size
=
compute_block_size
(
relements
,
256
);
if
(
relements
>
block_size
*
256
)
algo
=
"block_large"
;
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
nelements
*
block_size
,
256
),
block_size
);
}
...
...
@@ -166,7 +166,7 @@ struct reduce_compiler : compiler<reduce_compiler>
auto
reduce_elements
=
get_reduce_elements
(
ins
->
inputs
());
auto
reduce_type
=
ins
->
inputs
().
front
()
->
get_shape
().
type
();
v
[
"reduction"
]
=
"op::sum{}"
;
std
::
string
mean
=
"op::mean
{
"
+
std
::
to_string
(
reduce_elements
)
+
"}"
;
std
::
string
mean
=
"op::mean
<
"
+
std
::
to_string
(
reduce_elements
)
+
"
>{
}"
;
// Use float accumulator when reduction size is too large for half
if
(
reduce_type
==
shape
::
half_type
and
reduce_elements
>
16384
)
v
[
"read"
]
=
"compose("
+
mean
+
", op::convert_to<float>{})"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
View file @
05e81ed3
...
...
@@ -178,5 +178,9 @@ MIGRAPHX_HIP_NORETURN inline __host__ __device__ void assert_fail(const source_l
#define MIGRAPHX_WARN(...)
#endif
#define MIGRAPHX_STATIC_ASSERT_FOR(...) \
static_assert(__VA_ARGS__); \
if constexpr(__VA_ARGS__)
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_DEBUG_HPP
src/targets/gpu/kernels/include/migraphx/kernels/index.hpp
View file @
05e81ed3
...
...
@@ -29,6 +29,7 @@
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <migraphx/kernels/debug.hpp>
#include <migraphx/kernels/functional.hpp>
namespace
migraphx
{
...
...
@@ -135,42 +136,100 @@ struct index
return
(
n
-
_c
<
1
>
)
/
stride
+
_c
<
1
>
;
}
template
<
class
N
>
constexpr
auto
max_global_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nglobal
());
}
template
<
class
N
>
constexpr
auto
max_local_stride_iterations
(
N
n
)
const
{
return
max_stride_iterations
(
n
,
nlocal
());
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
d
)
->
decltype
(
f
(
i
,
d
))
{
return
f
(
i
,
d
);
}
template
<
class
F
,
class
I
,
class
D
>
static
constexpr
auto
invoke_loop
(
F
f
,
I
i
,
D
)
->
decltype
(
f
(
i
))
{
return
f
(
i
);
}
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride_loop_unroll
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
sequence
(
max_stride_iterations
(
n
,
stride
),
[
&
](
auto
...
ks
)
{
fold
([
&
](
auto
d
,
auto
k
)
{
auto
i
=
start
+
stride
*
k
;
if
(
i
<
n
)
invoke_loop
(
f
,
i
,
d
);
return
d
+
_c
<
1
>
;
})(
_c
<
0
>
,
ks
...);
});
}
template
<
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride_loop
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
index_int
k
=
0
;
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
{
invoke_loop
(
f
,
i
,
k
);
k
++
;
}
}
template
<
bool
Unroll
,
class
F
,
class
N
,
class
Stride
>
static
constexpr
void
for_stride
(
index_int
start
,
N
n
,
Stride
stride
,
F
f
)
{
MIGRAPHX_ASSERT
(
start
<
stride
);
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{}
and
max_stride_iterations
(
n
,
stride
)
==
1
)
if
constexpr
(
not
is_integral
<
N
>
{}
and
not
is_integral
<
Stride
>
{})
{
if
constexpr
(
stride
>
n
)
if
constexpr
(
max_stride_iterations
(
n
,
stride
)
==
1
)
{
if
constexpr
(
stride
>
n
)
{
if
(
start
<
n
)
invoke_loop
(
f
,
start
,
_c
<
0
>
);
}
else
{
invoke_loop
(
f
,
start
,
_c
<
0
>
);
}
}
else
if
constexpr
(
Unroll
)
{
if
(
start
<
n
)
f
(
start
);
MIGRAPHX_STATIC_ASSERT_FOR
(
max_stride_iterations
(
n
,
stride
)
<
256
)
{
for_stride_loop_unroll
(
start
,
n
,
stride
,
f
);
}
}
else
{
f
(
start
);
f
or_stride_loop
(
start
,
n
,
stride
,
f
);
}
}
else
{
for
(
index_int
i
=
start
;
i
<
n
;
i
+=
stride
)
{
f
(
i
);
}
for_stride_loop
(
start
,
n
,
stride
,
f
);
}
}
template
<
class
F
,
class
N
>
__device__
void
global_stride
(
N
n
,
F
f
)
const
{
for_stride
(
global
,
n
,
nglobal
(),
f
);
for_stride
<
false
>
(
global
,
n
,
nglobal
(),
f
);
}
template
<
class
F
,
class
N
>
__device__
void
local_stride
(
N
n
,
F
f
)
const
{
for_stride
(
local
,
n
,
nlocal
(),
f
);
for_stride
<
true
>
(
local
,
n
,
nlocal
(),
f
);
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/layernorm.hpp
View file @
05e81ed3
...
...
@@ -46,28 +46,27 @@ template <index_int Axis,
__device__
void
generic_binary_layernorm
(
F
compute
,
BinOp
op
,
float
eps
,
Output
output
,
Input1
input1
,
Input2
input2
,
Inputs
...
inputs
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input1
,
Axis
>
()
>
;
using
reduce_output
=
reduce
::
with_axis
<
Input1
,
Axis
>
;
reduce
::
block
::
run
<
reduce_output
>
([
&
](
auto
,
auto
r
)
{
using
value_type
=
typename
Input1
::
type
;
block
::
template
run
<
reduce_output
>([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
([
&
](
auto
x1
,
auto
x2
)
{
return
op
(
x1
,
x2
);
})(
input1
,
input2
);
using
value_type
=
typename
Input1
::
type
;
constexpr
auto
relements
=
r
.
template
elements
<
Input1
>();
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x1
,
auto
x2
)
{
auto
x
=
op
(
x1
,
x2
);
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
})(
input1
,
input2
);
auto
means
=
r
.
reduce
(
op
::
sum
{},
make_array
<
vec_type
<
value_type
>>
(
0
,
0
),
[
&
](
auto
x
)
{
return
make_array
(
x
,
x
*
x
)
*
vec_type
<
value_type
>
{
1.0
/
relements
};
})(
input
);
auto
mean_x
=
means
[
0
];
auto
mean_x2
=
means
[
1
];
auto
variance
=
mean_x2
-
(
mean_x
*
mean_x
);
value_type
eps_val
=
eps
;
// implicit conversion for eps
r
.
inner
([
&
](
auto
&
y
,
auto
x1
,
auto
x2
,
auto
...
xs
)
{
auto
x
=
op
(
x1
,
x2
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
,
auto
...
xs
)
{
auto
m
=
x
-
mean_x
;
// m * rsqrt(mean(m ^ 2) + epsilon)
y
=
compute
(
m
*
rsqrt
(
variance
+
eps_val
),
xs
...);
})(
output
,
input
1
,
input2
,
inputs
...);
})(
output
,
input
,
inputs
...);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ops.hpp
View file @
05e81ed3
...
...
@@ -66,13 +66,22 @@ struct convert_to
}
};
template
<
index_int
N
>
struct
mean
{
index_int
item_num
=
1
;
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
auto
operator
()(
T
x
)
const
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
using
type
=
vec_type
<
T
>
;
if
constexpr
(
is_floating_point
<
type
>
{})
{
constexpr
type
d
=
1.0
/
N
;
return
x
*
d
;
}
else
{
return
x
/
static_cast
<
type
>
(
N
);
}
}
};
...
...
src/targets/gpu/kernels/include/migraphx/kernels/reduce.hpp
View file @
05e81ed3
...
...
@@ -103,10 +103,10 @@ __device__ auto block_reduce(index idx, Op op, T init, Index n, F f)
#else
constexpr
index_int
lanes_per_thread
=
64
;
#endif
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()
/
lanes_per_thread
];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
dpp_reduce
(
x
,
op
);
const
auto
ldsidx
=
idx
.
local
/
lanes_per_thread
;
...
...
@@ -128,10 +128,10 @@ template <class Op, class T, class Index, class F>
__device__
auto
block_reduce
(
index
idx
,
Op
op
,
T
init
,
Index
n
,
F
f
)
{
MIGRAPHX_ASSERT
(
idx
.
max_nlocal
()
==
idx
.
nlocal
());
using
type
=
decltype
(
f
(
0
));
using
type
=
decltype
(
index
::
invoke_loop
(
f
,
0
,
_c
<
0
>
));
__shared__
type
buffer
[
idx
.
max_nlocal
()];
type
x
=
init
;
idx
.
local_stride
(
n
,
[
&
](
auto
i
)
{
x
=
op
(
x
,
f
(
i
));
});
idx
.
local_stride
(
n
,
[
&
](
auto
i
,
auto
d
)
{
x
=
op
(
x
,
index
::
invoke_loop
(
f
,
i
,
d
));
});
buffer
[
idx
.
local
]
=
x
;
__syncthreads
();
...
...
@@ -167,6 +167,25 @@ constexpr auto reduce_slice(Input input, T i)
namespace
reduce
{
struct
inner_storage_tag
{
};
template
<
class
T
>
using
is_inner_storage
=
is_base_of
<
inner_storage_tag
,
remove_cv_t
<
remove_reference_t
<
T
>>>
;
template
<
class
R
,
class
F
>
struct
storage_access
:
F
{
using
type
=
R
;
};
template
<
class
R
,
class
F
>
constexpr
storage_access
<
R
,
F
>
make_storage_access
(
F
f
)
{
return
{{
f
}};
}
template
<
class
Slicer
,
class
F
>
constexpr
auto
sliced
(
Slicer
slicer
,
F
f
)
{
...
...
@@ -191,20 +210,140 @@ constexpr auto compute_reduce_axis()
template
<
class
Input
,
index_int
Axis
>
using
with_axis
=
decltype
(
compute_reduce_axis
<
Input
,
Axis
>
());
template
<
class
Derived
>
struct
reducer_base
{
template
<
class
T
>
__device__
auto
make_inner_slice
(
T
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
;
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
make_storage_access
<
typename
decltype
(
t
)
::
type
>
([
=
](
auto
i
,
auto
...)
->
auto
&
{
return
t
[
i
];
});
}
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
,
[[
maybe_unused
]]
Ts
&&
...
xs
)
const
{
MIGRAPHX_ASSERT
(
get_size
(
x
)
==
get_size
(
xs
...));
return
get_size
(
x
);
}
template
<
class
T
,
class
...
Ts
>
constexpr
auto
get_size
(
T
&&
x
)
const
{
if
constexpr
(
is_inner_storage
<
T
>
{})
{
return
x
.
rsize
();
}
else
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
auto
t
=
derived
.
slice
(
x
);
return
t
.
size
();
}
}
template
<
class
F
>
__device__
auto
inner_sliced
(
F
f
)
const
{
return
[
=
](
auto
&&
...
xs
)
{
return
f
(
get_size
(
xs
...),
make_inner_slice
(
xs
)...);
};
}
template
<
class
T
>
static
__device__
typename
T
::
type
&
decl_inner_storage
(
const
T
&
);
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
using
result_type
=
decltype
(
f
(
decl_inner_storage
(
xs
)...));
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
if
constexpr
(
is_void
<
result_type
>
{})
{
derived
.
inner_void_impl
(
f
,
n
,
xs
...);
}
else
{
return
derived
.
template
inner_impl
<
result_type
>(
f
,
n
,
xs
...);
}
});
}
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
{
return
this
->
inner_sliced
([
=
](
auto
n
,
auto
&&
...
xs
)
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
return
derived
.
reduce_impl
(
op
,
init
,
read
,
n
,
xs
...);
});
}
template
<
class
Op
,
class
T
>
__device__
auto
reduce
(
Op
op
,
T
init
)
const
{
return
this
->
reduce
(
op
,
init
,
op
::
id
{});
}
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
f
();
}
template
<
class
Input
>
constexpr
auto
elements
()
const
{
auto
&&
derived
=
static_cast
<
const
Derived
&>
(
*
this
);
using
reduce_type
=
decltype
(
derived
.
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
}
};
struct
block
{
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
T
,
index_int
N
,
class
Size
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
T
;
array
<
T
,
N
>
arr
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
const
{
return
arr
[
d
];
}
template
<
class
U
,
class
V
>
constexpr
auto
&
operator
()(
U
,
V
d
)
{
return
arr
[
d
];
}
};
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
return
block_reduce
(
idx
,
op
,
init
,
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
return
vec_reduce
(
read
(
x
[
j
],
xs
[
j
]...),
op
);
});
return
block_reduce
(
idx
,
op
,
init
,
n
,
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
...
...
@@ -215,31 +354,99 @@ struct block
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
idx
.
local_stride
(
x
.
get_shape
().
elements
(),
[
&
](
auto
j
)
{
f
(
x
[
j
],
xs
[
j
]...);
});
using
max_iterations
=
decltype
(
idx
.
max_local_stride_iterations
(
n
));
inner_storage
<
R
,
max_iterations
{},
N
>
storage
;
idx
.
local_stride
(
n
,
[
&
](
auto
j
,
auto
d
)
{
storage
(
j
,
d
)
=
f
(
xs
(
j
,
d
)...);
});
return
storage
;
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
static
__device__
void
run
(
F
f
)
{
auto
idx
=
make_index
();
constexpr
auto
nelements
=
get_shape_c
<
Output
>
{}.
elements
();
idx
.
global_stride
(
nelements
*
idx
.
nlocal
(),
[
&
](
auto
i
)
{
const
auto
out_idx
=
get_shape_c
<
Output
>
{}.
multi
(
i
/
idx
.
nlocal
());
f
(
out_idx
,
make
(
idx
,
[
&
](
auto
input
)
{
return
reduce_slice
<
Output
>
(
input
,
out_idx
);
}));
});
}
};
struct
block_large
{
template
<
class
Slicer
>
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{
f
};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
...
Ts
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
Ts
&&
...
xs
)
const
{
return
block_reduce
(
idx
,
op
,
init
,
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
return
vec_reduce
(
read
(
xs
(
j
,
d
)...),
op
);
});
}
template
<
class
Input
>
constexpr
auto
elements
(
)
const
template
<
class
F
>
__device__
void
outer
(
F
f
)
const
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
using
value_type
=
typename
Input
::
type
;
constexpr
auto
relements
=
get_shape_c
<
reduce_type
>
{}.
elements
();
if
constexpr
(
vec_size
<
value_type
>
()
>
1
)
return
relements
*
vec_size
<
value_type
>
();
else
return
relements
;
if
(
idx
.
local
==
0
)
f
();
}
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
idx
.
local_stride
(
index_int
{
n
},
[
&
](
auto
j
,
auto
d
)
{
f
(
xs
(
j
,
d
)...);
});
}
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
...
...
@@ -257,22 +464,40 @@ struct block
struct
lane
{
template
<
class
Slicer
>
struct
reducer
struct
reducer
:
reducer_base
<
reducer
<
Slicer
>>
{
index
idx
;
Slicer
slice
;
template
<
class
Op
,
class
T
,
class
Read
>
__device__
auto
reduce
(
Op
op
,
T
init
,
Read
read
)
const
template
<
class
Size
,
class
F
>
struct
inner_storage
:
inner_storage_tag
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
using
type
=
typename
decltype
(
x
)
::
type
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
r
=
op
(
r
,
read
(
x
[
j
],
xs
[
j
]...));
}
return
r
;
});
using
type
=
remove_reference_t
<
decltype
(
declval
<
F
>
()(
0
,
_c
<
0
>
))
>
;
F
f
;
constexpr
Size
rsize
()
const
{
return
{};
}
template
<
class
U
,
class
V
>
constexpr
auto
operator
()(
U
j
,
V
d
)
const
{
return
f
(
j
,
d
);
}
};
template
<
class
Size
,
class
F
>
constexpr
inner_storage
<
Size
,
F
>
make_inner_storage
(
Size
,
F
f
)
{
return
{
f
};
}
template
<
class
Op
,
class
T
,
class
Read
,
class
N
,
class
U
,
class
...
Us
>
__device__
auto
reduce_impl
(
Op
op
,
T
init
,
Read
read
,
N
n
,
U
&&
x
,
Us
&&
...
xs
)
const
{
using
type
=
remove_reference_t
<
decltype
(
x
(
0
,
_c
<
0
>
))
>
;
type
r
=
init
;
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
r
=
op
(
r
,
read
(
x
(
j
,
_c
<
0
>
),
xs
(
j
,
_c
<
0
>
)...));
}
return
r
;
}
template
<
class
F
>
...
...
@@ -281,29 +506,25 @@ struct lane
f
();
}
template
<
class
F
>
__device__
auto
inner
(
F
f
)
const
template
<
class
F
,
class
N
,
class
...
Ts
>
__device__
void
inner
_void_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
return
sliced
(
slice
,
[
=
](
auto
x
,
auto
...
xs
)
{
for
(
index_int
j
=
0
;
j
<
x
.
get_shape
().
elements
();
j
++
)
{
f
(
x
[
j
],
xs
[
j
]...);
}
});
for
(
index_int
j
=
0
;
j
<
n
;
j
++
)
{
f
(
xs
(
j
,
_c
<
0
>
)...);
}
}
template
<
class
Input
>
constexpr
auto
elements
(
)
const
template
<
class
R
,
class
F
,
class
N
,
class
...
Ts
>
__device__
auto
inner_impl
(
F
f
,
N
n
,
Ts
&&
...
xs
)
const
{
using
reduce_type
=
decltype
(
slice
(
Input
{}));
return
get_shape_c
<
reduce_type
>
{}.
elements
();
return
make_inner_storage
(
n
,
[
=
](
auto
j
,
auto
d
)
{
return
f
(
xs
(
j
,
d
)...);
});
}
};
template
<
class
Slicer
>
static
__device__
auto
make
(
index
idx
,
Slicer
slicer
)
{
return
reducer
<
Slicer
>
{
idx
,
slicer
};
return
reducer
<
Slicer
>
{
{},
idx
,
slicer
};
}
template
<
class
Output
,
class
F
>
...
...
@@ -318,6 +539,26 @@ struct lane
}
};
// TODO: Remove these in the future when they can be selected in the compiler class
template
<
index_int
RElements
>
constexpr
auto
pick_block
()
{
using
nlocal
=
decltype
(
index
{}.
max_nlocal
());
if
constexpr
(
RElements
<
nlocal
{}
*
256
)
return
block
{};
else
return
block_large
{};
}
template
<
index_int
RElements
>
using
auto_block
=
decltype
(
pick_block
<
RElements
>
());
template
<
class
Input
,
index_int
Axis
>
constexpr
auto
reduce_elements_with_axis
()
{
constexpr
auto
s
=
get_shape_c
<
Input
>
{};
return
s
.
lens
[
Axis
];
}
}
// namespace reduce
template
<
class
Algo
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/softmax.hpp
View file @
05e81ed3
...
...
@@ -30,18 +30,20 @@
namespace
migraphx
{
template
<
index_int
Axis
,
class
Input
,
class
Output
>
__device__
void
softmax
(
Input
input
,
Output
output
)
__device__
void
softmax
(
Input
input
1
,
Output
output
)
{
reduce
::
block
::
run
<
reduce
::
with_axis
<
Input
,
Axis
>>
([
&
](
auto
,
auto
r
)
{
using
block
=
reduce
::
auto_block
<
reduce
::
reduce_elements_with_axis
<
Input
,
Axis
>
()
>
;
block
::
template
run
<
reduce
::
with_axis
<
Input
,
Axis
>
>
([
&
](
auto
,
auto
r
)
{
auto
input
=
r
.
inner
(
op
::
id
{})(
input1
);
#ifdef MIGRAPHX_USE_FAST_SOFTMAX
const
auto
c
=
vec_at
(
r
.
slice
(
input
)[
0
],
0
);
const
auto
c
=
vec_at
(
r
.
slice
(
input
1
)[
0
],
0
);
#else
const
auto
c
=
r
.
reduce
(
op
::
max
{},
lowest
{},
op
::
id
{})(
input
);
#endif
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[
&
](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
migraphx
::
exp
(
x
-
c
));
})(
input
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
migraphx
::
exp
(
x
-
c
)
/
batch_sum
;
})(
output
,
input
);
auto
exp_in
=
r
.
inner
([
&
](
auto
x
)
{
return
migraphx
::
exp
(
x
-
c
);
})(
input
);
auto
batch_sum
=
r
.
reduce
(
op
::
sum
{},
0
,
[](
auto
x
)
{
return
migraphx
::
convert
<
float
>
(
x
);
})(
exp_in
);
r
.
inner
([
&
](
auto
&
y
,
auto
x
)
{
y
=
x
/
batch_sum
;
})(
output
,
exp_in
);
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
05e81ed3
...
...
@@ -141,6 +141,25 @@ MIGRAPHX_BUILTIN_TYPE_TRAITN(is_constructible);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_nothrow_constructible
);
MIGRAPHX_BUILTIN_TYPE_TRAITN
(
is_trivially_constructible
);
template
<
class
T
>
struct
remove_cv
{
using
type
=
T
;
};
template
<
class
T
>
struct
remove_cv
<
const
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
struct
remove_cv
<
volatile
T
>
:
remove_cv
<
T
>
{
};
template
<
class
T
>
using
remove_cv_t
=
typename
remove_cv
<
T
>::
type
;
template
<
class
T
>
struct
remove_reference
{
...
...
@@ -168,6 +187,11 @@ struct add_pointer : type_identity<typename remove_reference<T>::type*>
template
<
class
T
>
using
add_pointer_t
=
typename
add_pointer
<
T
>::
type
;
template
<
class
T
>
struct
is_void
:
is_same
<
void
,
remove_cv_t
<
T
>>
{
};
template
<
class
...
Ts
>
struct
common_type
;
...
...
src/targets/gpu/lowering.cpp
View file @
05e81ed3
...
...
@@ -369,7 +369,7 @@ struct miopen_apply
apply_map
.
emplace
(
"select_module"
,
[
=
](
instruction_ref
ins
)
{
std
::
vector
<
instruction_ref
>
inputs
=
ins
->
inputs
();
auto
mod_args
=
ins
->
module_inputs
();
for
(
auto
smod
:
mod_args
)
for
(
auto
*
smod
:
mod_args
)
{
smod
->
use_local_alloc
=
true
;
auto
last_ins
=
std
::
prev
(
smod
->
end
());
...
...
test/ref_ops_test.cpp
View file @
05e81ed3
...
...
@@ -7285,7 +7285,7 @@ TEST_CASE(select_module_add_test)
auto literal_ins = mm->add_literal(migraphx::literal{lit_s, {6}});
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto create_submodule = [&](std::size_t batch_size,
const
std::string
&
module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 4}};
auto sm_input = submod->add_parameter("data", sm_shape);
...
...
@@ -7329,7 +7329,7 @@ TEST_CASE(select_module_reduce_test0)
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto create_submodule = [&](std::size_t batch_size,
const
std::string
&
module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
...
...
@@ -7375,7 +7375,7 @@ TEST_CASE(select_module_reduce_test1)
migraphx::program p;
// create batch submodules
auto create_submodule = [&](std::size_t batch_size, std::string module_name) {
auto create_submodule = [&](std::size_t batch_size,
const
std::string
&
module_name) {
auto* submod = p.create_module(module_name);
migraphx::shape sm_shape{migraphx::shape::float_type, {batch_size, 2, 2}};
auto sm_input = submod->add_parameter("data", sm_shape);
...
...
test/verify/test_reduce_op_large.cpp
View file @
05e81ed3
...
...
@@ -76,3 +76,16 @@ struct test_reduce_mean_2 : verify_program<test_reduce_mean_2>
return
p
;
};
};
struct
test_large_reduce_mean
:
verify_program
<
test_large_reduce_mean
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
s
{
migraphx
::
shape
::
float_type
,
{
2
,
256
*
256
*
16
}};
auto
x
=
mm
->
add_parameter
(
"x"
,
s
);
mm
->
add_instruction
(
migraphx
::
op
::
reduce_mean
{{
1
}},
x
);
return
p
;
};
};
test/verify/test_select_module_add.cpp
View file @
05e81ed3
...
...
@@ -37,7 +37,7 @@ struct test_select_module_add : verify_program<test_select_module_add>
auto
literal_ins
=
mm
->
add_literal
(
migraphx
::
literal
{
lit_s
,
{
6
}});
// create batch submodules
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
std
::
string
module_name
)
{
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
const
std
::
string
&
module_name
)
{
auto
*
submod
=
p
.
create_module
(
module_name
);
migraphx
::
shape
sm_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
4
}};
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
...
...
test/verify/test_select_module_reduce.cpp
View file @
05e81ed3
...
...
@@ -34,8 +34,8 @@ struct test_select_module_reduce : verify_program<test_select_module_reduce>
migraphx
::
program
p
;
// create batch submodules
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
std
::
string
module_name
)
{
auto
submod
=
p
.
create_module
(
module_name
);
auto
create_submodule
=
[
&
](
std
::
size_t
batch_size
,
const
std
::
string
&
module_name
)
{
auto
*
submod
=
p
.
create_module
(
module_name
);
migraphx
::
shape
sm_shape
{
migraphx
::
shape
::
float_type
,
{
batch_size
,
2
,
2
}};
auto
sm_input
=
submod
->
add_parameter
(
"data"
,
sm_shape
);
auto
reduce_ins
=
...
...
tools/install_prereqs.sh
View file @
05e81ed3
...
...
@@ -57,7 +57,7 @@ echo "Dependencies are installed at $PREFIX"
rbuild prepare
-d
$PREFIX
-s
develop
# install onnx package for unit tests
pip3
install
onnx
==
1.10.
0
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
16.8
pip3
install
onnx
==
1.10.
2
numpy
==
1.21.6
typing
==
3.7.4
pytest
==
6.0.1
packaging
==
23.0
# pin version of protobuf in Python for onnx runtime unit tests
pip3
install
protobuf
==
3.20.0
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment