Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
45e82cad
Commit
45e82cad
authored
Apr 03, 2019
by
Shucai Xiao
Browse files
clang format
parent
154b9287
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
41 additions
and
42 deletions
+41
-42
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+2
-1
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+39
-41
No files found.
src/include/migraphx/operators.hpp
View file @
45e82cad
...
@@ -857,7 +857,8 @@ struct dot
...
@@ -857,7 +857,8 @@ struct dot
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
out_lens
[
dim_1
]
=
b
.
lens
()[
dim_1
];
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
if
(
inputs
.
size
()
==
3
&&
out_lens
!=
inputs
.
at
(
2
).
lens
())
{
{
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
inputs
.
at
(
2
).
lens
())
+
MIGRAPHX_THROW
(
"DOT: dimension mismatch, operand C: {"
+
to_string_range
(
inputs
.
at
(
2
).
lens
())
+
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
"}, cannot add to operand A * B: {"
+
to_string_range
(
out_lens
)
+
"}"
);
}
}
...
...
src/targets/gpu/gemm.cpp
View file @
45e82cad
...
@@ -274,60 +274,58 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -274,60 +274,58 @@ argument miopen_gemm::compute(context& ctx,
// matrix * vector (b is a vector)
// matrix * vector (b is a vector)
else
if
(
b_lens
.
size
()
==
2
&&
b_lens
.
at
(
1
)
==
1
)
else
if
(
b_lens
.
size
()
==
2
&&
b_lens
.
at
(
1
)
==
1
)
{
{
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
rocblas_int
m
=
static_cast
<
rocblas_int
>
(
a_lens
[
0
]);
rocblas_int
m
=
static_cast
<
rocblas_int
>
(
a_lens
[
0
]);
rocblas_int
n
=
static_cast
<
rocblas_int
>
(
a_lens
[
1
]);
rocblas_int
n
=
static_cast
<
rocblas_int
>
(
a_lens
[
1
]);
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
1
:
0
];
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
assert
(
a_lens
.
back
()
==
args
[
1
].
get_shape
().
elements
());
assert
(
a_lens
.
back
()
==
args
[
1
].
get_shape
().
elements
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemv
(
generic_rocblas_gemv
(
as
,
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
m
,
n
,
n
,
&
alpha_r
,
&
alpha_r
,
to_pointer
(
args
[
0
]),
to_pointer
(
args
[
0
]),
lda
,
lda
,
to_pointer
(
args
[
1
]),
to_pointer
(
args
[
1
]),
1
,
1
,
&
beta_r
,
&
beta_r
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
1
);
1
);
});
});
}
}
// vector * matrix (a is a vector)
// vector * matrix (a is a vector)
else
if
(
a_lens
.
size
()
==
2
&&
a_lens
.
at
(
0
)
==
1
)
else
if
(
a_lens
.
size
()
==
2
&&
a_lens
.
at
(
0
)
==
1
)
{
{
bool
transb
=
!
args
[
1
].
get_shape
().
transposed
();
bool
transb
=
!
args
[
1
].
get_shape
().
transposed
();
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[(
transb
?
1
:
0
)];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[(
transb
?
1
:
0
)];
rocblas_int
m
=
b_lens
[
0
];
rocblas_int
m
=
b_lens
[
0
];
rocblas_int
n
=
b_lens
[
1
];
rocblas_int
n
=
b_lens
[
1
];
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
assert
(
b_lens
[
0
]
==
args
[
0
].
get_shape
().
elements
());
assert
(
b_lens
[
0
]
==
args
[
0
].
get_shape
().
elements
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
generic_rocblas_gemv
(
generic_rocblas_gemv
(
as
,
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
m
,
m
,
n
,
n
,
&
alpha_r
,
&
alpha_r
,
to_pointer
(
args
[
1
]),
to_pointer
(
args
[
1
]),
ldb
,
ldb
,
to_pointer
(
args
[
0
]),
to_pointer
(
args
[
0
]),
1
,
1
,
&
beta_r
,
&
beta_r
,
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
1
);
1
);
});
});
}
}
// batch matrix multiplication
// batch matrix multiplication
...
@@ -337,7 +335,7 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -337,7 +335,7 @@ argument miopen_gemm::compute(context& ctx,
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
n_dim
=
output_shape
.
lens
().
size
();
auto
dim_1
=
n_dim
-
1
;
auto
dim_1
=
n_dim
-
1
;
auto
dim_0
=
n_dim
-
2
;
auto
dim_0
=
n_dim
-
2
;
float
beta
=
0.0
f
;
float
beta
=
0.0
f
;
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
to_rocblas_type
(
as
(
beta
));
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
...
@@ -374,9 +372,9 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -374,9 +372,9 @@ argument miopen_gemm::compute(context& ctx,
ldc
,
ldc
,
m
*
n
,
m
*
n
,
num_matrices
);
num_matrices
);
});
});
}
}
return
args
[
2
];
return
args
[
2
];
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment