Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
04ca2e74
Commit
04ca2e74
authored
Jul 13, 2018
by
wsttiger
Browse files
Merge branch 'contiguous_on_cpu' into hip_contiguous
parents
5f68a283
e29a2396
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
102 additions
and
0 deletions
+102
-0
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+59
-0
test/miopen/miopen.cpp
test/miopen/miopen.cpp
+43
-0
No files found.
src/targets/miopen/miopen_target.cpp
View file @
04ca2e74
...
...
@@ -173,6 +173,43 @@ struct miopen_gemm
}
};
struct
miopen_transpose
{
transpose
op
;
std
::
string
name
()
const
{
return
"miopen::transpose"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
return
{
output_shape
,
std
::
move
(
args
.
front
().
data
)};
}
};
struct
miopen_contiguous
{
contiguous
op
;
std
::
string
name
()
const
{
return
"miopen::contiguous"
;
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
0
)});
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
argument
result
{
output_shape
};
visit_all
(
result
,
from_gpu
(
args
[
0
]))([
&
](
auto
output
,
auto
input
)
{
shape_for_each
(
output
.
get_shape
(),
[
&
](
const
auto
&
idx
)
{
output
(
idx
.
begin
(),
idx
.
end
())
=
input
(
idx
.
begin
(),
idx
.
end
());
});
});
return
to_gpu
(
result
);
}
};
struct
miopen_relu
{
shared
<
activation_descriptor
>
ad
;
...
...
@@ -231,6 +268,14 @@ struct miopen_apply
{
apply_gemm
(
it
);
}
else
if
(
it
->
op
.
name
()
==
"transpose"
)
{
apply_transpose
(
it
);
}
else
if
(
it
->
op
.
name
()
==
"contiguous"
)
{
apply_contiguous
(
it
);
}
}
}
...
...
@@ -297,6 +342,20 @@ struct miopen_apply
prog
->
replace_instruction
(
ins
,
miopen_gemm
{
op
},
ins
->
arguments
.
at
(
0
),
ins
->
arguments
.
at
(
1
),
output
);
}
void
apply_transpose
(
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
transpose
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_transpose
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
}
void
apply_contiguous
(
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
contiguous
>
(
ins
->
op
);
auto
output
=
insert_allocation
(
ins
,
ins
->
result
);
prog
->
replace_instruction
(
ins
,
miopen_contiguous
{
op
},
ins
->
arguments
.
at
(
0
),
output
);
}
};
struct
miopen_pass
...
...
test/miopen/miopen.cpp
View file @
04ca2e74
...
...
@@ -158,6 +158,47 @@ struct test_gemm
}
};
struct
test_contiguous
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
,
4
,
4
,
3
},
{
48
,
4
,
1
,
16
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
p
.
add_instruction
(
migraph
::
contiguous
{},
x
);
return
p
;
}
migraph
::
program
::
parameter_map
create_params
()
const
{
migraph
::
program
::
parameter_map
m
;
m
[
"x"
]
=
migraph
::
generate_argument
({
migraph
::
shape
::
float_type
,
{
4
,
4
,
4
,
3
},
{
48
,
4
,
1
,
16
}});
return
m
;
}
};
struct
test_transpose
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
migraph
::
shape
s
{
migraph
::
shape
::
float_type
,
{
4
,
3
,
4
,
4
}};
auto
x
=
p
.
add_parameter
(
"x"
,
s
);
std
::
vector
<
int64_t
>
perm
=
{
0
,
2
,
3
,
1
};
auto
l
=
p
.
add_instruction
(
migraph
::
transpose
{
perm
},
x
);
p
.
add_instruction
(
migraph
::
contiguous
{},
l
);
return
p
;
}
migraph
::
program
::
parameter_map
create_params
()
const
{
migraph
::
program
::
parameter_map
m
;
m
[
"x"
]
=
migraph
::
generate_argument
({
migraph
::
shape
::
float_type
,
{
4
,
3
,
4
,
4
}});
return
m
;
}
};
int
main
()
{
verify_program
<
test_add
>
();
...
...
@@ -165,4 +206,6 @@ int main()
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_gemm
>
();
verify_program
<
test_contiguous
>
();
verify_program
<
test_transpose
>
();
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment