Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
29d47724
"...composable_kernel_rocm.git" did not exist on "73f02a108347d626ee9b31789f0ff8b26ef87006"
Commit
29d47724
authored
Jun 19, 2019
by
Khalique
Browse files
add fixes to device code for pad
parent
6e53e190
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
12 additions
and
95 deletions
+12
-95
src/eliminate_pad.cpp
src/eliminate_pad.cpp
+0
-2
src/include/migraphx/op/convolution.hpp
src/include/migraphx/op/convolution.hpp
+2
-29
src/include/migraphx/op/pooling.hpp
src/include/migraphx/op/pooling.hpp
+0
-32
src/include/migraphx/pad_calc.hpp
src/include/migraphx/pad_calc.hpp
+0
-5
src/targets/gpu/device/pad.cpp
src/targets/gpu/device/pad.cpp
+10
-26
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+0
-1
No files found.
src/eliminate_pad.cpp
View file @
29d47724
...
@@ -44,8 +44,6 @@ void eliminate_pad::update_op(T,
...
@@ -44,8 +44,6 @@ void eliminate_pad::update_op(T,
std
::
array
<
size_t
,
2
>
new_pads
{
static_cast
<
size_t
>
(
pads
[
2
]),
static_cast
<
size_t
>
(
pads
[
3
])};
std
::
array
<
size_t
,
2
>
new_pads
{
static_cast
<
size_t
>
(
pads
[
2
]),
static_cast
<
size_t
>
(
pads
[
3
])};
T
op
=
any_cast
<
T
>
(
ins
->
get_operator
());
T
op
=
any_cast
<
T
>
(
ins
->
get_operator
());
// if(op.padding_mode != op::padding_mode_t::default_)
// return;
op
.
padding
=
new_pads
;
op
.
padding
=
new_pads
;
std
::
vector
<
instruction_ref
>
new_inputs
{
ins
->
inputs
()};
std
::
vector
<
instruction_ref
>
new_inputs
{
ins
->
inputs
()};
...
...
src/include/migraphx/op/convolution.hpp
View file @
29d47724
...
@@ -44,8 +44,7 @@ struct convolution
...
@@ -44,8 +44,7 @@ struct convolution
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
input
=
inputs
.
at
(
0
);
const
shape
&
weights
=
inputs
.
at
(
1
);
const
shape
&
weights
=
inputs
.
at
(
1
);
auto
t
=
input
.
type
();
auto
t
=
input
.
type
();
// if(padding_mode == default_)
// {
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
...
@@ -63,33 +62,7 @@ struct convolution
...
@@ -63,33 +62,7 @@ struct convolution
stride
[
1
]
+
stride
[
1
]
+
1
)),
1
)),
}};
}};
// }
// else if(padding_mode == same)
// {
// return {t,
// {input.lens()[0],
// weights.lens()[0],
// static_cast<std::size_t>(
// std::ceil(static_cast<double>(input.lens()[2]) / stride[0])),
// static_cast<std::size_t>(
// std::ceil(static_cast<double>(input.lens()[3]) / stride[1]))}};
// }
// else if(padding_mode == valid)
// {
// return {
// t,
// {input.lens()[0],
// weights.lens()[0],
// static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[2] - weights.lens()[2] + 1) / stride[0])),
// static_cast<std::size_t>(std::ceil(
// static_cast<double>(input.lens()[3] - weights.lens()[3] + 1) /
// stride[1]))}};
// }
// else
// {
// MIGRAPHX_THROW("Invalid padding mode");
// }
}
}
};
};
...
...
src/include/migraphx/op/pooling.hpp
View file @
29d47724
...
@@ -48,8 +48,6 @@ struct pooling
...
@@ -48,8 +48,6 @@ struct pooling
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
0
]
<=
(
input
.
lens
()[
2
]
+
2
*
padding
[
0
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
assert
(
lengths
[
1
]
<=
(
input
.
lens
()[
3
]
+
2
*
padding
[
1
]));
// if(padding_mode == default_)
// {
return
{
t
,
return
{
t
,
{
{
input
.
lens
()[
0
],
input
.
lens
()[
0
],
...
@@ -65,36 +63,6 @@ struct pooling
...
@@ -65,36 +63,6 @@ struct pooling
stride
[
1
])
+
stride
[
1
])
+
1
)),
1
)),
}};
}};
// }
// else if(padding_mode == same)
// {
// return {t,
// {input.lens()[0],
// input.lens()[1],
// ceil_divide<std::size_t>(input.lens()[2], stride[0]),
// ceil_divide<std::size_t>(input.lens()[3], stride[1])}};
// }
// else if(padding_mode == valid)
// {
// return {
// t,
// {
// input.lens()[0],
// input.lens()[1],
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// floor_divide<std::ptrdiff_t>(input.lens()[2] - lengths[0], stride[0]) +
// 1)),
// std::size_t(std::max<std::ptrdiff_t>(
// 1,
// floor_divide<std::ptrdiff_t>(input.lens()[3] - lengths[1], stride[1]) +
// 1)),
// }};
// }
// else
// {
// MIGRAPHX_THROW("Invalid padding mode");
// }
}
}
};
};
...
...
src/include/migraphx/pad_calc.hpp
View file @
29d47724
...
@@ -8,11 +8,6 @@
...
@@ -8,11 +8,6 @@
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
std
::
size_t
calculate_padding
(
std
::
size_t
weight_dim
,
std
::
size_t
dilation
)
{
return
(
dilation
*
(
weight_dim
-
1
))
/
2
;
}
inline
void
calculate_padding
(
int64_t
idx
,
inline
void
calculate_padding
(
int64_t
idx
,
std
::
vector
<
int64_t
>&
pads
,
std
::
vector
<
int64_t
>&
pads
,
int64_t
input_dim
,
int64_t
input_dim
,
...
...
src/targets/gpu/device/pad.cpp
View file @
29d47724
...
@@ -15,33 +15,17 @@ argument
...
@@ -15,33 +15,17 @@ argument
pad
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
float
value
,
std
::
vector
<
std
::
int64_t
>
pads
)
pad
(
hipStream_t
stream
,
argument
result
,
argument
arg1
,
float
value
,
std
::
vector
<
std
::
int64_t
>
pads
)
{
{
std
::
size_t
nelements
=
arg1
.
get_shape
().
elements
();
std
::
size_t
nelements
=
arg1
.
get_shape
().
elements
();
if
(
float_equal
(
value
,
std
::
numeric_limits
<
float
>::
lowest
()))
visit_all
(
result
)([
&
](
auto
output
)
{
{
auto
*
outptr
=
device_cast
(
output
.
data
());
auto
val
=
device_cast
(
std
::
numeric_limits
<
decltype
(
value
)
>::
lowest
());
using
type
=
typename
decltype
(
output
)
::
value_type
;
nary
(
stream
,
result
)([
=
]
{
return
val
;
});
device_type
<
type
>
device_val
=
value
;
// visit_all(result)([&](auto output) {
if
(
float_equal
(
value
,
std
::
numeric_limits
<
float
>::
lowest
()))
// auto* outptr = device_cast(output.data());
{
// auto val =
device_val
=
device_cast
(
std
::
numeric_limits
<
type
>::
lowest
());
// device_cast(std::numeric_limits<typename
}
// decltype(output)::value_type>::lowest());
gs_launch
(
stream
,
result
.
get_shape
().
elements
())([
=
](
auto
i
)
{
outptr
[
i
]
=
device_val
;
});
});
// gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
// });
}
else
{
// visit_all(result)([&](auto output) {
// auto* outptr = device_cast(output.data());
// auto val =
// device_cast(value);
// gs_launch(stream, nelements)([=](auto i) { outptr[i] = val; });
// });'
auto
val
=
device_cast
(
value
);
nary
(
stream
,
result
)([
=
]
{
return
val
;
});
}
// nary(stream, result)([=] { return value; });
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
arg1
)([
&
](
auto
output
,
auto
input
)
{
visit_tensor_size
(
result
.
get_shape
().
lens
().
size
(),
[
&
](
auto
ndim
)
{
visit_tensor_size
(
result
.
get_shape
().
lens
().
size
(),
[
&
](
auto
ndim
)
{
std
::
size_t
offsets
[
ndim
];
std
::
size_t
offsets
[
ndim
];
...
...
src/targets/gpu/lowering.cpp
View file @
29d47724
...
@@ -100,7 +100,6 @@ struct miopen_apply
...
@@ -100,7 +100,6 @@ struct miopen_apply
add_extend_op
<
miopen_contiguous
,
op
::
contiguous
>
(
"contiguous"
);
add_extend_op
<
miopen_contiguous
,
op
::
contiguous
>
(
"contiguous"
);
add_extend_op
<
hip_concat
,
op
::
concat
>
(
"concat"
);
add_extend_op
<
hip_concat
,
op
::
concat
>
(
"concat"
);
add_extend_op
<
hip_softmax
,
op
::
softmax
>
(
"softmax"
);
add_extend_op
<
hip_softmax
,
op
::
softmax
>
(
"softmax"
);
// add_extend_op<miopen_softmax, op::softmax>("softmax");
add_extend_op
<
hip_logsoftmax
,
op
::
logsoftmax
>
(
"logsoftmax"
);
add_extend_op
<
hip_logsoftmax
,
op
::
logsoftmax
>
(
"logsoftmax"
);
add_extend_op
<
hip_gather
,
op
::
gather
>
(
"gather"
);
add_extend_op
<
hip_gather
,
op
::
gather
>
(
"gather"
);
add_extend_op
<
hip_pad
,
op
::
pad
>
(
"pad"
);
add_extend_op
<
hip_pad
,
op
::
pad
>
(
"pad"
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment