Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5a1af3d1
Commit
5a1af3d1
authored
May 31, 2022
by
Paul
Browse files
Merge
parents
dfc7bbac
6e94e607
Changes
49
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
183 additions
and
165 deletions
+183
-165
src/reduce_dims.cpp
src/reduce_dims.cpp
+18
-4
src/targets/cpu/copy.cpp
src/targets/cpu/copy.cpp
+0
-1
src/targets/cpu/gather.cpp
src/targets/cpu/gather.cpp
+0
-1
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
+0
-2
src/targets/gpu/compile_hip.cpp
src/targets/gpu/compile_hip.cpp
+0
-1
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+67
-22
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+38
-6
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+7
-25
src/targets/gpu/jit/gathernd.cpp
src/targets/gpu/jit/gathernd.cpp
+1
-1
src/targets/gpu/jit/pointwise.cpp
src/targets/gpu/jit/pointwise.cpp
+22
-3
src/targets/gpu/jit/roialign.cpp
src/targets/gpu/jit/roialign.cpp
+0
-1
src/targets/gpu/jit/scatternd.cpp
src/targets/gpu/jit/scatternd.cpp
+0
-1
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
+2
-2
src/targets/gpu/kernels/include/migraphx/kernels/basic_ops.hpp
...argets/gpu/kernels/include/migraphx/kernels/basic_ops.hpp
+0
-84
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
...rgets/gpu/kernels/include/migraphx/kernels/functional.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
...ts/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
...targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
+9
-7
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
...gets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
...gets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
+15
-0
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
+1
-1
No files found.
src/reduce_dims.cpp
View file @
5a1af3d1
...
@@ -16,11 +16,9 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -16,11 +16,9 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
auto
bstride
=
s
.
strides
()[
n
+
1
];
auto
bstride
=
s
.
strides
()[
n
+
1
];
auto
blen
=
s
.
lens
()[
n
+
1
];
auto
blen
=
s
.
lens
()[
n
+
1
];
if
(
astride
==
bstride
*
blen
)
if
(
astride
==
bstride
*
blen
or
alen
==
1
)
{
new_lens
.
push_back
(
alen
*
blen
);
new_lens
.
push_back
(
alen
*
blen
);
}
}
}
if
(
new_lens
.
size
()
!=
shapes
.
size
())
if
(
new_lens
.
size
()
!=
shapes
.
size
())
return
false
;
return
false
;
std
::
size_t
i
=
0
;
std
::
size_t
i
=
0
;
...
@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
...
@@ -37,10 +35,25 @@ bool reduce_dim(std::vector<shape>& shapes, std::size_t n)
return
true
;
return
true
;
}
}
void
reduce_dim1
(
std
::
vector
<
shape
>&
shapes
)
{
if
(
std
::
any_of
(
shapes
.
begin
(),
shapes
.
end
(),
[
&
](
const
auto
&
s
)
{
return
s
.
lens
().
size
()
<
2
or
s
.
lens
().
back
()
!=
1
;
}))
return
;
for
(
auto
&
s
:
shapes
)
{
auto
lens
=
s
.
lens
();
auto
strides
=
s
.
strides
();
lens
.
pop_back
();
strides
.
pop_back
();
s
=
shape
{
s
.
type
(),
lens
,
strides
};
}
}
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
std
::
size_t
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
,
std
::
size_t
n
)
{
{
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
while
(
reduce_dim
(
shapes
,
n
)
and
n
<
shapes
.
size
())
{}
return
n
+
1
;
return
n
+
1
;
}
}
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
void
reduce_dim_all
(
std
::
vector
<
shape
>&
shapes
)
...
@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
...
@@ -48,6 +61,7 @@ void reduce_dim_all(std::vector<shape>& shapes)
std
::
size_t
n
=
0
;
std
::
size_t
n
=
0
;
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
while
(
n
<
shapes
.
front
().
lens
().
size
()
-
1
)
n
=
reduce_dim_all
(
shapes
,
n
);
n
=
reduce_dim_all
(
shapes
,
n
);
reduce_dim1
(
shapes
);
}
}
std
::
vector
<
std
::
size_t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
std
::
vector
<
std
::
size_t
>
base_lens
(
const
std
::
vector
<
shape
>&
shapes
)
...
...
src/targets/cpu/copy.cpp
View file @
5a1af3d1
...
@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy>
...
@@ -20,7 +20,6 @@ struct cpu_copy : reduce_dims_base, auto_register_op<cpu_copy>
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
...
...
src/targets/cpu/gather.cpp
View file @
5a1af3d1
...
@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather>
...
@@ -26,7 +26,6 @@ struct cpu_gather : auto_register_op<cpu_gather>
}
}
argument
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
std
::
size_t
nelements
=
output_shape
.
elements
();
std
::
size_t
nelements
=
output_shape
.
elements
();
...
...
src/targets/cpu/include/migraphx/cpu/pointwise.hpp
View file @
5a1af3d1
...
@@ -323,7 +323,6 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
...
@@ -323,7 +323,6 @@ struct cpu_unary : reduce_dims_base, auto_register_op<cpu_unary<Op>>
return
{
s
.
type
(),
s
.
lens
()};
return
{
s
.
type
(),
s
.
lens
()};
}
}
argument
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
...
@@ -362,7 +361,6 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
...
@@ -362,7 +361,6 @@ struct cpu_binary : reduce_dims_base, auto_register_op<cpu_binary<Op>>
}
}
argument
argument
// cppcheck-suppress constParameter
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
compute
(
context
&
ctx
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
{
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
argument
result
=
get_arg
(
args
,
args
.
size
()
-
1
);
...
...
src/targets/gpu/compile_hip.cpp
View file @
5a1af3d1
...
@@ -134,7 +134,6 @@ struct hiprtc_program
...
@@ -134,7 +134,6 @@ struct hiprtc_program
std
::
vector
<
char
>
buffer
(
n
);
std
::
vector
<
char
>
buffer
(
n
);
MIGRAPHX_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
MIGRAPHX_HIPRTC
(
hiprtcGetProgramLog
(
prog
.
get
(),
buffer
.
data
()));
assert
(
buffer
.
back
()
==
0
);
assert
(
buffer
.
back
()
==
0
);
// cppcheck-suppress returnDanglingLifetime
return
{
buffer
.
begin
(),
buffer
.
end
()
-
1
};
return
{
buffer
.
begin
(),
buffer
.
end
()
-
1
};
}
}
...
...
src/targets/gpu/fuse_ops.cpp
View file @
5a1af3d1
...
@@ -681,7 +681,7 @@ struct miopen_fusion
...
@@ -681,7 +681,7 @@ struct miopen_fusion
struct
miopen_conv_bias
struct
miopen_conv_bias
{
{
op
::
convolution
op
;
op
::
convolution
op
;
fusion
f
=
{};
fusion
f
p
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
bias
=
{};
fusion
::
op_t
bias
=
{};
...
@@ -705,19 +705,19 @@ struct miopen_conv_bias
...
@@ -705,19 +705,19 @@ struct miopen_conv_bias
float
beta
=
0
;
float
beta
=
0
;
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
return
f
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
return
f
p
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
f
=
fusion
(
inputs
[
0
]);
f
p
=
fusion
(
inputs
[
0
]);
conv
=
f
.
create_conv
(
op
,
inputs
[
1
]);
conv
=
f
p
.
create_conv
(
op
,
inputs
[
1
]);
bias
=
f
.
create_bias
(
inputs
[
3
]);
bias
=
f
p
.
create_bias
(
inputs
[
3
]);
if
(
not
f
.
compile
(
ctx
))
if
(
not
f
p
.
compile
(
ctx
))
MIGRAPHX_THROW
(
"Failed to compile fusion plan"
);
MIGRAPHX_THROW
(
"Failed to compile fusion plan"
);
}
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
.
get_workspace
(
ctx
);
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
p
.
get_workspace
(
ctx
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
@@ -728,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias)
...
@@ -728,7 +728,7 @@ MIGRAPHX_REGISTER_OP(miopen_conv_bias)
struct
miopen_conv_bias_relu
struct
miopen_conv_bias_relu
{
{
op
::
convolution
op
;
op
::
convolution
op
;
fusion
f
=
{};
fusion
f
p
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
conv
=
{};
fusion
::
op_t
bias
=
{};
fusion
::
op_t
bias
=
{};
fusion
::
op_t
relu
=
{};
fusion
::
op_t
relu
=
{};
...
@@ -754,18 +754,18 @@ struct miopen_conv_bias_relu
...
@@ -754,18 +754,18 @@ struct miopen_conv_bias_relu
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsConvForward
(
fargs
.
get
(),
conv
,
&
alpha
,
&
beta
,
args
[
1
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
miopenSetOpArgsBiasForward
(
fargs
.
get
(),
bias
,
&
alpha
,
&
beta
,
args
[
3
].
implicit
());
miopenSetOpArgsActivForward
(
fargs
.
get
(),
relu
,
&
alpha
,
&
beta
,
0
,
0
,
0
);
miopenSetOpArgsActivForward
(
fargs
.
get
(),
relu
,
&
alpha
,
&
beta
,
0
,
0
,
0
);
return
f
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
return
f
p
.
execute
(
ctx
,
fargs
,
args
[
0
],
args
[
4
]);
}
}
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
void
finalize
(
context
&
ctx
,
const
shape
&
,
const
std
::
vector
<
shape
>&
inputs
)
{
{
f
=
fusion
(
inputs
[
0
]);
f
p
=
fusion
(
inputs
[
0
]);
conv
=
f
.
create_conv
(
op
,
inputs
[
1
]);
conv
=
f
p
.
create_conv
(
op
,
inputs
[
1
]);
bias
=
f
.
create_bias
(
inputs
[
3
]);
bias
=
f
p
.
create_bias
(
inputs
[
3
]);
relu
=
f
.
create_relu
();
relu
=
f
p
.
create_relu
();
f
.
compile
(
ctx
);
f
p
.
compile
(
ctx
);
}
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
.
get_workspace
(
ctx
);
}
shape
get_workspace
(
context
&
ctx
)
{
return
f
p
.
get_workspace
(
ctx
);
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
@@ -875,7 +875,6 @@ struct find_conv_pointwise
...
@@ -875,7 +875,6 @@ struct find_conv_pointwise
{
{
if
(
i
.
name
()[
0
]
==
'@'
)
if
(
i
.
name
()[
0
]
==
'@'
)
continue
;
continue
;
auto
inputs
=
to_shapes
(
i
.
inputs
());
op
.
ops
.
push_back
({{
i
.
get_operator
()}});
op
.
ops
.
push_back
({{
i
.
get_operator
()}});
}
}
std
::
vector
<
instruction_ref
>
inputs
=
{
input_ins
,
weights_ins
,
bias_ins
,
alloc_ins
};
std
::
vector
<
instruction_ref
>
inputs
=
{
input_ins
,
weights_ins
,
bias_ins
,
alloc_ins
};
...
@@ -908,11 +907,6 @@ struct find_gemm_add
...
@@ -908,11 +907,6 @@ struct find_gemm_add
if
(
not
float_equal
(
gemm
.
beta
,
0
))
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
return
;
if
(
std
::
any_of
(
ins
->
inputs
().
begin
(),
ins
->
inputs
().
end
(),
[](
auto
i
)
{
return
not
i
->
get_shape
().
standard
();
}))
return
;
auto
inputs
=
gemm_ins
->
inputs
();
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
pop_back
();
...
@@ -931,6 +925,53 @@ struct find_gemm_add
...
@@ -931,6 +925,53 @@ struct find_gemm_add
}
}
};
};
auto
pointwise_name
(
const
std
::
string
&
s
)
{
return
precompile_name
(
"pointwise"
)(
match
::
make_basic_pred_matcher
([
=
](
auto
ins
)
{
module_ref
pm
=
ins
->
module_inputs
().
front
();
auto
n
=
std
::
count_if
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
i
.
name
()
==
s
;
});
if
(
n
!=
1
)
return
false
;
return
std
::
all_of
(
pm
->
begin
(),
pm
->
end
(),
[
&
](
auto
&
i
)
{
return
starts_with
(
i
.
name
(),
"@"
)
or
i
.
name
()
==
s
;
});
}));
}
struct
find_gemm_pointwise
{
auto
matcher
()
const
{
return
pointwise_name
(
"add"
)(
match
::
nargs
(
3
),
match
::
all_of
[
match
::
inputs
()](
match
::
standard_shape
()),
match
::
either_arg
(
0
,
1
)(
match
::
used_once
().
bind
(
"c"
),
match
::
name
(
"gpu::gemm"
)(
match
::
nargs
(
3
)).
bind
(
"gemm"
)));
}
void
apply
(
module
&
m
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm_ins
=
r
.
instructions
[
"gemm"
];
auto
c_ins
=
r
.
instructions
[
"c"
];
auto
gemm
=
any_cast
<
rocblas_gemm
<
op
::
dot
>>
(
gemm_ins
->
get_operator
());
// Already fused gemm
if
(
not
float_equal
(
gemm
.
beta
,
0
))
return
;
auto
inputs
=
gemm_ins
->
inputs
();
inputs
.
pop_back
();
inputs
.
push_back
(
c_ins
);
inputs
.
push_back
(
gemm_ins
->
inputs
().
back
());
gemm
.
beta
=
1
;
m
.
replace_instruction
(
ins
,
gemm
,
inputs
);
}
};
struct
find_commutative_broadcast
struct
find_commutative_broadcast
{
{
auto
matcher
()
const
auto
matcher
()
const
...
@@ -967,7 +1008,11 @@ void fuse_ops::apply(module& m) const
...
@@ -967,7 +1008,11 @@ void fuse_ops::apply(module& m) const
find_add_unary
{
"gpu::tanh"
,
hip_add_tanh
{},
hip_triadd_tanh
{}},
find_add_unary
{
"gpu::tanh"
,
hip_add_tanh
{},
hip_triadd_tanh
{}},
find_add_clip
{});
find_add_clip
{});
run_passes
(
m
,
{
dead_code_elimination
{}});
run_passes
(
m
,
{
dead_code_elimination
{}});
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_gemm_add
{},
find_commutative_broadcast
{});
match
::
find_matches
(
m
,
find_triadd_layernorm
{},
find_gemm_add
{},
find_gemm_pointwise
{},
find_commutative_broadcast
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/gemm_impl.cpp
View file @
5a1af3d1
#include <rocblas.h>
#include <rocblas.h>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/gpu/gemm_impl.hpp>
#include <migraphx/reduce_dims.hpp>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
...
@@ -27,6 +28,22 @@ rocblas_datatype get_type(shape::type_t type)
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: data type not supported!"
);
MIGRAPHX_THROW
(
"ROCBLAS_GEMM: data type not supported!"
);
}
}
void
blas_shape
(
const
shape
&
s
)
{
if
(
s
.
lens
().
size
()
<
2
)
return
;
if
(
std
::
none_of
(
s
.
strides
().
end
()
-
2
,
s
.
strides
().
end
(),
[
&
](
auto
i
)
{
return
i
==
1
;
}))
MIGRAPHX_THROW
(
"GPU_GEMM: needs to have one matrix stride as 1"
);
if
(
s
.
lens
().
size
()
<
3
)
return
;
shape
batch_shape
{
s
.
type
(),
{
s
.
lens
().
begin
(),
s
.
lens
().
end
()
-
2
},
{
s
.
strides
().
begin
(),
s
.
strides
().
end
()
-
2
}};
auto
batch_shapes
=
reduce_dims
({
batch_shape
});
if
(
batch_shapes
.
front
().
lens
().
size
()
!=
1
)
MIGRAPHX_THROW
(
"GPU_GEMM: Batch dimension is not collapsible"
);
}
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
template
<
class
R
,
class
...
Ts
,
class
...
Us
>
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
R
rocblas_invoke
(
R
(
*
f
)(
Ts
...),
Us
...
xs
)
{
{
...
@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
...
@@ -36,6 +53,18 @@ R rocblas_invoke(R (*f)(Ts...), Us... xs)
return
f
(
xs
...,
nullptr
,
nullptr
);
return
f
(
xs
...,
nullptr
,
nullptr
);
}
}
static
bool
is_transposed
(
const
shape
&
s
)
{
if
(
not
s
.
transposed
())
return
false
;
return
s
.
strides
().
back
()
!=
1
;
}
static
rocblas_int
get_batch_stride
(
const
argument
&
a
)
{
return
a
.
get_shape
().
strides
()[
a
.
get_shape
().
strides
().
size
()
-
3
];
}
template
<
class
T
>
template
<
class
T
>
void
gemm_impl
(
context
&
ctx
,
void
gemm_impl
(
context
&
ctx
,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
...
@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
...
@@ -45,8 +74,8 @@ void gemm_impl(context& ctx,
bool
int8_x4_format
,
bool
int8_x4_format
,
bool
compute_fp32
)
bool
compute_fp32
)
{
{
bool
transa
=
args
[
0
].
get_shape
()
.
transposed
(
);
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
args
[
1
].
get_shape
()
.
transposed
(
);
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
dim_0
=
n_dim
-
2
;
...
@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
...
@@ -142,6 +171,9 @@ void gemm_impl(context& ctx,
}
}
else
else
{
{
auto
a_stride
=
get_batch_stride
(
args
[
0
]);
auto
b_stride
=
get_batch_stride
(
args
[
1
]);
auto
c_stride
=
get_batch_stride
(
args
[
2
]);
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
rocblas_invoke
(
&
rocblas_gemm_strided_batched_ex
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
...
@@ -153,20 +185,20 @@ void gemm_impl(context& ctx,
to_pointer
(
args
.
at
(
1
)),
to_pointer
(
args
.
at
(
1
)),
arg_type
,
arg_type
,
ldb
,
ldb
,
k
*
n
,
b_stride
,
to_pointer
(
args
.
at
(
0
)),
to_pointer
(
args
.
at
(
0
)),
arg_type
,
arg_type
,
lda
,
lda
,
m
*
k
,
a_stride
,
beta_v
,
beta_v
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
m
*
n
,
c_stride
,
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
is_3inputs
?
to_pointer
(
args
[
3
])
:
to_pointer
(
args
[
2
]),
output_type
,
output_type
,
ldc
,
ldc
,
m
*
n
,
c_stride
,
num_matrices
,
num_matrices
,
compute_type
,
compute_type
,
rocblas_gemm_algo_standard
,
rocblas_gemm_algo_standard
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
5a1af3d1
...
@@ -18,6 +18,8 @@ namespace gpu {
...
@@ -18,6 +18,8 @@ namespace gpu {
struct
context
;
struct
context
;
void
blas_shape
(
const
shape
&
s
);
template
<
class
Op
>
template
<
class
Op
>
struct
rocblas_gemm
struct
rocblas_gemm
{
{
...
@@ -50,13 +52,14 @@ struct rocblas_gemm
...
@@ -50,13 +52,14 @@ struct rocblas_gemm
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
();
check_shapes
{
in_shapes
,
*
this
}.
not_broadcasted
();
b
atch_not_transposed
(
inputs
[
0
].
strides
()
);
b
las_shape
(
inputs
[
0
]
);
b
atch_not_transposed
(
inputs
[
1
].
strides
()
);
b
las_shape
(
inputs
[
1
]
);
// if gemm and add are fused
// if gemm and add are fused
if
(
not
float_equal
(
beta
,
0
)
)
if
(
in_shapes
.
size
()
>
2
)
{
{
auto
cmat_shape
=
in_shapes
.
back
();
auto
cmat_shape
=
in_shapes
.
back
();
in_shapes
.
pop_back
();
in_shapes
.
pop_back
();
blas_shape
(
cmat_shape
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
auto
op_out_shape
=
op
.
compute_shape
(
in_shapes
);
if
(
cmat_shape
.
lens
()
!=
op_out_shape
.
lens
())
if
(
cmat_shape
.
lens
()
!=
op_out_shape
.
lens
())
{
{
...
@@ -71,6 +74,7 @@ struct rocblas_gemm
...
@@ -71,6 +74,7 @@ struct rocblas_gemm
to_string
(
cmat_shape
.
type
())
+
to_string
(
cmat_shape
.
type
())
+
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
", it must be: "
+
to_string
(
op_out_shape
.
type
()));
}
}
return
op_out_shape
;
}
}
return
op
.
compute_shape
(
in_shapes
);
return
op
.
compute_shape
(
in_shapes
);
...
@@ -96,28 +100,6 @@ struct rocblas_gemm
...
@@ -96,28 +100,6 @@ struct rocblas_gemm
return
args
.
back
();
return
args
.
back
();
}
}
void
batch_not_transposed
(
const
std
::
vector
<
std
::
size_t
>&
strides
)
const
{
if
(
strides
.
size
()
<=
2
)
return
;
auto
dim_0
=
strides
.
size
()
-
2
;
auto
matrix_size
=
std
::
max
(
strides
[
dim_0
],
strides
[
dim_0
+
1
]);
std
::
vector
<
std
::
size_t
>
batch
(
strides
.
begin
(),
strides
.
begin
()
+
dim_0
);
if
(
std
::
all_of
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
)
{
return
(
i
<
matrix_size
);
}))
{
MIGRAPHX_THROW
(
"GPU_GEMM: matrix size and batch size {"
+
to_string_range
(
strides
)
+
"} are transposed!"
);
}
if
(
std
::
adjacent_find
(
batch
.
begin
(),
batch
.
end
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
i
<
j
or
i
<
matrix_size
or
j
<
matrix_size
);
})
!=
batch
.
end
())
{
MIGRAPHX_THROW
(
"GPU_GEMM: batch size {"
+
to_string_range
(
strides
)
+
"} is transposed!"
);
}
}
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
std
::
ptrdiff_t
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
{
return
shapes
.
size
()
-
1
;
return
shapes
.
size
()
-
1
;
...
...
src/targets/gpu/jit/gathernd.cpp
View file @
5a1af3d1
...
@@ -19,7 +19,7 @@ namespace gpu {
...
@@ -19,7 +19,7 @@ namespace gpu {
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
static
const
char
*
const
gathernd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/gathernd.hpp>
#include <migraphx/kernels/
basic_
ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
src/targets/gpu/jit/pointwise.cpp
View file @
5a1af3d1
...
@@ -27,7 +27,7 @@ namespace migraphx {
...
@@ -27,7 +27,7 @@ namespace migraphx {
${preamble}
${preamble}
extern "C" {
extern "C" {
__global__ void kernel(${params})
__global__ void
${
kernel
}
(${params})
{
{
auto idx = make_index();
auto idx = make_index();
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
pointwise(idx, auto_preload<${preloads}>(idx), vectorize<${vec_size}, ${axis}>())(${lambda}, ${args});
...
@@ -39,6 +39,18 @@ __global__ void kernel(${params})
...
@@ -39,6 +39,18 @@ __global__ void kernel(${params})
)__migraphx__"
;
)__migraphx__"
;
static
std
::
vector
<
std
::
string
>
get_op_names
(
const
module
&
m
)
{
std
::
vector
<
std
::
string
>
result
;
for
(
auto
&
ins
:
m
)
{
if
(
starts_with
(
ins
.
name
(),
"@"
))
continue
;
result
.
push_back
(
ins
.
name
());
}
return
result
;
}
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
struct
pointwise_compiler
:
compiler
<
pointwise_compiler
>
{
{
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"pointwise"
};
}
...
@@ -131,12 +143,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -131,12 +143,14 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
preloads
=
preload
(
axis
,
options
.
virtual_inputs
);
auto
is_preloading
=
auto
is_preloading
=
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
std
::
accumulate
(
preloads
.
begin
(),
preloads
.
end
(),
false
,
std
::
logical_or
<>
{});
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"kernel"
);
options
.
set_launch_params
(
v
,
options
.
set_launch_params
(
v
,
compute_global_for
(
ctx
,
compute_global_for
(
ctx
,
options
.
output
.
elements
()
/
vec_size
,
options
.
output
.
elements
()
/
vec_size
,
oversubscribe_if
(
not
is_preloading
)));
oversubscribe_if
(
not
is_preloading
)));
auto
src
=
interpolate_string
(
pointwise_kernel
,
auto
src
=
interpolate_string
(
pointwise_kernel
,
{{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{{
"kernel"
,
options
.
kernel_name
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"lambda"
,
v
.
at
(
"lambda"
).
to
<
std
::
string
>
()},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
{
"vec_size"
,
std
::
to_string
(
vec_size
)},
...
@@ -167,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
...
@@ -167,8 +181,13 @@ struct pointwise_compiler : compiler<pointwise_compiler>
auto
name
=
g
.
create_function
(
auto
name
=
g
.
create_function
(
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
g
.
generate_module
(
*
pm
).
set_attributes
({
"__device__"
}).
set_generic_types
(
*
pm
));
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
std
::
string
lambda
=
"MIGRAPHX_LIFT("
+
name
+
")"
;
auto
op_names
=
get_op_names
(
*
pm
);
op_names
.
push_back
(
"kernel"
);
auto
op_name_string
=
join_strings
(
op_names
,
"_"
);
return
replace
(
return
replace
(
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()}}));
compile_op
(
ctx
,
to_shapes
(
ins
->
inputs
()),
{{
"lambda"
,
lambda
},
{
"preamble"
,
g
.
str
()},
{
"kernel"
,
op_name_string
}}));
}
}
};
};
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/jit/roialign.cpp
View file @
5a1af3d1
...
@@ -19,7 +19,6 @@ namespace gpu {
...
@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
roialign_kernel
=
R"__migraphx__(
static
const
char
*
const
roialign_kernel
=
R"__migraphx__(
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/roialign.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
src/targets/gpu/jit/scatternd.cpp
View file @
5a1af3d1
...
@@ -19,7 +19,6 @@ namespace gpu {
...
@@ -19,7 +19,6 @@ namespace gpu {
// NOLINTNEXTLINE
// NOLINTNEXTLINE
static
const
char
*
const
scatternd_kernel
=
R"__migraphx__(
static
const
char
*
const
scatternd_kernel
=
R"__migraphx__(
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/scatternd.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <migraphx/kernels/generic_constant.hpp>
#include <args.hpp>
#include <args.hpp>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/array.hpp
View file @
5a1af3d1
...
@@ -146,8 +146,8 @@ struct array
...
@@ -146,8 +146,8 @@ struct array
constexpr
array
carry
(
array
result
)
const
constexpr
array
carry
(
array
result
)
const
{
{
u
in
t32_
t
overflow
=
0
;
in
dex_in
t
overflow
=
0
;
for
(
std
::
ptr
diff_t
i
=
result
.
size
()
-
1
;
i
>
0
;
i
--
)
for
(
diff_
in
t
i
=
result
.
size
()
-
1
;
i
>
0
;
i
--
)
{
{
auto
z
=
result
[
i
]
+
overflow
;
auto
z
=
result
[
i
]
+
overflow
;
// Reset overflow
// Reset overflow
...
...
src/targets/gpu/kernels/include/migraphx/kernels/basic_ops.hpp
deleted
100755 → 0
View file @
dfc7bbac
#ifndef MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#define MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
#include <migraphx/kernels/types.hpp>
namespace
migraphx
{
struct
sum
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
+
y
;
}
};
struct
product
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
x
*
y
;
}
};
struct
id
{
template
<
class
T
>
constexpr
auto
operator
()(
T
x
)
const
{
return
x
;
}
};
struct
mean
{
size_t
item_num
=
1
;
template
<
class
T
>
constexpr
auto
operator
()(
T
x
)
const
{
return
x
/
static_cast
<
T
>
(
item_num
);
}
};
struct
max_f
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
>
y
)
?
x
:
y
;
}
};
inline
constexpr
auto
max
=
max_f
{};
struct
min_f
{
template
<
class
T
,
class
U
>
constexpr
auto
operator
()(
T
x
,
U
y
)
const
{
return
(
x
<
y
)
?
x
:
y
;
}
};
inline
constexpr
auto
min
=
min_f
{};
struct
lowest
{
template
<
class
T
>
constexpr
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
lowest
();
}
};
struct
highest
{
template
<
class
T
>
constexpr
operator
T
()
const
{
return
std
::
numeric_limits
<
T
>::
max
();
}
};
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_AMDMIGRAPHX_KERNELS_BASIC_OPS_HPP
src/targets/gpu/kernels/include/migraphx/kernels/functional.hpp
View file @
5a1af3d1
...
@@ -137,7 +137,7 @@ constexpr auto by(F f)
...
@@ -137,7 +137,7 @@ constexpr auto by(F f)
template
<
class
F
,
class
...
Ts
>
template
<
class
F
,
class
...
Ts
>
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
constexpr
void
each_args
(
F
f
,
Ts
&&
...
xs
)
{
{
swallow
{(
f
(
st
d
::
forward
<
Ts
>
(
xs
)),
0
)...};
swallow
{(
f
(
st
atic_cast
<
Ts
&&
>
(
xs
)),
0
)...};
}
}
template
<
class
F
>
template
<
class
F
>
...
...
src/targets/gpu/kernels/include/migraphx/kernels/iota_iterator.hpp
View file @
5a1af3d1
...
@@ -13,7 +13,7 @@ struct basic_iota_iterator
...
@@ -13,7 +13,7 @@ struct basic_iota_iterator
F
f
;
F
f
;
using
difference_type
=
diff_int
;
using
difference_type
=
diff_int
;
using
reference
=
decltype
(
f
(
std
::
declval
<
Iterator
>
()));
using
reference
=
decltype
(
f
(
declval
<
Iterator
>
()));
using
value_type
=
remove_reference_t
<
reference
>
;
using
value_type
=
remove_reference_t
<
reference
>
;
using
pointer
=
add_pointer_t
<
value_type
>
;
using
pointer
=
add_pointer_t
<
value_type
>
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/roialign.hpp
View file @
5a1af3d1
...
@@ -3,14 +3,15 @@
...
@@ -3,14 +3,15 @@
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/dfor.hpp>
#include <migraphx/kernels/basic_ops.hpp>
#include <migraphx/kernels/ops.hpp>
#include <migraphx/kernels/math.hpp>
#include <migraphx/kernels/array.hpp>
#include <migraphx/kernels/array.hpp>
namespace
migraphx
{
namespace
migraphx
{
struct
max_pool
struct
max_pool
{
{
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
{
return
lowest
()
;
}
MIGRAPHX_DEVICE_CONSTEXPR
auto
init
()
{
return
lowest
{}
;
}
template
<
class
T
>
template
<
class
T
>
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
,
T
y
)
MIGRAPHX_DEVICE_CONSTEXPR
T
operator
()(
T
x
,
T
y
)
...
@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
...
@@ -55,7 +56,7 @@ MIGRAPHX_DEVICE_CONSTEXPR typename Iterator::value_type bilinear_interpolate(
return
0
;
return
0
;
}
}
xy
[
ii
]
=
max
(
xy
[
ii
],
0.0
f
);
xy
[
ii
]
=
migraphx
::
max
(
xy
[
ii
],
0.0
f
);
low
[
ii
]
=
xy
[
ii
];
low
[
ii
]
=
xy
[
ii
];
high
[
ii
]
=
low
[
ii
]
+
1
;
high
[
ii
]
=
low
[
ii
]
+
1
;
if
(
low
[
ii
]
>=
dims
[
ii
]
-
1
)
if
(
low
[
ii
]
>=
dims
[
ii
]
-
1
)
...
@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
...
@@ -164,11 +165,12 @@ __device__ void roialign(const T& x_t, const U& rois_t, const V& ind_t, W& y_t,
for
(
index_int
ii
=
0
;
ii
<
roi_size
.
size
();
++
ii
)
for
(
index_int
ii
=
0
;
ii
<
roi_size
.
size
();
++
ii
)
{
{
roi_size
[
ii
]
=
roi_ends
[
ii
]
-
roi_starts
[
ii
];
roi_size
[
ii
]
=
roi_ends
[
ii
]
-
roi_starts
[
ii
];
roi_size
[
ii
]
=
max
(
roi_size
[
ii
],
1.0
f
);
roi_size
[
ii
]
=
migraphx
::
max
(
roi_size
[
ii
],
1.0
f
);
bin_size
[
ii
]
=
roi_size
[
ii
]
/
out_dims
[
ii
];
bin_size
[
ii
]
=
roi_size
[
ii
]
/
out_dims
[
ii
];
bin_grid_size
[
ii
]
=
bin_grid_size
[
ii
]
=
(
s
.
sampling_ratio
>
0
)
(
s
.
sampling_ratio
>
0
)
?
s
.
sampling_ratio
:
std
::
ceil
(
roi_size
[
ii
]
/
out_dims
[
ii
]);
?
s
.
sampling_ratio
:
migraphx
::
ceil
(
roi_size
[
ii
]
/
out_dims
[
ii
]);
}
}
const
auto
offset_x
=
x
+
((
batch_ind
*
channel_num
+
c
)
*
in_dims
[
0
]
*
in_dims
[
1
]);
const
auto
offset_x
=
x
+
((
batch_ind
*
channel_num
+
c
)
*
in_dims
[
0
]
*
in_dims
[
1
]);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/tensor_view.hpp
View file @
5a1af3d1
...
@@ -11,7 +11,7 @@ template <class T>
...
@@ -11,7 +11,7 @@ template <class T>
struct
tensor_view_iterator_read
struct
tensor_view_iterator_read
{
{
T
*
view
;
T
*
view
;
constexpr
auto
&
operator
()(
std
::
size_
t
n
)
const
constexpr
auto
&
operator
()(
index_in
t
n
)
const
{
{
MIGRAPHX_ASSERT
(
view
!=
nullptr
);
MIGRAPHX_ASSERT
(
view
!=
nullptr
);
return
(
*
view
)[
n
];
return
(
*
view
)[
n
];
...
...
src/targets/gpu/kernels/include/migraphx/kernels/type_traits.hpp
View file @
5a1af3d1
...
@@ -35,6 +35,21 @@ struct enable_if<true, T>
...
@@ -35,6 +35,21 @@ struct enable_if<true, T>
template
<
bool
B
,
class
T
=
void
>
template
<
bool
B
,
class
T
=
void
>
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
using
enable_if_t
=
typename
enable_if
<
B
,
T
>::
type
;
template
<
bool
B
,
class
T
,
class
F
>
struct
conditional
{
using
type
=
T
;
};
template
<
class
T
,
class
F
>
struct
conditional
<
false
,
T
,
F
>
{
using
type
=
F
;
};
template
<
bool
B
,
class
T
,
class
F
>
using
conditional_t
=
typename
conditional
<
B
,
T
,
F
>::
type
;
// NOLINTNEXTLINE
// NOLINTNEXTLINE
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
#define MIGRAPHX_BUILTIN_TYPE_TRAIT1(name) \
template <class T> \
template <class T> \
...
...
src/targets/gpu/kernels/include/migraphx/kernels/vec.hpp
View file @
5a1af3d1
...
@@ -79,7 +79,7 @@ __device__ __host__ auto as_vec(T* x)
...
@@ -79,7 +79,7 @@ __device__ __host__ auto as_vec(T* x)
}
}
template
<
class
T
,
index_int
N
>
template
<
class
T
,
index_int
N
>
using
safe_vec
=
vec
<
std
::
conditional_t
<
std
::
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
using
safe_vec
=
vec
<
conditional_t
<
is_same
<
T
,
bool
>
{},
uint8_t
,
T
>
,
N
>
;
template
<
class
...
Ts
>
template
<
class
...
Ts
>
constexpr
auto
vec_transform
(
Ts
...
xs
)
constexpr
auto
vec_transform
(
Ts
...
xs
)
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment