Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d14dd144
Commit
d14dd144
authored
Aug 07, 2018
by
Aditya Atluri
Browse files
added batch norm inference for miopen
parent
d1481b13
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
58 additions
and
0 deletions
+58
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+58
-0
No files found.
src/targets/gpu/lowering.cpp
View file @
d14dd144
...
@@ -15,6 +15,48 @@
...
@@ -15,6 +15,48 @@
namespace
migraph
{
namespace
migraph
{
namespace
gpu
{
namespace
gpu
{
struct
miopen_batch_norm_inference
{
batch_norm_inference
op
;
std
::
string
name
()
const
{
return
"gpu::batch_norm_inference"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
6
);
return
op
.
compute_shape
(
{
inputs
.
at
(
0
),
inputs
.
at
(
1
),
inputs
.
at
(
2
),
inputs
.
at
(
3
),
inputs
.
at
(
4
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
y_desc
=
make_tensor
(
output_shape
);
float
alpha
=
1.0
,
beta
=
0.0
f
;
// TODO: adityaatluri
// create bn-scale-bias-mean-variance descriptor for
// miopen call
miopenBatchNormalizationForwardInference
(
ctx
.
handle
.
get
(),
miopenBatchNormMode_t
(
op
.
bn_mode
),
&
alpha
,
&
beta
,
x_desc
.
get
(),
args
[
0
].
implicit
(),
y_desc
.
get
(),
args
[
5
].
implicit
(),
bn_desc
,
args
[
3
].
implicit
(),
args
[
4
].
implicit
(),
args
[
1
].
implicit
(),
args
[
2
].
implicit
(),
op
.
mode
.
epsilon
);
return
args
[
5
];
}
};
struct
miopen_convolution
struct
miopen_convolution
{
{
convolution
op
;
convolution
op
;
...
@@ -259,6 +301,12 @@ struct miopen_apply
...
@@ -259,6 +301,12 @@ struct miopen_apply
{
{
apply_contiguous
(
it
);
apply_contiguous
(
it
);
}
}
// TODO: adityaatluri
// tagging to easily find where code changed
else
if
(
it
->
op
.
name
()
==
"batch_norm_inference"
)
{
apply_batch_norm_inference
(
it
);
}
}
}
}
}
...
@@ -332,6 +380,16 @@ struct miopen_apply
...
@@ -332,6 +380,16 @@ struct miopen_apply
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
}
}
// TODO: adityaatluri
// Not sure how to write this. Review and fix required
void
apply_batch_norm_inference
(
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
batch_norm_inference
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_batch_norm_inference
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
}
};
};
void
lowering
::
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
void
lowering
::
apply
(
program
&
p
)
const
{
miopen_apply
{
&
p
}.
apply
();
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment