Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3ce7ad8b
Commit
3ce7ad8b
authored
Feb 08, 2022
by
Shucai Xiao
Browse files
merge from develop branch and resolve merge conflicts
parents
d0202590
b304d97d
Changes
96
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
825 additions
and
134 deletions
+825
-134
src/targets/gpu/argmax.cpp
src/targets/gpu/argmax.cpp
+1
-1
src/targets/gpu/argmin.cpp
src/targets/gpu/argmin.cpp
+1
-1
src/targets/gpu/compile_hip_code_object.cpp
src/targets/gpu/compile_hip_code_object.cpp
+2
-1
src/targets/gpu/compile_ops.cpp
src/targets/gpu/compile_ops.cpp
+11
-1
src/targets/gpu/compile_pointwise.cpp
src/targets/gpu/compile_pointwise.cpp
+18
-4
src/targets/gpu/compile_roialign.cpp
src/targets/gpu/compile_roialign.cpp
+1
-1
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
...targets/gpu/device/include/migraphx/gpu/device/launch.hpp
+3
-2
src/targets/gpu/device/softmax.cpp
src/targets/gpu/device/softmax.cpp
+48
-24
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+184
-6
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
...gets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
+1
-1
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
+2
-1
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
+57
-6
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+55
-2
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
...pu/kernels/include/migraphx/kernels/integral_constant.hpp
+1
-0
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
+162
-0
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
...argets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
+28
-2
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
+23
-7
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+56
-0
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
...argets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
+137
-74
test/auto_contiguous_test.cpp
test/auto_contiguous_test.cpp
+34
-0
No files found.
src/targets/gpu/argmax.cpp
View file @
3ce7ad8b
...
@@ -9,7 +9,7 @@ namespace gpu {
...
@@ -9,7 +9,7 @@ namespace gpu {
shape
hip_argmax
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
hip_argmax
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
)
.
standard
()
;
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
}
}
...
...
src/targets/gpu/argmin.cpp
View file @
3ce7ad8b
...
@@ -9,7 +9,7 @@ namespace gpu {
...
@@ -9,7 +9,7 @@ namespace gpu {
shape
hip_argmin
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
hip_argmin
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
)
.
standard
()
;
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
return
op
.
normalize_compute_shape
({
inputs
.
at
(
0
)});
}
}
...
...
src/targets/gpu/compile_hip_code_object.cpp
View file @
3ce7ad8b
...
@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
...
@@ -108,12 +108,13 @@ operation compile_hip_code_object(const std::string& content, hip_compile_option
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
srcs
.
push_back
(
src_file
{
fs
::
path
{
"main.cpp"
},
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
std
::
make_pair
(
content
.
data
(),
content
.
data
()
+
content
.
size
())});
auto
args_hpp
=
auto
args_hpp
=
generate_args_hpp
(
options
.
reduced
_inputs
.
empty
()
?
options
.
inputs
:
options
.
reduced
_inputs
);
generate_args_hpp
(
options
.
virtual
_inputs
.
empty
()
?
options
.
inputs
:
options
.
virtual
_inputs
);
srcs
.
push_back
(
src_file
{
fs
::
path
{
"args.hpp"
},
srcs
.
push_back
(
src_file
{
fs
::
path
{
"args.hpp"
},
std
::
make_pair
(
args_hpp
.
data
(),
args_hpp
.
data
()
+
args_hpp
.
size
())});
std
::
make_pair
(
args_hpp
.
data
(),
args_hpp
.
data
()
+
args_hpp
.
size
())});
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NGLOBAL="
+
std
::
to_string
(
options
.
global
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" -DMIGRAPHX_NLOCAL="
+
std
::
to_string
(
options
.
local
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" "
+
join_strings
(
compiler_warnings
(),
" "
);
options
.
params
+=
" -ftemplate-backtrace-limit=0"
;
options
.
params
+=
" -Werror"
;
options
.
params
+=
" -Werror"
;
auto
cos
=
compile_hip_src
(
srcs
,
std
::
move
(
options
.
params
),
get_device_name
());
auto
cos
=
compile_hip_src
(
srcs
,
std
::
move
(
options
.
params
),
get_device_name
());
if
(
cos
.
size
()
!=
1
)
if
(
cos
.
size
()
!=
1
)
...
...
src/targets/gpu/compile_ops.cpp
View file @
3ce7ad8b
...
@@ -12,6 +12,8 @@ namespace migraphx {
...
@@ -12,6 +12,8 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_GPU_COMPILE_PARALLEL
);
struct
precompile_op
struct
precompile_op
{
{
operation
op
=
op
::
identity
{};
operation
op
=
op
::
identity
{};
...
@@ -70,6 +72,14 @@ struct compiled_result
...
@@ -70,6 +72,14 @@ struct compiled_result
instruction_ref
ins
;
instruction_ref
ins
;
};
};
template
<
class
F
>
void
par_compile
(
std
::
size_t
n
,
F
f
)
{
if
(
n
==
0
)
return
;
par_for
(
n
,
n
/
value_of
(
MIGRAPHX_GPU_COMPILE_PARALLEL
{},
n
),
f
);
}
void
compile_ops
::
apply
(
module
&
m
)
const
void
compile_ops
::
apply
(
module
&
m
)
const
{
{
auto
compilers
=
make_compilers
(
pointwise_compiler
{});
auto
compilers
=
make_compilers
(
pointwise_compiler
{});
...
@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
...
@@ -85,7 +95,7 @@ void compile_ops::apply(module& m) const
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
return
{
c
(
*
ctx
,
ins
,
preop
),
ins
};
});
compiles
.
emplace_back
([
=
]()
->
compiled_result
{
return
{
c
(
*
ctx
,
ins
,
preop
),
ins
};
});
}
}
std
::
vector
<
compiled_result
>
results
(
compiles
.
size
());
std
::
vector
<
compiled_result
>
results
(
compiles
.
size
());
par_
for
(
compiles
.
size
(),
1
,
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
par_
compile
(
compiles
.
size
(),
[
&
](
auto
i
)
{
results
[
i
]
=
compiles
[
i
]();
});
for
(
const
auto
&
cr
:
results
)
for
(
const
auto
&
cr
:
results
)
{
{
m
.
replace_instruction
(
cr
.
ins
,
cr
.
op
,
cr
.
ins
->
inputs
());
m
.
replace_instruction
(
cr
.
ins
,
cr
.
op
,
cr
.
ins
->
inputs
());
...
...
src/targets/gpu/compile_pointwise.cpp
View file @
3ce7ad8b
...
@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
...
@@ -20,7 +20,7 @@ static const char* const pointwise_kernel = R"__migraphx__(
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <args.hpp>
#include <args.hpp>
using
namespace migraphx
;
namespace migraphx
{
${preamble}
${preamble}
...
@@ -32,6 +32,8 @@ __global__ void kernel(${params})
...
@@ -32,6 +32,8 @@ __global__ void kernel(${params})
}
}
} // namespace migraphx
int main() {}
int main() {}
)__migraphx__"
;
)__migraphx__"
;
...
@@ -46,7 +48,7 @@ operation compile_pointwise(context&,
...
@@ -46,7 +48,7 @@ operation compile_pointwise(context&,
options
.
local
=
1024
;
options
.
local
=
1024
;
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
inputs
.
back
();
options
.
output
=
inputs
.
back
();
options
.
reduced
_inputs
=
reduce_dims
(
inputs
);
options
.
virtual
_inputs
=
reduce_dims
(
inputs
);
options
.
params
=
"-Wno-float-equal"
;
options
.
params
=
"-Wno-float-equal"
;
auto
src
=
interpolate_string
(
pointwise_kernel
,
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
...
@@ -60,8 +62,20 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
...
@@ -60,8 +62,20 @@ operation compile_pointwise(context& ctx, const std::vector<shape>& inputs, modu
{
{
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
run_passes
(
m
,
{
eliminate_common_subexpression
{},
dead_code_elimination
{}});
cpp_generator
g
;
cpp_generator
g
;
auto
name
=
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}));
g
.
fmap
([](
const
std
::
string
&
fname
)
{
return
"migraphx::"
+
fname
;
});
return
compile_pointwise
((
ctx
),
inputs
,
"&"
+
name
,
g
.
str
());
g
.
add_point_op
(
"where"
,
"${function:where}(${0}, ${1}, ${2})"
);
g
.
add_point_op
(
"prelu"
,
"${function:where}(${0} < 0, ${0} * ${1}, ${0})"
);
g
.
add_point_op
(
"sign"
,
"${function:where}(${0} > 0, 1, ${function:where}(${0} < 0, -1, 0))"
);
g
.
add_point_op
(
"equal"
,
"migraphx::abs(${0} == ${1})"
);
g
.
add_point_op
(
"less"
,
"migraphx::abs(${0} < ${1})"
);
g
.
add_point_op
(
"greater"
,
"migraphx::abs(${0} > ${1})"
);
g
.
add_point_op
(
"not"
,
"migraphx::abs(not ${0})"
);
// Add explict conversions
g
.
fresult
(
[](
const
shape
&
s
)
{
return
"migraphx::convert<"
+
shape
::
cpp_type
(
s
.
type
())
+
">"
;
});
auto
name
=
g
.
create_function
(
g
.
generate_module
(
m
).
set_attributes
({
"__device__"
}).
set_generic_types
(
m
));
return
compile_pointwise
((
ctx
),
inputs
,
"MIGRAPHX_LIFT("
+
name
+
")"
,
g
.
str
());
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/compile_roialign.cpp
View file @
3ce7ad8b
...
@@ -50,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
...
@@ -50,7 +50,7 @@ operation compile_roialign(context&, const std::vector<shape>& io_shapes, const
options
.
inputs
=
io_shapes
;
options
.
inputs
=
io_shapes
;
options
.
output
=
out_s
;
options
.
output
=
out_s
;
options
.
kernel_name
=
"roialign_kernel"
;
options
.
kernel_name
=
"roialign_kernel"
;
options
.
reduced
_inputs
=
io_shapes
;
options
.
virtual
_inputs
=
io_shapes
;
// sampling_ratio
// sampling_ratio
assert
(
val
.
contains
(
"sampling_ratio"
));
assert
(
val
.
contains
(
"sampling_ratio"
));
...
...
src/targets/gpu/device/include/migraphx/gpu/device/launch.hpp
100755 → 100644
View file @
3ce7ad8b
...
@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
...
@@ -75,8 +75,9 @@ MIGRAPHX_DEVICE_CONSTEXPR auto gs_invoke(F&& f, index_int i, index) -> decltype(
inline
auto
gs_launch
(
hipStream_t
stream
,
index_int
n
,
index_int
local
=
1024
)
inline
auto
gs_launch
(
hipStream_t
stream
,
index_int
n
,
index_int
local
=
1024
)
{
{
index_int
groups
=
(
n
+
local
-
1
)
/
local
;
index_int
groups
=
(
n
+
local
-
1
)
/
local
;
index_int
nglobal
=
std
::
min
<
index_int
>
(
256
,
groups
)
*
local
;
// max possible number of blocks is set to 1B (1,073,741,824)
index_int
nglobal
=
std
::
min
<
index_int
>
(
1073741824
,
groups
)
*
local
;
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
__device__
{
launch
(
stream
,
nglobal
,
local
)([
=
](
auto
idx
)
__device__
{
...
...
src/targets/gpu/device/softmax.cpp
View file @
3ce7ad8b
...
@@ -20,34 +20,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
...
@@ -20,34 +20,58 @@ void softmax(hipStream_t stream, const argument& result, const argument& arg, in
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
result
.
get_shape
().
type
(),
batch_lens
};
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
hip_visit_all
(
result
,
arg
,
batch_shape
)([
&
](
auto
output
,
auto
input
,
auto
batch
)
{
const
index_int
max_block_size
=
256
;
const
index_int
max_block_size
=
128
;
const
index_int
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
const
index_int
block_size
=
compute_block_size
(
batch_item_num
,
max_block_size
);
gs_launch
(
stream
,
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
batch_shape
.
elements
()
*
block_size
,
type
init
=
lowest
();
block_size
)([
=
](
auto
i
,
auto
idx
)
__device__
{
auto
data_idx
=
batch
.
multi
(
i
/
block_size
);
if
(
axis
==
batch_lens
.
size
()
-
1
)
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
{
type
init
=
lowest
();
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
auto
batch_max
=
block_reduce
<
max_block_size
>
(
auto
start_loc
=
i
/
block_size
*
batch_item_num
;
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
batch_max
=
block_reduce
<
max_block_size
>
(
data_idx
[
axis
]
=
j
;
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
return
input
[
data_idx
];
return
input
[
start_loc
+
j
];
});
});
auto
batch_sum
=
block_reduce
<
max_block_size
>
(
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
val
=
input
[
start_loc
+
j
]
-
batch_max
;
return
::
exp
(
to_hip_type
(
val
));
});
auto
batch_sum
=
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
__device__
{
block_reduce
<
max_block_size
>
(
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
val
=
input
[
start_loc
+
j
]
-
batch_max
;
data_idx
[
axis
]
=
j
;
output
[
start_loc
+
j
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
auto
val
=
input
[
data_idx
]
-
batch_max
;
});
return
::
exp
(
to_hip_type
(
val
));
});
});
}
else
{
gs_launch
(
stream
,
batch_shape
.
elements
()
*
block_size
,
block_size
)(
[
=
](
auto
i
,
auto
idx
)
__device__
{
auto
data_idx
=
batch
.
multi
(
i
/
block_size
);
auto
batch_max
=
block_reduce
<
max_block_size
>
(
idx
,
max
{},
init
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
return
input
[
data_idx
];
});
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
batch_sum
=
block_reduce
<
max_block_size
>
(
data_idx
[
axis
]
=
j
;
idx
,
sum
{},
0
,
batch_item_num
,
[
&
](
auto
j
)
__device__
{
auto
val
=
input
[
data_idx
]
-
batch_max
;
data_idx
[
axis
]
=
j
;
output
[
data_idx
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
auto
val
=
input
[
data_idx
]
-
batch_max
;
});
return
::
exp
(
to_hip_type
(
val
));
});
});
idx
.
local_stride
(
batch_item_num
,
[
&
](
auto
j
)
__device__
{
data_idx
[
axis
]
=
j
;
auto
val
=
input
[
data_idx
]
-
batch_max
;
output
[
data_idx
]
=
::
exp
(
to_hip_type
(
val
))
/
batch_sum
;
});
});
}
});
});
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
3ce7ad8b
...
@@ -62,6 +62,8 @@ struct fusion
...
@@ -62,6 +62,8 @@ struct fusion
keep_alive
(
std
::
move
(
t
));
keep_alive
(
std
::
move
(
t
));
}
}
bool
empty
()
const
{
return
fp
==
nullptr
;
}
op_t
operator
[](
std
::
size_t
i
)
const
op_t
operator
[](
std
::
size_t
i
)
const
{
{
assert
(
fp
);
assert
(
fp
);
...
@@ -125,12 +127,11 @@ struct fusion
...
@@ -125,12 +127,11 @@ struct fusion
return
shape
{
shape
::
int8_type
,
{
ws_size
}};
return
shape
{
shape
::
int8_type
,
{
ws_size
}};
}
}
void
compile
(
context
&
ctx
)
bool
compile
(
context
&
ctx
)
{
{
assert
(
fp
);
assert
(
fp
);
auto
status
=
miopenCompileFusionPlan
(
ctx
.
get_stream
().
get_miopen
(),
fp
.
get
());
return
miopenCompileFusionPlan
(
ctx
.
get_stream
().
get_miopen
(),
fp
.
get
())
==
if
(
status
!=
miopenStatusSuccess
)
miopenStatusSuccess
;
MIGRAPHX_THROW
(
"Compiling fusion plan failed"
);
}
}
argument
execute
(
context
&
ctx
,
argument
execute
(
context
&
ctx
,
...
@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
...
@@ -169,7 +170,7 @@ MIGRAPHX_PRED_MATCHER(bias_shape, instruction_ref ins)
MIGRAPHX_PRED_MATCHER
(
fusable_conv
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
fusable_conv
,
instruction_ref
ins
)
{
{
const
auto
device_name
=
split_string
(
get_device_name
(),
':'
).
front
();
const
auto
device_name
=
trim
(
split_string
(
get_device_name
(),
':'
).
front
()
)
;
if
(
not
contains
(
get_supported_archs
(),
device_name
))
if
(
not
contains
(
get_supported_archs
(),
device_name
))
return
false
;
return
false
;
if
(
enabled
(
MIGRAPHX_DISABLE_MIOPEN_FUSION
{}))
if
(
enabled
(
MIGRAPHX_DISABLE_MIOPEN_FUSION
{}))
...
@@ -561,6 +562,122 @@ struct find_mul_add_relu
...
@@ -561,6 +562,122 @@ struct find_mul_add_relu
}
}
};
};
struct
miopen_fusion
{
struct
fuse_op_data
{
operation
op
;
float
alpha
=
1
;
float
beta
=
0
;
};
struct
fuse_op
:
fuse_op_data
,
reflect_equality
<
fuse_op
>
,
reflect_stream
<
fuse_op
>
{
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
),
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
}
};
std
::
vector
<
fuse_op
>
ops
=
{};
fusion
f
=
{};
std
::
function
<
void
(
context
&
,
const
fusion
&
,
const
std
::
vector
<
argument
>&
)
>
execute
;
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
ops
,
"ops"
));
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
value
compile
(
context
&
ctx
,
const
shape
&
,
std
::
vector
<
shape
>
inputs
)
{
// Compensate for allocation
inputs
.
pop_back
();
std
::
size_t
i
=
0
;
f
=
fusion
(
inputs
[
i
]);
i
++
;
std
::
vector
<
std
::
function
<
void
(
const
fused_operator_args
&
,
const
std
::
vector
<
argument
>&
)
>>
invokers
;
for
(
auto
&&
fop
:
ops
)
{
if
(
i
>
inputs
.
size
())
{
f
=
{};
return
{};
}
if
(
fop
.
op
.
name
()
==
"convolution"
)
{
auto
*
mop
=
f
.
create_conv
(
any_cast
<
op
::
convolution
>
(
fop
.
op
),
inputs
[
i
]);
invokers
.
push_back
(
[
=
](
const
fused_operator_args
&
fargs
,
const
std
::
vector
<
argument
>&
args
)
{
miopenSetOpArgsConvForward
(
fargs
.
get
(),
mop
,
&
fop
.
alpha
,
&
fop
.
beta
,
args
[
i
].
implicit
());
});
i
++
;
}
else
if
(
fop
.
op
.
name
()
==
"add"
)
{
auto
*
mop
=
f
.
create_bias
(
inputs
[
i
]);
invokers
.
push_back
(
[
=
](
const
fused_operator_args
&
fargs
,
const
std
::
vector
<
argument
>&
args
)
{
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
mop
,
&
fop
.
alpha
,
&
fop
.
beta
,
args
[
i
].
implicit
());
});
i
++
;
}
else
if
(
fop
.
op
.
name
()
==
"relu"
)
{
auto
*
mop
=
f
.
create_relu
();
invokers
.
push_back
([
=
](
const
fused_operator_args
&
fargs
,
const
std
::
vector
<
argument
>&
)
{
miopenSetOpArgsActivForward
(
fargs
.
get
(),
mop
,
&
fop
.
alpha
,
&
fop
.
beta
,
0
,
0
,
0
);
});
}
else
{
f
=
{};
return
{};
}
}
if
(
not
f
.
compile
(
ctx
))
{
f
=
{};
return
{};
}
execute
=
[
invokers
](
context
&
c
,
const
fusion
&
ff
,
const
std
::
vector
<
argument
>&
args
)
{
auto
fargs
=
make_fused_args
();
for
(
auto
&&
invoker
:
invokers
)
invoker
(
fargs
,
args
);
ff
.
execute
(
c
,
fargs
,
args
.
front
(),
args
.
back
());
};
return
{{
"workspace"
,
f
.
get_workspace
(
ctx
).
bytes
()}};
}
void
finalize
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
shape
>&
inputs
)
{
if
(
not
f
.
empty
())
return
;
auto
v
=
compile
(
ctx
,
output_shape
,
inputs
);
if
(
not
v
.
is_object
())
MIGRAPHX_THROW
(
"Failed to compile fusion plan"
);
}
std
::
string
name
()
const
{
return
"gpu::miopen_fusion"
;
}
shape
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
if
(
ops
.
empty
())
return
{};
// TODO: Check number of arguments
return
ops
.
front
().
op
.
compute_shape
({
inputs
[
0
],
inputs
[
1
]});
}
argument
compute
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
argument
>&
args
)
const
{
execute
(
ctx
,
f
,
args
);
return
args
.
back
();
}
};
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
...
@@ -596,7 +713,8 @@ struct miopen_conv_bias
...
@@ -596,7 +713,8 @@ struct miopen_conv_bias
f
=
fusion
(
inputs
[
0
]);
f
=
fusion
(
inputs
[
0
]);
conv
=
f
.
create_conv
(
op
,
inputs
[
1
]);
conv
=
f
.
create_conv
(
op
,
inputs
[
1
]);
bias
=
f
.
create_bias
(
inputs
[
3
]);
bias
=
f
.
create_bias
(
inputs
[
3
]);
f
.
compile
(
ctx
);
if
(
not
f
.
compile
(
ctx
))
MIGRAPHX_THROW
(
"Failed to compile fusion plan"
);
}
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
.
get_workspace
(
ctx
);
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
.
get_workspace
(
ctx
);
}
...
@@ -683,6 +801,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
...
@@ -683,6 +801,25 @@ void apply_conv_bias(context& ctx, module& p, match::matcher_result r)
p
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
p
.
replace_instruction
(
ins
,
cb
,
input_ins
,
weights_ins
,
old_ws_ins
,
bias_ins
,
alloc_ins
);
}
}
inline
auto
precompile_name
(
std
::
string
s
)
// NOLINT
{
return
match
::
make_basic_pred_matcher
([
=
](
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"gpu::precompile_op"
)
return
false
;
auto
op
=
from_value
<
operation
>
(
ins
->
get_operator
().
to_value
().
at
(
"op"
));
return
(
op
.
name
()
==
s
);
});
}
template
<
class
...
Ms
>
auto
conv_bias_pointwise
(
Ms
...
ms
)
{
return
precompile_name
(
"pointwise"
)(
match
::
either_arg
(
0
,
1
)(
bias_shape
(
match
::
used_once
()).
bind
(
"bias"
),
fusable_conv
(
match
::
used_once
()).
bind
(
"conv"
)),
ms
...);
}
struct
find_conv_bias
struct
find_conv_bias
{
{
context
*
ctx
=
nullptr
;
context
*
ctx
=
nullptr
;
...
@@ -709,6 +846,46 @@ struct find_conv_bias_relu
...
@@ -709,6 +846,46 @@ struct find_conv_bias_relu
}
}
};
};
struct
find_conv_pointwise
{
context
*
ctx
=
nullptr
;
auto
matcher
()
const
{
return
precompile_name
(
"pointwise"
)(
match
::
nargs
(
3
),
match
::
either_arg
(
0
,
1
)(
bias_shape
(
match
::
used_once
()).
bind
(
"bias"
),
fusable_conv
(
match
::
used_once
()).
bind
(
"conv"
)));
}
void
apply
(
module
&
m
,
match
::
matcher_result
r
)
const
{
auto
conv_ins
=
r
.
instructions
[
"conv"
];
auto
bias_ins
=
r
.
instructions
[
"bias"
];
auto
ins
=
r
.
result
;
auto
input_ins
=
conv_ins
->
inputs
().
at
(
0
);
auto
weights_ins
=
conv_ins
->
inputs
().
at
(
1
);
auto
conv_op
=
any_cast
<
miopen_convolution
>
(
conv_ins
->
get_operator
()).
op
;
auto
alloc_ins
=
ins
->
inputs
().
back
();
module_ref
pm
=
ins
->
module_inputs
().
front
();
miopen_fusion
op
{};
op
.
ops
.
push_back
({{
conv_op
}});
for
(
auto
&&
i
:
*
pm
)
{
if
(
i
.
name
()[
0
]
==
'@'
)
continue
;
auto
inputs
=
to_shapes
(
i
.
inputs
());
op
.
ops
.
push_back
({{
i
.
get_operator
()}});
}
std
::
vector
<
instruction_ref
>
inputs
=
{
input_ins
,
weights_ins
,
bias_ins
,
alloc_ins
};
auto
v
=
op
.
compile
(
*
ctx
,
ins
->
get_shape
(),
to_shapes
(
inputs
));
if
(
not
v
.
is_object
())
return
;
m
.
replace_instruction
(
ins
,
op
,
inputs
);
}
};
struct
find_gemm_add
struct
find_gemm_add
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -778,6 +955,7 @@ void fuse_ops::apply(module& p) const
...
@@ -778,6 +955,7 @@ void fuse_ops::apply(module& p) const
match
::
find_matches
(
p
,
find_triadd
{});
match
::
find_matches
(
p
,
find_triadd
{});
match
::
find_matches
(
p
,
match
::
find_matches
(
p
,
find_layernorm
{},
find_layernorm
{},
find_conv_pointwise
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias_relu
{
ctx
},
find_conv_bias
{
ctx
},
find_conv_bias
{
ctx
},
find_add_gelu
{},
find_add_gelu
{},
...
...
src/targets/gpu/include/migraphx/gpu/compile_hip_code_object.hpp
View file @
3ce7ad8b
...
@@ -16,7 +16,7 @@ struct hip_compile_options
...
@@ -16,7 +16,7 @@ struct hip_compile_options
shape
output
;
shape
output
;
std
::
string
kernel_name
=
"kernel"
;
std
::
string
kernel_name
=
"kernel"
;
std
::
string
params
=
""
;
std
::
string
params
=
""
;
std
::
vector
<
shape
>
reduced
_inputs
=
{};
std
::
vector
<
shape
>
virtual
_inputs
=
{};
};
};
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
operation
compile_hip_code_object
(
const
std
::
string
&
content
,
hip_compile_options
options
);
...
...
src/targets/gpu/include/migraphx/gpu/device/arg_op.hpp
View file @
3ce7ad8b
...
@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
...
@@ -76,8 +76,9 @@ void arg_op(Op op, hipStream_t stream, const argument& result, const argument& a
size_t
batch_item_num
=
batch_lens
[
axis
];
size_t
batch_item_num
=
batch_lens
[
axis
];
batch_lens
[
axis
]
=
1
;
batch_lens
[
axis
]
=
1
;
migraphx
::
shape
batch_shape
{
arg_shape
.
type
(),
batch_lens
};
migraphx
::
shape
batch_shape
{
arg_shape
.
type
(),
batch_lens
};
migraphx
::
shape
std_arg_shape
{
arg_shape
.
type
(),
arg_shape
.
lens
()};
hip_visit_all
(
arg
,
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
hip_visit_all
(
arg
,
std_
arg_shape
,
batch_shape
)([
&
](
auto
input
,
auto
arg_s
,
auto
batch_s
)
{
auto
*
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
auto
*
output
=
device_cast
(
result
.
get
<
int64_t
>
().
data
());
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
using
type
=
device_type
<
std
::
remove_cv_t
<
typename
decltype
(
input
)
::
value_type
>>
;
// use one block for items in one batch.
// use one block for items in one batch.
...
...
src/targets/gpu/kernels/include/migraphx/kernels/debug.hpp
100755 → 100644
View file @
3ce7ad8b
...
@@ -5,6 +5,9 @@
...
@@ -5,6 +5,9 @@
namespace
migraphx
{
namespace
migraphx
{
#define MIGRAPHX_STRINGIZE_1(...) #__VA_ARGS__
#define MIGRAPHX_STRINGIZE(...) MIGRAPHX_STRINGIZE_1(__VA_ARGS__)
// Workaround hip's broken abort on device code
// Workaround hip's broken abort on device code
#ifdef __HIP_DEVICE_COMPILE__
#ifdef __HIP_DEVICE_COMPILE__
// NOLINTNEXTLINE
// NOLINTNEXTLINE
...
@@ -14,19 +17,67 @@ namespace migraphx {
...
@@ -14,19 +17,67 @@ namespace migraphx {
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#define MIGRAPHX_HIP_NORETURN [[noreturn]]
#endif
#endif
namespace
debug
{
struct
swallow
{
template
<
class
...
Ts
>
constexpr
swallow
(
Ts
&&
...)
{
}
};
template
<
size_t
N
>
struct
print_buffer
{
char
buffer
[
N
+
1
]
=
{
0
};
char
*
pos
=
buffer
;
constexpr
void
append
(
char
c
)
{
if
(
c
==
0
)
return
;
if
(
pos
<
buffer
+
N
)
{
*
pos
=
c
;
pos
++
;
}
}
template
<
size_t
M
>
constexpr
void
append
(
const
char
(
&
array
)[
M
])
{
for
(
int
i
=
0
;
i
<
M
;
i
++
)
append
(
array
[
i
]);
}
};
template
<
class
...
Ts
>
__host__
__device__
void
print
(
const
Ts
&
...
xs
)
{
const
auto
size
=
(
sizeof
(
xs
)
+
...);
print_buffer
<
size
>
buffer
;
swallow
{(
buffer
.
append
(
xs
),
0
)...};
printf
(
"%s"
,
buffer
.
buffer
);
}
}
// namespace debug
// noreturn cannot be used on this function because abort in hip is broken
// noreturn cannot be used on this function because abort in hip is broken
template
<
class
T1
,
class
T2
,
class
T3
,
class
T4
>
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
MIGRAPHX_HIP_NORETURN
inline
__host__
__device__
void
assert_fail
(
const
char
*
assertion
,
const
char
*
file
,
unsigned
int
line
,
const
char
*
function
)
assert_fail
(
const
T1
&
assertion
,
const
T2
&
file
,
const
T3
&
line
,
const
T4
&
function
)
{
{
printf
(
"%s:%u: %s: assertion '%s' failed.
\n
"
,
file
,
line
,
function
,
assertion
);
// printf is broken on hip with more than one argument, so use a simple print functions instead
debug
::
print
(
file
,
":"
,
line
,
": "
,
function
,
": assertion '"
,
assertion
,
"' failed.
\n
"
);
// printf("%s:%s: %s: assertion '%s' failed.\n", file, line, function, assertion);
abort
();
abort
();
}
}
#ifdef MIGRAPHX_DEBUG
#ifdef MIGRAPHX_DEBUG
#define MIGRAPHX_ASSERT(cond) \
#define MIGRAPHX_ASSERT(cond)
\
((cond) ? void(0) : [](auto... xs) { \
((cond) ? void(0) : [](auto
&&
...
private_migraphx_
xs) { \
assert_fail(xs...); \
assert_fail(
private_migraphx_
xs...);
\
}(#cond, __FILE__, __LINE__, __PRETTY_FUNCTION__))
}(#cond, __FILE__,
MIGRAPHX_STRINGIZE(
__LINE__
)
, __PRETTY_FUNCTION__))
#else
#else
#define MIGRAPHX_ASSERT(cond)
#define MIGRAPHX_ASSERT(cond)
#endif
#endif
...
...
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
100755 → 100644
View file @
3ce7ad8b
...
@@ -16,6 +16,19 @@ struct swallow
...
@@ -16,6 +16,19 @@ struct swallow
template
<
index_int
>
template
<
index_int
>
using
ignore
=
swallow
;
using
ignore
=
swallow
;
template
<
class
...
Fs
>
struct
overloaded
:
Fs
...
{
using
Fs
::
operator
()...;
overloaded
(
Fs
...
fs
)
:
Fs
(
fs
)...
{}
};
template
<
class
...
Fs
>
overloaded
<
Fs
...
>
overload
(
Fs
...
fs
)
{
return
{
fs
...};
}
namespace
detail
{
namespace
detail
{
template
<
class
R
>
template
<
class
R
>
...
@@ -124,12 +137,48 @@ constexpr void each_args(F)
...
@@ -124,12 +137,48 @@ constexpr void each_args(F)
{
{
}
}
template
<
class
F
,
class
T
>
constexpr
auto
fold_impl
(
F
&&
,
T
&&
x
)
{
return
static_cast
<
T
&&>
(
x
);
}
template
<
class
F
,
class
T
,
class
U
,
class
...
Ts
>
constexpr
auto
fold_impl
(
F
&&
f
,
T
&&
x
,
U
&&
y
,
Ts
&&
...
xs
)
{
return
fold_impl
(
f
,
f
(
static_cast
<
T
&&>
(
x
),
static_cast
<
U
&&>
(
y
)),
static_cast
<
Ts
&&>
(
xs
)...);
}
template
<
class
F
>
constexpr
auto
fold
(
F
f
)
{
return
[
=
](
auto
&&
...
xs
)
{
return
fold_impl
(
f
,
static_cast
<
decltype
(
xs
)
&&>
(
xs
)...);
};
}
template
<
class
...
Ts
>
template
<
class
...
Ts
>
auto
pack
(
Ts
...
xs
)
constexpr
auto
pack
(
Ts
...
xs
)
{
{
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
return
[
=
](
auto
f
)
{
return
f
(
xs
...);
};
}
}
template
<
class
Compare
,
class
P1
,
class
P2
>
constexpr
auto
pack_compare
(
Compare
compare
,
P1
p1
,
P2
p2
)
{
return
p1
([
&
](
auto
...
xs
)
{
return
p2
([
&
](
auto
...
ys
)
{
auto
c
=
[
&
](
auto
x
,
auto
y
)
->
int
{
if
(
compare
(
x
,
y
))
return
1
;
else
if
(
compare
(
y
,
x
))
return
-
1
;
else
return
0
;
};
return
fold
([](
auto
x
,
auto
y
)
{
return
x
?
x
:
y
;
})(
c
(
xs
,
ys
)...,
0
);
});
});
}
template
<
index_int
N
>
template
<
index_int
N
>
constexpr
auto
arg_c
()
constexpr
auto
arg_c
()
{
{
...
@@ -168,9 +217,13 @@ constexpr auto transform_args(F f, Fs... fs)
...
@@ -168,9 +217,13 @@ constexpr auto transform_args(F f, Fs... fs)
return
[
=
](
auto
...
xs
)
{
return
transform_args
(
f
)(
xs
...)(
transform_args
(
fs
...));
};
return
[
=
](
auto
...
xs
)
{
return
transform_args
(
f
)(
xs
...)(
transform_args
(
fs
...));
};
}
}
// NOLINTNEXTLINE
#define MIGRAPHX_RETURNS(...) \
->decltype(__VA_ARGS__) { return __VA_ARGS__; }
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_LIFT(...) \
#define MIGRAPHX_LIFT(...) \
(
[](auto&&... xs)
{ return
(__VA_ARGS__)(static_cast<decltype(xs)>(xs)...)
; }
)
[](auto&&... xs)
MIGRAPHX_RETURNS(
(__VA_ARGS__)(static_cast<decltype(xs)>(xs)...))
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
#endif // MIGRAPHX_GUARD_KERNELS_FUNCTIONAL_HPP
src/targets/gpu/kernels/include/migraphx/kernels/integral_constant.hpp
View file @
3ce7ad8b
...
@@ -13,6 +13,7 @@ struct integral_constant
...
@@ -13,6 +13,7 @@ struct integral_constant
using
type
=
integral_constant
;
using
type
=
integral_constant
;
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
constexpr
operator
value_type
()
const
noexcept
{
return
value
;
}
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
constexpr
value_type
operator
()()
const
noexcept
{
return
value
;
}
static
constexpr
type
to
()
{
return
{};
}
};
};
// NOLINTNEXTLINE
// NOLINTNEXTLINE
...
...
src/targets/gpu/kernels/include/migraphx/kernels/math.hpp
0 → 100644
View file @
3ce7ad8b
#ifndef MIGRAPHX_GUARD_KERNELS_MATH_HPP
#define MIGRAPHX_GUARD_KERNELS_MATH_HPP
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/vec.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/type_traits.hpp>
#include <hip/hip_fp16.h>
#include <hip/math_functions.h>
namespace
migraphx
{
namespace
math
{
constexpr
float
as_float
(
migraphx
::
half
x
)
{
return
x
;
}
template
<
class
T
>
constexpr
T
as_float
(
T
x
)
{
return
x
;
}
}
// namespace math
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) MIGRAPHX_RETURNS(fname(xs...))
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_VEC(name) \
template <class... Ts, MIGRAPHX_REQUIRES(is_any_vec<Ts...>())> \
auto __device__ name(Ts... xs) \
{ \
return vec_transform(xs...)([](auto... ys) { return name(ys...); }); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_FOR(type, name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(type x, Ts... xs)->type \
{ \
return fname(x, xs...); \
}
// NOLINTNEXTLINE
#define MIGRAPHX_DEVICE_MATH_HALF(name, fname) \
template <class... Ts, MIGRAPHX_REQUIRES(not is_any_vec<Ts...>())> \
auto __device__ name(migraphx::half x, Ts... xs) \
MIGRAPHX_RETURNS(fname(math::as_float(x), math::as_float(xs)...))
MIGRAPHX_DEVICE_MATH
(
abs
,
::
abs
)
MIGRAPHX_DEVICE_MATH
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH
(
exp
,
::
exp
)
MIGRAPHX_DEVICE_MATH
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH
(
log
,
::
log
)
MIGRAPHX_DEVICE_MATH
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH
(
rsqrt
,
::
rsqrt
)
MIGRAPHX_DEVICE_MATH
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH
(
sqrt
,
::
sqrt
)
MIGRAPHX_DEVICE_MATH
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH
(
tanh
,
::
tanh
)
// Float overloads
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
acos
,
::
acosf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
acosh
,
::
acoshf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
asin
,
::
asinf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
asinh
,
::
asinhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
atan
,
::
atanf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
atanh
,
::
atanhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
cos
,
::
cosf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
cosh
,
::
coshf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
rsqrt
,
::
rsqrtf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
sin
,
::
sinf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
sinh
,
::
sinhf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
tan
,
::
tanf
)
MIGRAPHX_DEVICE_MATH_FOR
(
float
,
tanh
,
::
tanhf
)
// Builtin half functions
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
abs
,
::
__habs
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
exp
,
::
hexp
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
log
,
::
hlog
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
rsqrt
,
::
hrsqrt
)
MIGRAPHX_DEVICE_MATH_FOR
(
migraphx
::
half
,
sqrt
,
::
hsqrt
)
// Use float to compute half overload
MIGRAPHX_DEVICE_MATH_HALF
(
acos
,
::
acos
)
MIGRAPHX_DEVICE_MATH_HALF
(
acosh
,
::
acosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
asin
,
::
asin
)
MIGRAPHX_DEVICE_MATH_HALF
(
asinh
,
::
asinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
atan
,
::
atan
)
MIGRAPHX_DEVICE_MATH_HALF
(
atanh
,
::
atanh
)
MIGRAPHX_DEVICE_MATH_HALF
(
ceil
,
::
ceil
)
MIGRAPHX_DEVICE_MATH_HALF
(
cos
,
::
cos
)
MIGRAPHX_DEVICE_MATH_HALF
(
cosh
,
::
cosh
)
MIGRAPHX_DEVICE_MATH_HALF
(
erf
,
::
erf
)
MIGRAPHX_DEVICE_MATH_HALF
(
floor
,
::
floor
)
MIGRAPHX_DEVICE_MATH_HALF
(
pow
,
::
pow
)
MIGRAPHX_DEVICE_MATH_HALF
(
round
,
::
round
)
MIGRAPHX_DEVICE_MATH_HALF
(
sin
,
::
sin
)
MIGRAPHX_DEVICE_MATH_HALF
(
sinh
,
::
sinh
)
MIGRAPHX_DEVICE_MATH_HALF
(
tan
,
::
tan
)
MIGRAPHX_DEVICE_MATH_HALF
(
tanh
,
::
tanh
)
template
<
class
T
,
class
U
>
constexpr
auto
where
(
bool
cond
,
const
T
&
a
,
const
U
&
b
)
{
return
cond
?
a
:
b
;
}
MIGRAPHX_DEVICE_MATH_VEC
(
abs
)
MIGRAPHX_DEVICE_MATH_VEC
(
acos
)
MIGRAPHX_DEVICE_MATH_VEC
(
acosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
asin
)
MIGRAPHX_DEVICE_MATH_VEC
(
asinh
)
MIGRAPHX_DEVICE_MATH_VEC
(
atan
)
MIGRAPHX_DEVICE_MATH_VEC
(
atanh
)
MIGRAPHX_DEVICE_MATH_VEC
(
ceil
)
MIGRAPHX_DEVICE_MATH_VEC
(
cos
)
MIGRAPHX_DEVICE_MATH_VEC
(
cosh
)
MIGRAPHX_DEVICE_MATH_VEC
(
erf
)
MIGRAPHX_DEVICE_MATH_VEC
(
exp
)
MIGRAPHX_DEVICE_MATH_VEC
(
floor
)
MIGRAPHX_DEVICE_MATH_VEC
(
log
)
MIGRAPHX_DEVICE_MATH_VEC
(
pow
)
MIGRAPHX_DEVICE_MATH_VEC
(
round
)
MIGRAPHX_DEVICE_MATH_VEC
(
rsqrt
)
MIGRAPHX_DEVICE_MATH_VEC
(
sin
)
MIGRAPHX_DEVICE_MATH_VEC
(
sinh
)
MIGRAPHX_DEVICE_MATH_VEC
(
sqrt
)
MIGRAPHX_DEVICE_MATH_VEC
(
tan
)
MIGRAPHX_DEVICE_MATH_VEC
(
tanh
)
MIGRAPHX_DEVICE_MATH_VEC
(
where
)
template
<
class
T
,
class
U
>
constexpr
auto
max
(
const
T
&
a
,
const
U
&
b
)
{
return
where
(
a
<
b
,
b
,
a
);
}
template
<
class
T
,
class
U
>
constexpr
auto
min
(
const
T
&
a
,
const
U
&
b
)
{
return
where
(
a
>
b
,
b
,
a
);
}
template
<
class
T
,
class
U
>
constexpr
auto
convert
(
U
v
)
{
return
vec_transform
(
v
)([](
auto
x
)
->
T
{
return
x
;
});
}
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_MATH_HPP
src/targets/gpu/kernels/include/migraphx/kernels/pointwise.hpp
100755 → 100644
View file @
3ce7ad8b
...
@@ -3,19 +3,45 @@
...
@@ -3,19 +3,45 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/functional.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/preload.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/vectorize.hpp>
#include <migraphx/kernels/args.hpp>
#include <migraphx/kernels/args.hpp>
namespace
migraphx
{
namespace
migraphx
{
template
<
class
T
>
struct
implicit_conversion_op
{
T
x
;
template
<
index_int
N
,
class
U
>
constexpr
operator
vec
<
U
,
N
>
()
const
{
static_assert
(
vec_size
<
T
>
()
==
N
,
"Vector mismatch size"
);
return
__builtin_convertvector
(
x
,
vec
<
U
,
N
>
);
}
template
<
class
U
>
constexpr
operator
U
()
const
{
return
x
;
}
};
template
<
class
T
>
constexpr
implicit_conversion_op
<
T
>
implicit_conversion
(
T
x
)
{
return
{
x
};
}
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
__device__
void
pointwise_tensor
(
index
idx
,
F
f
,
T
out
,
Ts
...
xs
)
{
{
preload
<
typename
T
::
type
>
(
idx
,
xs
...)([
&
](
auto
...
ps
)
{
preload
<
typename
T
::
type
>
(
idx
,
xs
...)([
&
](
auto
...
ps
)
{
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
idx
.
global_stride
(
out
.
get_shape
().
elements
(),
[
&
](
auto
i
)
{
auto
multi_idx
=
out
.
get_shape
().
multi
(
i
);
auto
multi_idx
=
out
.
get_shape
().
multi
(
i
);
out
[
multi_idx
]
=
f
(
ps
[
multi_idx
]...);
out
[
multi_idx
]
=
implicit_conversion
(
f
(
ps
[
multi_idx
]...)
)
;
});
});
});
});
}
}
...
@@ -23,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
...
@@ -23,7 +49,7 @@ __device__ void pointwise_tensor(index idx, F f, T out, Ts... xs)
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
__device__
void
pointwise
(
F
f
,
Ts
*
...
ps
)
__device__
void
pointwise
(
F
f
,
Ts
*
...
ps
)
{
{
auto
t
=
transform_args
(
make_tensors
(),
rotate_last
());
auto
t
=
transform_args
(
make_tensors
(),
rotate_last
()
,
auto_vectorize
()
);
t
(
ps
...)([
&
](
auto
...
xs
)
{
t
(
ps
...)([
&
](
auto
...
xs
)
{
auto
idx
=
make_index
();
auto
idx
=
make_index
();
pointwise_tensor
(
idx
,
f
,
xs
...);
pointwise_tensor
(
idx
,
f
,
xs
...);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/preload.hpp
100755 → 100644
View file @
3ce7ad8b
...
@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
...
@@ -29,7 +29,7 @@ constexpr auto traverse_preload(Shapes... ss)
}
}
template
<
class
T
,
class
...
Shapes
>
template
<
class
T
,
class
...
Shapes
>
constexpr
index_int
compute_preload_size
(
Shapes
...)
constexpr
index_int
compute_preload_size
_c
(
Shapes
...)
{
{
index_int
size
=
0
;
index_int
size
=
0
;
traverse_preload
<
T
>
(
Shapes
{}...)(
traverse_preload
<
T
>
(
Shapes
{}...)(
...
@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
...
@@ -37,6 +37,12 @@ constexpr index_int compute_preload_size(Shapes...)
return
size
;
return
size
;
}
}
template
<
class
T
,
class
...
Shapes
>
constexpr
auto
compute_preload_size
(
Shapes
...)
{
return
_c
<
compute_preload_size_c
<
T
>
(
Shapes
{}...)
>
;
}
template
<
class
F
,
class
T
,
class
...
Ts
>
template
<
class
F
,
class
T
,
class
...
Ts
>
__device__
auto
preload_copy
(
index
idx
,
F
f
,
__shared__
T
*
buffer
,
Ts
...
xs
)
__device__
auto
preload_copy
(
index
idx
,
F
f
,
__shared__
T
*
buffer
,
Ts
...
xs
)
{
{
...
@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
...
@@ -48,11 +54,21 @@ __device__ auto preload_copy(index idx, F f, __shared__ T* buffer, Ts... xs)
[
&
](
auto
x
,
auto
offset
,
auto
copy
)
{
[
&
](
auto
x
,
auto
offset
,
auto
copy
)
{
if
constexpr
(
copy
)
if
constexpr
(
copy
)
{
{
auto
v
=
vectorize
(
x
);
if
constexpr
(
decltype
(
tensor_vec_size
(
x
)){}
==
0
)
auto
b
=
as_vec
(
tensor_vec_size
(
v
),
buffer
+
offset
);
{
idx
.
local_stride
(
v
.
get_shape
().
element_space
(),
auto
v
=
vectorize
(
x
);
[
&
](
auto
i
)
{
b
[
i
]
=
v
.
data
()[
i
];
});
auto
b
=
as_vec
(
tensor_vec_size
(
v
),
buffer
+
offset
);
return
x
.
with
(
buffer
+
offset
);
idx
.
local_stride
(
v
.
get_shape
().
element_space
(),
[
&
](
auto
i
)
{
b
[
i
]
=
v
.
data
()[
i
];
});
return
x
.
with
(
buffer
+
offset
);
}
else
{
auto
b
=
as_vec
(
tensor_vec_size
(
x
),
buffer
+
offset
);
idx
.
local_stride
(
x
.
get_shape
().
element_space
(),
[
&
](
auto
i
)
{
b
[
i
]
=
x
.
data
()[
i
];
});
return
x
.
with
(
b
);
}
}
}
else
else
{
{
...
@@ -78,7 +94,7 @@ template <class T, class... Ts>
...
@@ -78,7 +94,7 @@ template <class T, class... Ts>
__device__
auto
preload
(
index
idx
,
Ts
...
xs
)
__device__
auto
preload
(
index
idx
,
Ts
...
xs
)
{
{
using
type
=
typename
remove_vec
<
T
>::
type
;
using
type
=
typename
remove_vec
<
T
>::
type
;
constexpr
auto
size
=
compute_preload_size
<
type
>
(
xs
.
get_shape
()...);
constexpr
auto
size
=
decltype
(
compute_preload_size
<
type
>
(
xs
.
get_shape
()...)
){}
;
const
index_int
max_size
=
512
*
sizeof
(
type
);
const
index_int
max_size
=
512
*
sizeof
(
type
);
return
[
=
](
auto
f
)
{
return
[
=
](
auto
f
)
{
if
constexpr
(
size
>
0
and
size
<
max_size
)
if
constexpr
(
size
>
0
and
size
<
max_size
)
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
100755 → 100644
View file @
3ce7ad8b
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/types.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/functional.hpp>
namespace
migraphx
{
namespace
migraphx
{
...
@@ -24,6 +25,38 @@ constexpr auto vec_size()
...
@@ -24,6 +25,38 @@ constexpr auto vec_size()
return
decltype
(
vec_size
(
T
{})){};
return
decltype
(
vec_size
(
T
{})){};
}
}
template
<
class
...
Ts
>
constexpr
auto
is_any_vec
()
{
if
constexpr
(
sizeof
...(
Ts
)
==
0
)
return
false_type
{};
else
return
bool_constant
<
((
vec_size
<
Ts
>
()
+
...)
>
0
)
>
{};
}
template
<
class
T
,
class
I
>
constexpr
auto
vec_at
(
T
x
,
I
i
)
{
if
constexpr
(
vec_size
<
T
>
()
==
0
)
return
x
;
else
{
MIGRAPHX_ASSERT
(
i
<
vec_size
<
T
>
());
return
x
[
i
];
}
}
template
<
class
...
Ts
>
constexpr
auto
common_vec_size
()
{
return
fold
([](
auto
x
,
auto
y
)
{
if
constexpr
(
x
>
y
)
return
x
;
else
return
y
;
})(
vec_size
<
Ts
>
()...);
}
template
<
index_int
N
,
class
T
>
template
<
index_int
N
,
class
T
>
__device__
__host__
auto
as_vec
(
T
*
x
)
__device__
__host__
auto
as_vec
(
T
*
x
)
{
{
...
@@ -33,5 +66,28 @@ __device__ __host__ auto as_vec(T* x)
...
@@ -33,5 +66,28 @@ __device__ __host__ auto as_vec(T* x)
return
reinterpret_cast
<
vec
<
T
,
N
>*>
(
x
);
return
reinterpret_cast
<
vec
<
T
,
N
>*>
(
x
);
}
}
template
<
class
T
,
index_int
N
>
using
safe_vec
=
vec
<
std
::
conditional_t
<
std
::
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
template
<
class
...
Ts
>
constexpr
auto
vec_transform
(
Ts
...
xs
)
{
return
[
=
](
auto
f
)
{
if
constexpr
(
is_any_vec
<
Ts
...
>
())
{
using
type
=
decltype
(
f
(
vec_at
(
xs
,
0
)...));
constexpr
auto
size
=
common_vec_size
<
Ts
...
>
();
safe_vec
<
type
,
size
>
result
=
{
0
};
for
(
int
i
=
0
;
i
<
size
;
i
++
)
result
[
i
]
=
f
(
vec_at
(
xs
,
i
)...);
return
result
;
}
else
{
return
f
(
xs
...);
}
};
}
}
// namespace migraphx
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
#endif // MIGRAPHX_GUARD_KERNELS_VEC_HPP
src/targets/gpu/kernels/include/migraphx/kernels/vectorize.hpp
View file @
3ce7ad8b
...
@@ -7,40 +7,70 @@
...
@@ -7,40 +7,70 @@
namespace
migraphx
{
namespace
migraphx
{
template
<
class
T
>
template
<
class
T
>
constexpr
auto
tensor_vec_size
(
T
)
constexpr
auto
tensor_vec_size
()
{
{
return
vec_size
<
typename
T
::
type
>
();
return
vec_size
<
typename
T
::
type
>
();
}
}
template
<
index_int
N
,
class
Shape
>
template
<
class
T
>
constexpr
auto
as_vec_shape
(
Shape
s
)
constexpr
auto
tensor_vec_size
(
T
)
{
{
auto
lens
=
transform
(
s
.
lens
,
s
.
strides
,
[](
auto
len
,
auto
stride
)
{
return
tensor_vec_size
<
T
>
();
if
(
stride
==
1
)
}
return
len
/
N
;
else
template
<
index_int
N
,
class
Shape
,
class
Axis
>
return
len
;
constexpr
auto
shape_step
(
Shape
s
,
Axis
)
});
{
auto
strides
=
transform
(
s
.
strides
,
[](
auto
stride
)
{
static_assert
(
N
>
0
,
"Vector size must be non-zero"
);
if
(
stride
==
1
)
return
sequence
(
s
.
lens
.
size
(),
[
&
](
auto
...
is
)
{
return
stride
;
auto
lens
=
transform
(
s
.
lens
,
index_ints
<
is
...
>
{},
[
&
](
auto
i
,
auto
j
)
{
return
stride
/
N
;
constexpr
auto
axis
=
Axis
::
to
();
MIGRAPHX_ASSERT
(
i
!=
0
);
MIGRAPHX_ASSERT
(
j
!=
axis
or
i
%
N
==
0
);
if
(
j
==
axis
)
return
i
/
N
;
else
return
i
;
});
auto
strides
=
transform
(
s
.
strides
,
index_ints
<
is
...
>
{},
[
&
](
auto
i
,
auto
j
)
{
constexpr
auto
axis
=
Axis
::
to
();
// If stride of the axis is zero then we dont need to adjust the other strides
if
(
Shape
{}.
strides
[
axis
]
==
0
)
return
i
;
MIGRAPHX_ASSERT
(
j
==
axis
or
i
%
N
==
0
);
if
(
j
==
axis
)
return
i
;
else
return
i
/
N
;
});
MIGRAPHX_ASSERT
(
make_shape
(
lens
,
strides
).
elements
()
*
N
==
s
.
elements
());
MIGRAPHX_ASSERT
(
strides
[
Axis
{}]
==
0
or
make_shape
(
lens
,
strides
).
element_space
()
*
N
==
s
.
element_space
());
return
make_shape
(
lens
,
strides
);
});
});
MIGRAPHX_ASSERT
(
make_shape
(
lens
,
strides
).
element_space
()
*
N
==
s
.
element_space
());
return
make_shape
(
lens
,
strides
);
}
}
template
<
index_int
N
,
class
T
>
// Bools can not be used as a vector type so convert it to uint8
__device__
__host__
auto
as_vec
(
T
x
)
template
<
class
T
>
__device__
__host__
T
*
remove_bool
(
T
*
x
)
{
return
x
;
}
inline
__device__
__host__
uint8_t
*
remove_bool
(
bool
*
x
)
{
return
reinterpret_cast
<
uint8_t
*>
(
x
);
}
template
<
index_int
N
,
class
T
,
class
Axis
>
__device__
__host__
auto
as_vec
(
T
x
,
Axis
axis
)
{
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
==
0
)
return
x
;
return
x
;
else
else
return
make_tensor_view
(
as_vec
<
N
>
(
x
.
data
()),
as_vec_shape
<
N
>
(
x
.
get_shape
()));
return
make_tensor_view
(
as_vec
<
N
>
(
remove_bool
(
x
.
data
())),
shape_step
<
N
>
(
x
.
get_shape
(),
axis
));
}
}
template
<
index_int
N
,
class
T
,
class
Axis
>
template
<
index_int
N
,
class
T
,
class
Axis
>
constexpr
auto
tensor_step
(
T
x
,
Axis
)
constexpr
auto
tensor_step
(
T
x
,
Axis
axis
)
{
{
if
constexpr
(
N
==
0
)
if
constexpr
(
N
==
0
)
{
{
...
@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis)
...
@@ -49,17 +79,8 @@ constexpr auto tensor_step(T x, Axis)
else
else
{
{
constexpr
auto
s
=
decltype
(
x
.
get_shape
()){};
constexpr
auto
s
=
decltype
(
x
.
get_shape
()){};
MIGRAPHX_ASSERT
(
s
.
strides
[
Axis
{}]
==
0
);
MIGRAPHX_ASSERT
(
s
.
strides
[
axis
]
==
0
);
return
sequence
(
x
.
get_shape
().
lens
.
size
(),
[
&
](
auto
...
is
)
{
return
make_tensor_view
(
x
.
data
(),
shape_step
<
N
>
(
s
,
axis
));
auto
lens
=
transform
(
s
.
lens
,
index_ints
<
is
...
>
{},
[
&
](
auto
i
,
auto
j
)
{
constexpr
auto
axis
=
Axis
{};
if
(
j
==
axis
)
return
i
/
N
;
else
return
i
;
});
return
make_tensor_view
(
x
.
data
(),
make_shape
(
lens
,
s
.
strides
));
});
}
}
}
}
...
@@ -69,42 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
...
@@ -69,42 +90,71 @@ __device__ __host__ auto as_vec(IntegralConstant ic, T&& x)
return
as_vec
<
ic
>
(
x
);
return
as_vec
<
ic
>
(
x
);
}
}
template
<
class
...
Shape
s
>
template
<
class
Shape
>
constexpr
index_int
find_vector_axis
(
Shape
s
...
s
s
)
constexpr
index_int
find_vector_axis
_c
(
Shape
s
)
{
{
// Find the fastest axis that is not broadcasted
index_int
axis
=
0
;
index_int
axis
=
0
;
bool
b
=
false
;
for
(
index_int
i
=
1
;
i
<
s
.
lens
.
size
();
i
++
)
{
if
(
s
.
strides
[
i
]
==
0
)
continue
;
if
(
s
.
strides
[
axis
]
==
0
or
pack_compare
(
less
{},
pack
(
s
.
strides
[
i
],
s
.
lens
[
i
]),
pack
(
s
.
strides
[
axis
],
s
.
lens
[
axis
])))
axis
=
i
;
}
return
axis
;
}
template
<
class
...
Shapes
>
constexpr
index_int
find_vector_axis_c
(
Shapes
...
ss
)
{
const
bool
all_broadcasted
=
(
ss
.
broadcasted
()
and
...);
index_int
axis
=
0
;
bool
b
=
false
;
by
([
&
](
auto
s
)
{
by
([
&
](
auto
s
)
{
if
(
s
.
broadcasted
()
or
b
)
if
(
b
)
return
;
return
;
auto
it
=
find
(
s
.
strides
.
begin
(),
s
.
strides
.
end
(),
1
);
// Skip broadcasted shapes if there are shapes not broadcasted
if
(
it
==
s
.
strides
.
en
d
())
if
(
not
all_broadcasted
and
s
.
broadcaste
d
())
return
;
return
;
axis
=
it
-
s
.
strides
.
begin
();
axis
=
find_vector_axis_c
(
s
);
b
=
true
;
if
(
s
.
strides
[
axis
]
==
1
)
b
=
true
;
})(
ss
...);
})(
ss
...);
if
(
not
b
)
return
-
1
;
return
axis
;
return
axis
;
}
}
template
<
class
...
Shapes
>
constexpr
auto
find_vector_axis
(
Shapes
...)
{
return
_c
<
find_vector_axis_c
(
Shapes
{}...)
>
;
}
template
<
index_int
N
,
class
Axis
,
class
...
Shapes
>
template
<
index_int
N
,
class
Axis
,
class
...
Shapes
>
constexpr
auto
is_vectorizable
(
Axis
axis
,
Shapes
...
ss
)
constexpr
auto
is_vectorizable
_c
(
Axis
axis
,
Shapes
...
ss
)
{
{
return
(((
ss
.
lens
[
axis
]
%
N
)
==
0
and
(
ss
.
strides
[
axis
]
==
1
or
ss
.
strides
[
axis
]
==
0
))
and
return
((
axis
<
ss
.
lens
.
size
()
and
ss
.
lens
[
axis
]
%
N
==
0
and
// Only vectorize broadcasted types with stride 0, since this causes issues in the
// preloader
((
not
ss
.
broadcasted
()
and
ss
.
strides
[
axis
]
==
1
)
or
ss
.
strides
[
axis
]
==
0
))
and
...);
...);
}
}
template
<
index_int
N
,
class
...
Shapes
>
template
<
index_int
N
,
class
Axis
,
class
...
Shapes
>
constexpr
bool
is_vectorizable
(
Shapes
...
ss
)
constexpr
auto
is_vectorizable
(
Axis
,
Shapes
...)
{
{
return
(
is_vectorizable
<
N
>
(
ss
,
find_vector_axis
(
ss
))
and
...);
return
_c
<
is_vectorizable
_c
<
N
>
(
Axis
::
to
(),
Shapes
{}
...)
>
;
}
}
template
<
class
P
>
template
<
class
P
>
constexpr
auto
find_vectorize_size
(
P
pred
)
constexpr
auto
find_vectorize_size
(
P
pred
)
{
{
if
constexpr
(
pred
(
_c
<
4
>
))
if
constexpr
(
decltype
(
pred
(
_c
<
4
>
))
{})
return
_c
<
4
>
;
return
_c
<
4
>
;
else
if
constexpr
(
pred
(
_c
<
2
>
))
else
if
constexpr
(
decltype
(
pred
(
_c
<
2
>
))
{})
return
_c
<
2
>
;
return
_c
<
2
>
;
else
else
return
_c
<
0
>
;
return
_c
<
0
>
;
...
@@ -113,11 +163,12 @@ constexpr auto find_vectorize_size(P pred)
...
@@ -113,11 +163,12 @@ constexpr auto find_vectorize_size(P pred)
template
<
class
T
>
template
<
class
T
>
__host__
__device__
auto
vectorize
(
T
x
)
__host__
__device__
auto
vectorize
(
T
x
)
{
{
if
constexpr
(
vec_size
<
T
>
()
==
0
)
if
constexpr
(
tensor_
vec_size
<
T
>
()
==
0
)
{
{
constexpr
auto
axis
=
find_vector_axis
(
x
.
get_shape
());
constexpr
auto
n
=
constexpr
auto
n
=
find_vectorize_size
([
&
](
auto
i
)
{
return
_c
<
is_vectorizable
<
i
>
(
x
.
get_shape
())
>
;
});
find_vectorize_size
([
&
](
auto
i
)
{
return
is_vectorizable
<
i
>
(
axis
,
x
.
get_shape
());
});
return
as_vec
<
n
>
(
x
);
return
as_vec
<
n
>
(
x
,
axis
);
}
}
else
else
{
{
...
@@ -125,34 +176,46 @@ __host__ __device__ auto vectorize(T x)
...
@@ -125,34 +176,46 @@ __host__ __device__ auto vectorize(T x)
}
}
}
}
template
<
class
F
,
class
...
Ts
>
inline
__device__
__host__
auto
auto_vectorize_impl
(
F
f
,
Ts
...
xs
)
{
// TODO: Just check there a single axis of 1
constexpr
bool
packed_or_broadcasted
=
((
xs
.
get_shape
().
packed
()
or
xs
.
get_shape
().
broadcasted
())
and
...);
if
constexpr
(
packed_or_broadcasted
)
{
constexpr
auto
axis
=
decltype
(
find_vector_axis
(
xs
.
get_shape
()...)){};
constexpr
auto
n
=
find_vectorize_size
(
[
&
](
auto
i
)
{
return
is_vectorizable
<
i
>
(
axis
,
xs
.
get_shape
()...);
});
by
(
[
&
](
auto
x
)
{
constexpr
auto
s
=
decltype
(
x
.
get_shape
()){};
if
constexpr
(
axis
<
s
.
strides
.
size
())
{
MIGRAPHX_ASSERT
(
s
.
strides
[
axis
]
==
0
or
s
.
strides
[
axis
]
==
1
);
MIGRAPHX_ASSERT
(
s
.
lens
[
axis
]
>
0
);
MIGRAPHX_ASSERT
(
n
==
0
or
s
.
lens
[
axis
]
%
n
==
0
);
if
constexpr
(
s
.
strides
[
axis
]
==
0
)
return
tensor_step
<
n
>
(
x
,
axis
);
else
return
as_vec
<
n
>
(
x
,
axis
);
}
else
{
return
x
;
}
},
f
)(
xs
...);
}
else
{
f
(
xs
...);
}
}
inline
__device__
__host__
auto
auto_vectorize
()
inline
__device__
__host__
auto
auto_vectorize
()
{
{
return
[](
auto
...
xs
)
{
return
[](
auto
...
xs
)
{
return
[
=
](
auto
f
)
{
auto_vectorize_impl
(
f
,
xs
...);
};
};
return
[
=
](
auto
f
)
{
// TODO: Just check there a single axis of 1
constexpr
bool
packed_or_broadcasted
=
((
xs
.
get_shape
().
packed
()
or
xs
.
get_shape
().
broadcasted
())
and
...);
if
constexpr
(
packed_or_broadcasted
)
{
constexpr
auto
axis
=
find_vector_axis
(
xs
.
get_shape
()...);
constexpr
auto
n
=
find_vectorize_size
(
[
&
](
auto
i
)
{
return
_c
<
is_vectorizable
<
i
>
(
axis
,
xs
.
get_shape
()...)
>
;
});
by
(
[
&
](
auto
x
)
{
constexpr
auto
s
=
x
.
get_shape
();
if
constexpr
(
s
.
strides
[
axis
]
==
0
)
return
tensor_step
<
n
>
(
x
,
axis
);
else
return
as_vec
<
n
>
(
x
);
},
f
)(
xs
...);
}
else
{
f
(
xs
...);
}
};
};
}
}
}
// namespace migraphx
}
// namespace migraphx
...
...
test/auto_contiguous_test.cpp
View file @
3ce7ad8b
...
@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast)
...
@@ -101,4 +101,38 @@ TEST_CASE(after_param_broadcast)
EXPECT
(
not
m
.
get_output_shapes
().
back
().
broadcasted
());
EXPECT
(
not
m
.
get_output_shapes
().
back
().
broadcasted
());
}
}
TEST_CASE
(
two_transpose_gather
)
{
migraphx
::
module
m1
;
{
auto
data
=
m1
.
add_parameter
(
"2x2"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
ind
=
m1
.
add_parameter
(
"ind"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}});
auto
td
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
data
);
auto
sd
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
2
}}),
td
);
auto
bd
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
1
,
2
}}}),
sd
);
auto
r
=
m1
.
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
2
}}),
bd
,
ind
);
m1
.
add_return
({
r
});
}
run_pass
(
m1
);
migraphx
::
module
m2
;
{
auto
data
=
m2
.
add_parameter
(
"2x2"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
3
,
4
,
5
}});
auto
ind
=
m2
.
add_parameter
(
"ind"
,
{
migraphx
::
shape
::
float_type
,
{
2
,
3
}});
auto
td
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
3
,
1
}}}),
data
);
auto
ctd
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
td
);
auto
sd
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
2
}}),
ctd
);
auto
bd
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
3
,
1
,
2
}}}),
sd
);
auto
cbd
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"contiguous"
),
bd
);
auto
r
=
m2
.
add_instruction
(
migraphx
::
make_op
(
"gather"
,
{{
"axis"
,
2
}}),
cbd
,
ind
);
m2
.
add_return
({
r
});
}
EXPECT
(
m1
==
m2
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
int
main
(
int
argc
,
const
char
*
argv
[])
{
test
::
run
(
argc
,
argv
);
}
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment