Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
359ec2f8
"vscode:/vscode.git/clone" did not exist on "eb06cc6bd5ea01a3bc3ef535bf463b9289f84c94"
Commit
359ec2f8
authored
Mar 08, 2019
by
Shucai Xiao
Browse files
clang format
parent
9d52515a
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
65 additions
and
65 deletions
+65
-65
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+46
-45
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+19
-20
No files found.
src/include/migraphx/operators.hpp
View file @
359ec2f8
...
@@ -830,45 +830,46 @@ struct dot
...
@@ -830,45 +830,46 @@ struct dot
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
return
pack
(
f
(
self
.
alpha
,
"alpha"
),
f
(
self
.
beta
,
"beta"
));
}
}
std
::
vector
<
std
::
size_t
>
shape_broadcast
(
std
::
vector
<
std
::
size_t
>
&
a
,
std
::
vector
<
std
::
size_t
>
&
b
)
const
std
::
vector
<
std
::
size_t
>
shape_broadcast
(
std
::
vector
<
std
::
size_t
>&
a
,
std
::
vector
<
std
::
size_t
>&
b
)
const
{
{
if
(
a
.
empty
())
if
(
a
.
empty
())
return
b
;
return
b
;
else
if
(
b
.
empty
())
else
if
(
b
.
empty
())
return
a
;
return
a
;
auto
a_size
=
a
.
size
();
auto
a_size
=
a
.
size
();
auto
b_size
=
b
.
size
();
auto
b_size
=
b
.
size
();
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
auto
n_dim
=
std
::
min
(
a_size
,
b_size
);
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
std
::
vector
<
std
::
size_t
>
out_lens
(
std
::
max
(
a_size
,
b_size
));
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
for
(
std
::
size_t
i
=
0
;
i
<
n_dim
;
++
i
)
{
{
if
(
a
[
a_size
-
1
-
i
]
==
b
[
b_size
-
1
-
i
])
if
(
a
[
a_size
-
1
-
i
]
==
b
[
b_size
-
1
-
i
])
{
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
}
else
if
(
a
[
a_size
-
1
-
i
]
==
1
)
else
if
(
a
[
a_size
-
1
-
i
]
==
1
)
{
{
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
out_lens
[
i
]
=
b
[
b_size
-
1
-
i
];
}
}
else
if
(
b
[
b_size
-
1
-
i
]
==
1
)
else
if
(
b
[
b_size
-
1
-
i
]
==
1
)
{
{
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
out_lens
[
i
]
=
a
[
a_size
-
1
-
i
];
}
}
else
else
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
to_string_range
(
b
)
"}, and matrix B: {"
+
to_string_range
(
b
)
+
+
"} are not broadcastable"
);
"} are not broadcastable"
);
}
}
}
}
if
(
a_size
>
n_dim
)
if
(
a_size
>
n_dim
)
{
{
std
::
copy
(
a
.
rbegin
()
+
n_dim
,
a
.
rend
(),
out_lens
.
begin
()
+
n_dim
);
std
::
copy
(
a
.
rbegin
()
+
n_dim
,
a
.
rend
(),
out_lens
.
begin
()
+
n_dim
);
}
}
if
(
b_size
>
n_dim
)
if
(
b_size
>
n_dim
)
{
{
std
::
copy
(
b
.
rbegin
()
+
n_dim
,
b
.
rend
(),
out_lens
.
rbegin
()
+
n_dim
);
std
::
copy
(
b
.
rbegin
()
+
n_dim
,
b
.
rend
(),
out_lens
.
rbegin
()
+
n_dim
);
}
}
...
@@ -886,7 +887,7 @@ struct dot
...
@@ -886,7 +887,7 @@ struct dot
const
shape
&
b
=
inputs
.
at
(
1
);
const
shape
&
b
=
inputs
.
at
(
1
);
auto
t
=
a
.
type
();
auto
t
=
a
.
type
();
if
(
a
.
scalar
()
||
b
.
scalar
())
if
(
a
.
scalar
()
||
b
.
scalar
())
{
{
MIGRAPHX_THROW
(
"DOT: scalar operands are not allowed, use op::mul{} instead"
);
MIGRAPHX_THROW
(
"DOT: scalar operands are not allowed, use op::mul{} instead"
);
}
}
...
@@ -894,26 +895,26 @@ struct dot
...
@@ -894,26 +895,26 @@ struct dot
auto
a_lens
=
a
.
lens
();
auto
a_lens
=
a
.
lens
();
auto
b_lens
=
b
.
lens
();
auto
b_lens
=
b
.
lens
();
std
::
vector
<
std
::
size_t
>
out_lens
;
std
::
vector
<
std
::
size_t
>
out_lens
;
if
(
a_lens
.
size
()
==
1
)
if
(
a_lens
.
size
()
==
1
)
{
{
// inner product, output is a scalar, following numpy.matmul()
// inner product, output is a scalar, following numpy.matmul()
if
(
b_lens
.
size
()
==
1
)
if
(
b_lens
.
size
()
==
1
)
{
{
if
(
a_lens
.
front
()
!=
b_lens
.
front
())
if
(
a_lens
.
front
()
!=
b_lens
.
front
())
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
to_string_range
(
a_lens
)
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
"}, cannot multiply vector B: {"
+
to_string_range
(
b_lens
)
to_string_range
(
a_lens
)
+
"}, cannot multiply vector B: {"
+
+
"}"
);
to_string_range
(
b_lens
)
+
"}"
);
}
}
}
}
else
else
{
{
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
if
(
a_lens
.
front
()
!=
b_lens
[
dim_0
])
if
(
a_lens
.
front
()
!=
b_lens
[
dim_0
])
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
to_string_range
(
a_lens
)
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, vector A: {"
+
"}, cannot multiply matrix B: {"
+
to_string_range
(
b_lens
)
to_string_range
(
a_lens
)
+
"}, cannot multiply matrix B: {"
+
+
"}"
);
to_string_range
(
b_lens
)
+
"}"
);
}
}
out_lens
=
b_lens
;
out_lens
=
b_lens
;
...
@@ -923,13 +924,13 @@ struct dot
...
@@ -923,13 +924,13 @@ struct dot
else
else
{
{
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
if
(
b_lens
.
size
()
==
1
)
if
(
b_lens
.
size
()
==
1
)
{
{
if
(
a_lens
.
back
()
!=
b_lens
.
back
())
if
(
a_lens
.
back
()
!=
b_lens
.
back
())
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a_lens
)
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
"}, cannot multiply vector B: {"
+
to_string_range
(
b_lens
)
to_string_range
(
a_lens
)
+
"}, cannot multiply vector B: {"
+
+
"}"
);
to_string_range
(
b_lens
)
+
"}"
);
}
}
out_lens
=
a_lens
;
out_lens
=
a_lens
;
...
@@ -939,11 +940,11 @@ struct dot
...
@@ -939,11 +940,11 @@ struct dot
{
{
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_0
=
a_lens
.
size
()
-
1
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
std
::
size_t
dim_1
=
b_lens
.
size
()
-
2
;
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
if
(
a_lens
[
dim_0
]
!=
b_lens
[
dim_1
])
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a_lens
)
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
"}, cannot multiply matrix B: {"
+
to_string_range
(
b_lens
)
to_string_range
(
a_lens
)
+
"}, cannot multiply matrix B: {"
+
+
"}"
);
to_string_range
(
b_lens
)
+
"}"
);
}
}
a_lens
.
pop_back
();
a_lens
.
pop_back
();
...
@@ -961,7 +962,7 @@ struct dot
...
@@ -961,7 +962,7 @@ struct dot
}
}
// c is broadcast
// c is broadcast
if
(
inputs
.
size
()
==
3
)
if
(
inputs
.
size
()
==
3
)
// according to the specification of the numpy.matmul()
// according to the specification of the numpy.matmul()
// inputs with the shape dims more than 2 are acceptable
// inputs with the shape dims more than 2 are acceptable
...
...
src/targets/gpu/gemm.cpp
View file @
359ec2f8
...
@@ -127,8 +127,7 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -127,8 +127,7 @@ argument miopen_gemm::compute(context& ctx,
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
op
.
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_batched_gemm
(
generic_rocblas_batched_gemm
(
as
,
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment