Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
5c5115af
Commit
5c5115af
authored
Aug 04, 2018
by
Paul
Browse files
Add support for transpose gemm using blaze
parent
cd07476b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
6 deletions
+51
-6
src/include/migraph/operators.hpp
src/include/migraph/operators.hpp
+1
-1
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+7
-5
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+43
-0
No files found.
src/include/migraph/operators.hpp
View file @
5c5115af
...
@@ -324,7 +324,7 @@ struct gemm
...
@@ -324,7 +324,7 @@ struct gemm
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
a
.
lens
()[
1
]
!=
b
.
lens
()[
0
])
if
(
a
.
lens
()[
1
]
!=
b
.
lens
()[
0
])
MIGRAPH_THROW
(
"Inner dimensions do not match"
);
MIGRAPH_THROW
(
"Inner dimensions do not match
: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}
"
);
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
return
{
t
,
{
a
.
lens
()[
0
],
b
.
lens
()[
1
]}};
}
}
...
...
src/targets/gpu/lowering.cpp
View file @
5c5115af
...
@@ -159,15 +159,17 @@ struct miopen_gemm
...
@@ -159,15 +159,17 @@ struct miopen_gemm
{
{
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
rocblas_int
ldb
=
args
[
1
].
get_shape
().
lens
()[
1
];
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
rocblas_int
ldc
=
args
[
2
].
get_shape
().
lens
()[
1
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
1
:
0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
0
];
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_sgemm
(
ctx
.
rbhandle
.
get
(),
rocblas_sgemm
(
ctx
.
rbhandle
.
get
(),
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
n
,
m
,
m
,
k
,
k
,
...
...
test/gpu/miopen.cpp
View file @
5c5115af
...
@@ -135,6 +135,46 @@ struct test_gemm
...
@@ -135,6 +135,46 @@ struct test_gemm
}
}
};
};
struct
test_gemm_transposeb
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
4
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
3
,
5
}});
auto
bt
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
b
);
p
.
add_instruction
(
migraph
::
gemm
{},
a
,
bt
);
return
p
;
}
};
struct
test_gemm_transposea
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
4
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
3
}});
auto
at
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
a
);
p
.
add_instruction
(
migraph
::
gemm
{},
at
,
b
);
return
p
;
}
};
struct
test_gemm_transposeab
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
5
,
4
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
float_type
,
{
3
,
5
}});
auto
at
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
a
);
auto
bt
=
p
.
add_instruction
(
migraph
::
transpose
{{
1
,
0
}},
b
);
p
.
add_instruction
(
migraph
::
gemm
{},
at
,
bt
);
return
p
;
}
};
struct
test_contiguous
struct
test_contiguous
{
{
migraph
::
program
create_program
()
const
migraph
::
program
create_program
()
const
...
@@ -168,6 +208,9 @@ int main()
...
@@ -168,6 +208,9 @@ int main()
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_relu
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_conv_pooling
>
();
verify_program
<
test_gemm
>
();
verify_program
<
test_gemm
>
();
verify_program
<
test_gemm_transposeb
>
();
verify_program
<
test_gemm_transposea
>
();
verify_program
<
test_gemm_transposeab
>
();
verify_program
<
test_contiguous
>
();
verify_program
<
test_contiguous
>
();
verify_program
<
test_transpose
>
();
verify_program
<
test_transpose
>
();
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment