Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
9c16a90e
Commit
9c16a90e
authored
Feb 25, 2019
by
Shucai Xiao
Browse files
clang format
parent
c43eba64
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
24 additions
and
28 deletions
+24
-28
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+4
-8
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+8
-8
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+12
-12
No files found.
src/include/migraphx/operators.hpp
View file @
9c16a90e
...
@@ -844,19 +844,15 @@ struct dot
...
@@ -844,19 +844,15 @@ struct dot
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
// change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6},
// change to support cases like {1, 1, 3, 5} X {1, 1, 5, 6},
// which can be handled by numpy. as long as all previous
// which can be handled by numpy. as long as all previous
// dims are 1 except the last two dims, the two matrices
// dims are 1 except the last two dims, the two matrices
// are multipliable
// are multipliable
if
(
std
::
any_of
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
[](
auto
i
)
{
if
(
std
::
any_of
(
a
.
lens
().
rbegin
()
+
2
,
a
.
lens
().
rend
(),
[](
auto
i
)
{
return
(
i
!=
1
);
}))
return
(
i
!=
1
);
}))
{
{
MIGRAPHX_THROW
(
"DOT: first matrix, dimensions before matrix dims must be 1"
);
MIGRAPHX_THROW
(
"DOT: first matrix, dimensions before matrix dims must be 1"
);
}
}
if
(
std
::
any_of
(
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
(),
[](
auto
i
)
{
if
(
std
::
any_of
(
b
.
lens
().
rbegin
()
+
2
,
b
.
lens
().
rend
(),
[](
auto
i
)
{
return
(
i
!=
1
);
}))
return
(
i
!=
1
);
}))
{
{
MIGRAPHX_THROW
(
"DOT: second matrix, dimensions before matrix dims must be 1"
);
MIGRAPHX_THROW
(
"DOT: second matrix, dimensions before matrix dims must be 1"
);
}
}
...
@@ -865,7 +861,7 @@ struct dot
...
@@ -865,7 +861,7 @@ struct dot
if
(
a
.
lens
()[
n_dims
-
1
]
!=
b
.
lens
()[
n_dims
-
2
])
if
(
a
.
lens
()[
n_dims
-
1
]
!=
b
.
lens
()[
n_dims
-
2
])
MIGRAPHX_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
MIGRAPHX_THROW
(
"Inner dimensions do not match: {"
+
to_string_range
(
a
.
lens
())
+
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
"} x {"
+
to_string_range
(
b
.
lens
())
+
"}"
);
auto
out_lens
=
a
.
lens
();
auto
out_lens
=
a
.
lens
();
out_lens
[
n_dims
-
1
]
=
b
.
lens
()[
n_dims
-
1
];
out_lens
[
n_dims
-
1
]
=
b
.
lens
()[
n_dims
-
1
];
return
{
t
,
out_lens
};
return
{
t
,
out_lens
};
}
}
...
...
src/targets/cpu/gemm.cpp
View file @
9c16a90e
...
@@ -14,10 +14,10 @@ template <class T>
...
@@ -14,10 +14,10 @@ template <class T>
static
auto
make_mat
(
tensor_view
<
T
>
x
)
static
auto
make_mat
(
tensor_view
<
T
>
x
)
{
{
const
auto
&
s
=
x
.
get_shape
();
const
auto
&
s
=
x
.
get_shape
();
//assert(s.lens().size() == 2);
//
assert(s.lens().size() == 2);
std
::
size_t
n_dims
=
s
.
lens
().
size
();
std
::
size_t
n_dims
=
s
.
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
std
::
size_t
dim_1
=
n_dims
-
1
;
if
(
s
.
transposed
())
if
(
s
.
transposed
())
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_1
],
s
.
lens
()[
dim_0
],
s
.
strides
()[
dim_1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_1
],
s
.
lens
()[
dim_0
],
s
.
strides
()[
dim_1
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_0
],
s
.
lens
()[
dim_1
],
s
.
strides
()[
dim_0
]};
return
matrix
<
T
>
{
x
.
data
(),
s
.
lens
()[
dim_0
],
s
.
lens
()[
dim_1
],
s
.
strides
()[
dim_0
]};
...
@@ -68,11 +68,11 @@ void migemm_impl(tensor_view<T> cmat,
...
@@ -68,11 +68,11 @@ void migemm_impl(tensor_view<T> cmat,
std
::
false_type
)
std
::
false_type
)
{
{
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
std
::
size_t
n_dims
=
cmat
.
get_shape
().
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
std
::
size_t
dim_1
=
n_dims
-
1
;
auto
m
=
cmat
.
get_shape
().
lens
()[
dim_0
];
auto
m
=
cmat
.
get_shape
().
lens
()[
dim_0
];
auto
n
=
cmat
.
get_shape
().
lens
()[
dim_1
];
auto
n
=
cmat
.
get_shape
().
lens
()[
dim_1
];
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
auto
k
=
amat
.
get_shape
().
lens
()[
dim_1
];
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
amat
.
get_shape
().
lens
()[
dim_1
]
==
bmat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
m
==
amat
.
get_shape
().
lens
()[
dim_0
]);
assert
(
m
==
amat
.
get_shape
().
lens
()[
dim_0
]);
...
...
src/targets/gpu/gemm.cpp
View file @
9c16a90e
...
@@ -76,19 +76,19 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -76,19 +76,19 @@ argument miopen_gemm::compute(context& ctx,
const
shape
&
output_shape
,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
const
std
::
vector
<
argument
>&
args
)
const
{
{
float
alpha
=
1.0
f
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
std
::
size_t
n_dims
=
args
[
0
].
get_shape
().
lens
().
size
();
std
::
size_t
n_dims
=
args
[
0
].
get_shape
().
lens
().
size
();
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_0
=
n_dims
-
2
;
std
::
size_t
dim_1
=
n_dims
-
1
;
std
::
size_t
dim_1
=
n_dims
-
1
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
dim_0
];
rocblas_int
m
=
output_shape
.
lens
()[
dim_0
];
rocblas_int
m
=
output_shape
.
lens
()[
dim_0
];
rocblas_int
n
=
output_shape
.
lens
()[
dim_1
];
rocblas_int
n
=
output_shape
.
lens
()[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment