Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
09d332d9
Commit
09d332d9
authored
Jul 18, 2018
by
wsttiger
Browse files
Fixed up ROCBLAS implementation of GEMM on GPU
parent
afa4a833
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
19 additions
and
19 deletions
+19
-19
CMakeLists.txt
CMakeLists.txt
+0
-4
src/targets/miopen/CMakeLists.txt
src/targets/miopen/CMakeLists.txt
+5
-0
src/targets/miopen/include/migraph/miopen/context.hpp
src/targets/miopen/include/migraph/miopen/context.hpp
+2
-0
src/targets/miopen/miopen_lowering.cpp
src/targets/miopen/miopen_lowering.cpp
+11
-14
src/targets/miopen/miopen_target.cpp
src/targets/miopen/miopen_target.cpp
+1
-1
No files found.
CMakeLists.txt
View file @
09d332d9
...
@@ -20,10 +20,6 @@ endif()
...
@@ -20,10 +20,6 @@ endif()
add_compile_options
(
-std=c++14
)
add_compile_options
(
-std=c++14
)
# rocblas
find_package
(
rocblas REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with rocblas"
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake
)
list
(
APPEND CMAKE_MODULE_PATH
${
CMAKE_CURRENT_SOURCE_DIR
}
/cmake
)
include
(
EnableCompilerWarnings
)
include
(
EnableCompilerWarnings
)
# Override clang-tidy to not find the version from hcc
# Override clang-tidy to not find the version from hcc
...
...
src/targets/miopen/CMakeLists.txt
View file @
09d332d9
...
@@ -2,6 +2,10 @@
...
@@ -2,6 +2,10 @@
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc
)
list
(
APPEND CMAKE_PREFIX_PATH /opt/rocm /opt/rocm/hip /opt/rocm/hcc
)
find_package
(
miopen
)
find_package
(
miopen
)
# rocblas
find_package
(
rocblas REQUIRED PATHS /opt/rocm
)
message
(
STATUS
"Build with rocblas"
)
if
(
NOT TARGET MIOpen
)
if
(
NOT TARGET MIOpen
)
message
(
SEND_ERROR
"Cant find miopen"
)
message
(
SEND_ERROR
"Cant find miopen"
)
endif
()
endif
()
...
@@ -11,6 +15,7 @@ add_library(migraph_miopen
...
@@ -11,6 +15,7 @@ add_library(migraph_miopen
miopen_target.cpp
miopen_target.cpp
miopen_lowering.cpp
miopen_lowering.cpp
miopen_write_literals.cpp
miopen_write_literals.cpp
rocblas.cpp
)
)
rocm_clang_tidy_check
(
migraph_miopen
)
rocm_clang_tidy_check
(
migraph_miopen
)
target_link_libraries
(
migraph_miopen migraph MIOpen rocblas
)
target_link_libraries
(
migraph_miopen migraph MIOpen rocblas
)
...
...
src/targets/miopen/include/migraph/miopen/context.hpp
View file @
09d332d9
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#define MIGRAPH_GUARD_RTGLIB_CONTEXT_HPP
#include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/miopen.hpp>
#include <migraph/miopen/rocblas.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
miopen
{
namespace
miopen
{
...
@@ -9,6 +10,7 @@ namespace miopen {
...
@@ -9,6 +10,7 @@ namespace miopen {
struct
miopen_context
struct
miopen_context
{
{
shared
<
miopen_handle
>
handle
;
shared
<
miopen_handle
>
handle
;
shared
<
rocblas_handle_ptr
>
rbhandle
;
};
};
}
// namespace miopen
}
// namespace miopen
...
...
src/targets/miopen/miopen_lowering.cpp
View file @
09d332d9
#include <rocblas.h>
#include <migraph/miopen/miopen_lowering.hpp>
#include <migraph/miopen/miopen_lowering.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/manage_ptr.hpp>
#include <migraph/instruction.hpp>
#include <migraph/instruction.hpp>
...
@@ -7,16 +8,12 @@
...
@@ -7,16 +8,12 @@
#include <migraph/miopen/hip.hpp>
#include <migraph/miopen/hip.hpp>
#include <migraph/dfor.hpp>
#include <migraph/dfor.hpp>
#include <migraph/iterator_for.hpp>
#include <migraph/iterator_for.hpp>
#include <rocblas.h>
#include <migraph/miopen/rocblas.hpp>
#include <migraph/miopen/context.hpp>
namespace
migraph
{
namespace
migraph
{
namespace
miopen
{
namespace
miopen
{
struct
miopen_context
{
shared
<
miopen_handle
>
handle
;
};
struct
miopen_convolution
struct
miopen_convolution
{
{
convolution
op
;
convolution
op
;
...
@@ -28,7 +25,7 @@ struct miopen_convolution
...
@@ -28,7 +25,7 @@ struct miopen_convolution
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
migraph
::
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
...
@@ -80,7 +77,7 @@ struct miopen_pooling
...
@@ -80,7 +77,7 @@ struct miopen_pooling
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
check_shapes
{
inputs
,
*
this
}.
has
(
2
);
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
migraph
::
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
auto
x_desc
=
make_tensor
(
args
[
0
].
get_shape
());
...
@@ -113,7 +110,7 @@ struct miopen_add
...
@@ -113,7 +110,7 @@ struct miopen_add
return
inputs
.
at
(
0
);
return
inputs
.
at
(
0
);
}
}
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
migraph
::
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
if
(
args
[
1
].
get_shape
().
broadcasted
())
if
(
args
[
1
].
get_shape
().
broadcasted
())
{
{
...
@@ -160,10 +157,10 @@ struct miopen_gemm
...
@@ -160,10 +157,10 @@ struct miopen_gemm
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
return
op
.
compute_shape
({
inputs
.
at
(
0
),
inputs
.
at
(
1
)});
}
}
argument
compute
(
context
&
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
migraph
::
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
rocblas_handle
rochandle
;
//
rocblas_handle
_ptr handle_ptr = create_rocblas_handle_ptr();
rocblas_create_handle
(
&
rochandle
);
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
lens
()[
1
];
...
@@ -172,7 +169,7 @@ struct miopen_gemm
...
@@ -172,7 +169,7 @@ struct miopen_gemm
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_sgemm
(
rochandle
,
rocblas_sgemm
(
ctx
.
rbhandle
.
get
()
,
rocblas_operation_none
,
rocblas_operation_none
,
rocblas_operation_none
,
rocblas_operation_none
,
n
,
n
,
...
@@ -200,7 +197,7 @@ struct miopen_relu
...
@@ -200,7 +197,7 @@ struct miopen_relu
return
inputs
.
at
(
1
);
return
inputs
.
at
(
1
);
}
}
argument
compute
(
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
argument
compute
(
migraph
::
context
&
gctx
,
shape
output_shape
,
std
::
vector
<
argument
>
args
)
const
{
{
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
auto
&
ctx
=
any_cast
<
miopen_context
>
(
gctx
);
float
alpha
=
1
,
beta
=
0
;
float
alpha
=
1
,
beta
=
0
;
...
...
src/targets/miopen/miopen_target.cpp
View file @
09d332d9
...
@@ -15,7 +15,7 @@ std::string miopen_target::name() const { return "miopen"; }
...
@@ -15,7 +15,7 @@ std::string miopen_target::name() const { return "miopen"; }
context
miopen_target
::
get_context
()
const
context
miopen_target
::
get_context
()
const
{
{
return
miopen_context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
))};
return
miopen_context
{
share
(
make_obj
<
miopen_handle
>
(
&
miopenCreate
))
,
share
(
create_rocblas_handle_ptr
())
};
}
}
}
// namespace miopen
}
// namespace miopen
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment