Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
734deb37
Unverified
Commit
734deb37
authored
Nov 07, 2018
by
Paul Fultz II
Committed by
GitHub
Nov 07, 2018
Browse files
Merge branch 'master' into transpose
parents
272bf6b1
63978fb4
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
99 additions
and
18 deletions
+99
-18
src/onnx/onnx.cpp
src/onnx/onnx.cpp
+4
-4
src/targets/gpu/fuse_ops.cpp
src/targets/gpu/fuse_ops.cpp
+2
-0
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+80
-14
test/gpu/miopen.cpp
test/gpu/miopen.cpp
+13
-0
No files found.
src/onnx/onnx.cpp
View file @
734deb37
...
...
@@ -534,7 +534,7 @@ struct onnx_parser
case
onnx
::
TensorProto
::
INT64
:
return
literal
{{
shape
::
int64_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
)
;
case
onnx
::
TensorProto
::
FLOAT16
:
return
literal
{{
shape
::
half_type
,
dims
},
s
.
data
()}
;
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{{
shape
::
double_type
,
dims
},
s
.
data
()};
case
onnx
::
TensorProto
::
UINT32
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
UINT64
:
throw
std
::
runtime_error
(
""
);
...
...
@@ -562,7 +562,8 @@ struct onnx_parser
case
onnx
::
TensorProto
::
STRING
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
BOOL
:
return
literal
{{
shape
::
int32_type
,
dims
},
t
.
int32_data
().
begin
(),
t
.
int32_data
().
end
()};
case
onnx
::
TensorProto
::
FLOAT16
:
throw
std
::
runtime_error
(
""
);
case
onnx
::
TensorProto
::
FLOAT16
:
return
literal
{{
shape
::
half_type
,
dims
},
t
.
float_data
().
begin
(),
t
.
float_data
().
end
()};
case
onnx
::
TensorProto
::
DOUBLE
:
return
literal
{
{
shape
::
double_type
,
dims
},
t
.
double_data
().
begin
(),
t
.
double_data
().
end
()};
...
...
@@ -593,8 +594,7 @@ struct onnx_parser
break
;
// throw std::runtime_error("Unsupported type STRING");
case
onnx
::
TensorProto
::
BOOL
:
break
;
// throw std::runtime_error("Unsupported type BOOL");
case
onnx
::
TensorProto
::
FLOAT16
:
break
;
// throw std::runtime_error("Unsupported type FLOAT16");
case
onnx
::
TensorProto
::
FLOAT16
:
shape_type
=
shape
::
half_type
;
break
;
case
onnx
::
TensorProto
::
DOUBLE
:
shape_type
=
shape
::
double_type
;
break
;
case
onnx
::
TensorProto
::
UINT32
:
shape_type
=
shape
::
uint32_type
;
break
;
case
onnx
::
TensorProto
::
UINT64
:
shape_type
=
shape
::
uint64_type
;
break
;
...
...
src/targets/gpu/fuse_ops.cpp
View file @
734deb37
...
...
@@ -132,6 +132,8 @@ MIGRAPH_PRED_MATCHER(fusable_conv, instruction_ref ins)
{
if
(
ins
->
name
()
!=
"gpu::convolution"
)
return
false
;
if
(
ins
->
get_shape
().
type
()
!=
shape
::
float_type
)
return
false
;
auto
wei
=
ins
->
inputs
().
at
(
1
)
->
get_shape
();
assert
(
wei
.
lens
().
size
()
==
4
);
auto
conv
=
any_cast
<
miopen_convolution
>
(
ins
->
get_operator
());
...
...
src/targets/gpu/gemm.cpp
View file @
734deb37
...
...
@@ -8,6 +8,65 @@ namespace migraph {
inline
namespace
MIGRAPH_INLINE_NS
{
namespace
gpu
{
template
<
class
...
Ts
>
void
generic_rocblas_gemm
(
shape
::
as
<
float
>
,
Ts
&&
...
xs
)
{
rocblas_sgemm
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
...
Ts
>
void
generic_rocblas_gemm
(
shape
::
as
<
double
>
,
Ts
&&
...
xs
)
{
rocblas_dgemm
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
...
Ts
>
void
generic_rocblas_gemm
(
shape
::
as
<
half
>
,
Ts
&&
...
xs
)
{
rocblas_hgemm
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
,
class
...
Ts
>
void
generic_rocblas_gemm
(
shape
::
as
<
T
>
,
Ts
&&
...)
{
MIGRAPH_THROW
(
"Type unsupported by rocblas"
);
}
template
<
class
T
>
struct
compute_rocblas_type
{
using
type
=
T
;
};
template
<
class
T
>
struct
compute_rocblas_type
<
const
T
>
{
using
type
=
const
typename
compute_rocblas_type
<
T
>::
type
;
};
template
<
>
struct
compute_rocblas_type
<
half
>
{
using
type
=
rocblas_half
;
};
template
<
class
T
>
using
rb_type
=
typename
compute_rocblas_type
<
T
>::
type
;
template
<
class
T
>
rb_type
<
T
>
to_rocblas_type
(
T
x
)
{
return
reinterpret_cast
<
const
rb_type
<
T
>&>
(
x
);
}
template
<
class
T
>
rb_type
<
T
>*
to_rocblas_type
(
T
*
x
)
{
return
reinterpret_cast
<
rb_type
<
T
>*>
(
x
);
}
rocblas_half
to_rocblas_type
(
half
x
)
{
return
reinterpret_cast
<
const
rocblas_half
&>
(
x
);
}
shape
miopen_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
);
...
...
@@ -27,20 +86,27 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int
m
=
output_shape
.
lens
()[
0
];
rocblas_int
n
=
output_shape
.
lens
()[
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
1
];
rocblas_sgemm
(
ctx
.
get_stream
().
get_rocblas
(),
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
m
,
k
,
&
alpha
,
args
[
1
].
implicit
(
),
&
alpha
_r
,
to_pointer
(
args
[
1
]
),
ldb
,
args
[
0
].
implicit
(
),
to_pointer
(
args
[
0
]
),
lda
,
&
beta
,
args
[
2
].
implicit
(
),
&
beta
_r
,
to_pointer
(
args
[
2
]
),
ldc
);
});
return
args
[
2
];
}
...
...
test/gpu/miopen.cpp
View file @
734deb37
...
...
@@ -508,6 +508,18 @@ struct test_gemm
}
};
struct
test_gemm_half
{
migraph
::
program
create_program
()
const
{
migraph
::
program
p
;
auto
a
=
p
.
add_parameter
(
"a"
,
migraph
::
shape
{
migraph
::
shape
::
half_type
,
{
4
,
5
}});
auto
b
=
p
.
add_parameter
(
"b"
,
migraph
::
shape
{
migraph
::
shape
::
half_type
,
{
5
,
3
}});
p
.
add_instruction
(
migraph
::
op
::
dot
{},
a
,
b
);
return
p
;
}
};
struct
test_gemm_ld
{
migraph
::
program
create_program
()
const
...
...
@@ -844,6 +856,7 @@ int main()
verify_program
<
test_global_avg_pooling
>
();
verify_program
<
test_global_max_pooling
>
();
verify_program
<
test_gemm
>
();
verify_program
<
test_gemm_half
>
();
// verify_program<test_gemm_ld>();
verify_program
<
test_gemm_transposeb
>
();
verify_program
<
test_gemm_transposea
>
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment