Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
52b9cf14
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "0dac827c2859fb304c1d36778630cab4fe9edda7"
Commit
52b9cf14
authored
Mar 13, 2019
by
Shucai Xiao
Browse files
clang format
parent
3ccd7e15
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
162 additions
and
147 deletions
+162
-147
src/include/migraphx/operators.hpp
src/include/migraphx/operators.hpp
+14
-14
src/targets/cpu/gemm.cpp
src/targets/cpu/gemm.cpp
+1
-1
src/targets/cpu/lowering.cpp
src/targets/cpu/lowering.cpp
+21
-23
src/targets/gpu/gemm.cpp
src/targets/gpu/gemm.cpp
+122
-108
src/targets/gpu/include/migraphx/gpu/gemm.hpp
src/targets/gpu/include/migraphx/gpu/gemm.hpp
+4
-1
No files found.
src/include/migraphx/operators.hpp
View file @
52b9cf14
...
@@ -869,9 +869,9 @@ struct dot
...
@@ -869,9 +869,9 @@ struct dot
}
}
else
else
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, matrix A: {"
+
to_string_range
(
a
)
+
to_string_range
(
a
)
+
"}, and matrix B: {"
+
"}, and matrix B: {"
+
to_string_range
(
b
)
+
to_string_range
(
b
)
+
"} are not broadcastable"
);
"} are not broadcastable"
);
}
}
}
}
...
@@ -895,7 +895,7 @@ struct dot
...
@@ -895,7 +895,7 @@ struct dot
{
{
// If there are 3 inputs, then A and B must be matrices and
// If there are 3 inputs, then A and B must be matrices and
// C is broadcastable to A * B
// C is broadcastable to A * B
if
(
inputs
.
size
()
==
3
)
if
(
inputs
.
size
()
==
3
)
{
{
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_type
();
check_shapes
{
inputs
,
*
this
}.
has
(
3
).
same_type
();
check_shapes
{{
inputs
[
0
]},
*
this
}.
only_dims
(
2
);
check_shapes
{{
inputs
[
0
]},
*
this
}.
only_dims
(
2
);
...
@@ -904,7 +904,7 @@ struct dot
...
@@ -904,7 +904,7 @@ struct dot
auto
a_lens
=
inputs
[
0
].
lens
();
auto
a_lens
=
inputs
[
0
].
lens
();
auto
b_lens
=
inputs
[
1
].
lens
();
auto
b_lens
=
inputs
[
1
].
lens
();
auto
t
=
inputs
[
0
].
type
();
auto
t
=
inputs
[
0
].
type
();
if
(
a_lens
[
1
]
!=
b_lens
[
0
])
if
(
a_lens
[
1
]
!=
b_lens
[
0
])
{
{
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
MIGRAPHX_THROW
(
"DOT : dimension mismatch, operand A: {"
+
to_string_range
(
a_lens
)
+
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
"}, cannot multiply operand B: {"
+
to_string_range
(
b_lens
)
+
"}"
);
...
@@ -915,7 +915,7 @@ struct dot
...
@@ -915,7 +915,7 @@ struct dot
// check whether C is broadcastable to A * B
// check whether C is broadcastable to A * B
auto
c_lens
=
inputs
[
2
].
lens
();
auto
c_lens
=
inputs
[
2
].
lens
();
if
(
c_lens
.
size
()
>
2
||
if
(
c_lens
.
size
()
>
2
||
(
c_lens
.
size
()
>=
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
0
]))
||
(
c_lens
.
size
()
>=
1
&&
(
c_lens
[
0
]
!=
1
&&
c_lens
[
0
]
!=
b_lens
[
0
]))
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
a_lens
[
1
])))
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
!=
1
&&
c_lens
[
1
]
!=
a_lens
[
1
])))
{
{
...
...
src/targets/cpu/gemm.cpp
View file @
52b9cf14
src/targets/cpu/lowering.cpp
View file @
52b9cf14
...
@@ -376,27 +376,26 @@ struct cpu_gemm
...
@@ -376,27 +376,26 @@ struct cpu_gemm
auto
out_lens
=
result
.
get_shape
().
lens
();
auto
out_lens
=
result
.
get_shape
().
lens
();
auto
c_lens
=
c
.
get_shape
().
lens
();
auto
c_lens
=
c
.
get_shape
().
lens
();
if
(
out_lens
==
c_lens
)
if
(
out_lens
==
c_lens
)
{
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
std
::
memcpy
(
output
.
data
(),
input
.
data
(),
c_shape
.
bytes
());
std
::
memcpy
(
output
.
data
(),
input
.
data
(),
c_shape
.
bytes
());
});
});
}
}
// need broadcast
// need broadcast
else
if
(
c
.
single
())
else
if
(
c
.
single
())
{
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
std
::
fill
(
output
.
begin
(),
output
.
end
(),
input
.
front
());
std
::
fill
(
output
.
begin
(),
output
.
end
(),
input
.
front
());
});
});
}
}
// must be c_lens[0] == output_lens[1]
// must be c_lens[0] == output_lens[1]
else
if
(
c_lens
.
size
()
==
1
||
else
if
(
c_lens
.
size
()
==
1
||
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
==
out_lens
[
1
])))
(
c_lens
.
size
()
==
2
&&
(
c_lens
[
1
]
==
out_lens
[
1
])))
{
{
std
::
size_t
m
=
out_lens
[
0
];
std
::
size_t
m
=
out_lens
[
0
];
std
::
size_t
n
=
out_lens
[
1
];
std
::
size_t
n
=
out_lens
[
1
];
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
for
(
std
::
size_t
i
=
0
;
i
<
m
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
m
;
i
++
)
{
{
std
::
memcpy
((
output
.
data
()
+
i
*
n
),
input
.
data
(),
c_shape
.
bytes
());
std
::
memcpy
((
output
.
data
()
+
i
*
n
),
input
.
data
(),
c_shape
.
bytes
());
}
}
...
@@ -408,10 +407,11 @@ struct cpu_gemm
...
@@ -408,10 +407,11 @@ struct cpu_gemm
std
::
size_t
m
=
out_lens
[
0
];
std
::
size_t
m
=
out_lens
[
0
];
std
::
size_t
n
=
out_lens
[
1
];
std
::
size_t
n
=
out_lens
[
1
];
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
visit_all
(
result
,
c
)([
&
](
auto
output
,
auto
input
)
{
for
(
std
::
size_t
i
=
0
;
i
<
m
;
i
++
)
for
(
std
::
size_t
i
=
0
;
i
<
m
;
i
++
)
{
{
std
::
fill
(
output
.
begin
()
+
i
*
n
,
std
::
fill
(
output
.
begin
()
+
i
*
n
,
(
i
+
1
==
m
)
?
output
.
end
()
:
output
.
begin
()
+
((
i
+
1
)
*
n
),
input
[
i
]);
(
i
+
1
==
m
)
?
output
.
end
()
:
output
.
begin
()
+
((
i
+
1
)
*
n
),
input
[
i
]);
}
}
});
});
}
}
...
@@ -422,14 +422,13 @@ struct cpu_gemm
...
@@ -422,14 +422,13 @@ struct cpu_gemm
argument
result
{
output_shape
};
argument
result
{
output_shape
};
// 3 inputs, it is alpha * A * B + beta * C, then
// 3 inputs, it is alpha * A * B + beta * C, then
// A and B are matrics, and C is broadcastable to A * B
// A and B are matrics, and C is broadcastable to A * B
if
(
args
.
size
()
==
3
)
if
(
args
.
size
()
==
3
)
{
{
// no need to consider the value of args[2]
// no need to consider the value of args[2]
if
(
op
.
beta
==
0.0
f
)
if
(
op
.
beta
==
0.0
f
)
{
{
result
.
visit
([
&
](
auto
output
)
{
result
.
visit
(
std
::
memset
(
output
.
data
(),
0
,
output_shape
.
bytes
());
[
&
](
auto
output
)
{
std
::
memset
(
output
.
data
(),
0
,
output_shape
.
bytes
());
});
});
}
}
else
else
{
{
...
@@ -445,9 +444,8 @@ struct cpu_gemm
...
@@ -445,9 +444,8 @@ struct cpu_gemm
// all args are scalar
// all args are scalar
if
(
output_shape
.
scalar
())
if
(
output_shape
.
scalar
())
{
{
visit_all
(
result
,
args
[
0
],
args
[
1
])([
&
](
auto
res
,
auto
a
,
auto
b
)
{
visit_all
(
result
,
args
[
0
],
args
[
1
])(
res
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
];
[
&
](
auto
res
,
auto
a
,
auto
b
)
{
res
[
0
]
=
op
.
alpha
*
a
[
0
]
*
b
[
0
];
});
});
return
result
;
return
result
;
}
}
...
...
src/targets/gpu/gemm.cpp
View file @
52b9cf14
...
@@ -171,24 +171,24 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
...
@@ -171,24 +171,24 @@ shape miopen_gemm::compute_shape(const std::vector<shape>& inputs) const
return
op
.
compute_shape
(
inputs
);
return
op
.
compute_shape
(
inputs
);
}
}
void
miopen_gemm
::
fill_result
(
context
&
ctx
,
const
shape
&
output_shape
,
void
miopen_gemm
::
fill_result
(
context
&
ctx
,
const
argument
&
result
,
const
argument
&
c
)
const
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
{
{
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
auto
c_lens
=
c
.
get_shape
().
lens
();
auto
c_lens
=
c
.
get_shape
().
lens
();
if
(
output_shape
==
c
.
get_shape
())
if
(
output_shape
==
c
.
get_shape
())
{
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()));
};
hipMemcpy
(
to_pointer
(
args
[
3
]),
hipMemcpy
(
to_pointer
(
args
[
3
]),
to_pointer
(
args
[
2
]),
to_pointer
(
args
[
2
]),
output_shape
.
bytes
(),
output_shape
.
bytes
(),
hipMemcpyDeviceToDevice
);
hipMemcpyDeviceToDevice
);
});
});
}
}
else
if
(
c
.
single
())
else
if
(
c
.
single
())
{
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
)
{
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
)
{
...
@@ -204,8 +204,7 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape,
...
@@ -204,8 +204,7 @@ void miopen_gemm::fill_result(context& ctx, const shape& output_shape,
}
}
});
});
}
}
else
if
(
c_lens
.
size
()
==
1
||
else
if
(
c_lens
.
size
()
==
1
||
(
c_lens
.
size
()
==
2
&&
c_lens
[
1
]
==
out_lens
[
1
]))
(
c_lens
.
size
()
==
2
&&
c_lens
[
1
]
==
out_lens
[
1
]))
{
{
auto
m
=
out_lens
[
0
];
auto
m
=
out_lens
[
0
];
auto
n
=
out_lens
[
1
];
auto
n
=
out_lens
[
1
];
...
@@ -247,7 +246,7 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -247,7 +246,7 @@ argument miopen_gemm::compute(context& ctx,
const
std
::
vector
<
argument
>&
args
)
const
const
std
::
vector
<
argument
>&
args
)
const
{
{
bool
is_3inputs
=
(
args
.
size
()
==
4
);
bool
is_3inputs
=
(
args
.
size
()
==
4
);
if
(
is_3inputs
)
if
(
is_3inputs
)
{
{
fill_result
(
ctx
,
output_shape
,
args
[
3
],
args
[
2
]);
fill_result
(
ctx
,
output_shape
,
args
[
3
],
args
[
2
]);
...
@@ -302,16 +301,13 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -302,16 +301,13 @@ argument miopen_gemm::compute(context& ctx,
1
,
1
,
to_pointer
(
args
[
2
]));
to_pointer
(
args
[
2
]));
generic_rocblas_scal
(
as
,
generic_rocblas_scal
(
ctx
.
get_stream
().
get_rocblas
(),
as
,
ctx
.
get_stream
().
get_rocblas
(),
1
,
&
alpha_r
,
to_pointer
(
args
[
2
]));
1
,
&
alpha_r
,
to_pointer
(
args
[
2
]));
1
);
1
);
});
});
}
}
// matrix * vector
// matrix * vector
else
if
(
args
[
1
].
get_shape
().
lens
().
size
()
==
1
)
else
if
(
args
[
1
].
get_shape
().
lens
().
size
()
==
1
)
{
{
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
std
::
size_t
dim_0
=
a_lens
.
size
()
-
2
;
std
::
size_t
dim_0
=
a_lens
.
size
()
-
2
;
...
@@ -323,12 +319,15 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -323,12 +319,15 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
trans
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
trans
?
dim_1
:
dim_0
];
assert
(
a_lens
.
back
()
==
args
[
1
].
get_shape
().
elements
());
assert
(
a_lens
.
back
()
==
args
[
1
].
get_shape
().
elements
());
std
::
size_t
batch_num
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
batch_num
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
a_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
{
{
generic_rocblas_gemv
(
as
,
generic_rocblas_gemv
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
...
@@ -341,13 +340,12 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -341,13 +340,12 @@ argument miopen_gemm::compute(context& ctx,
to_pointer
(
args
[
1
]),
to_pointer
(
args
[
1
]),
1
,
1
,
&
beta_r
,
&
beta_r
,
to_pointer
(
args
[
2
],
batch_no
*
n
)
to_pointer
(
args
[
2
],
batch_no
*
n
)
1
);
1
);
}
}
});
});
}
}
// vector * matrix
// vector * matrix
else
if
(
args
[
0
].
get_shape
().
lens
().
size
()
==
1
)
else
if
(
args
[
0
].
get_shape
().
lens
().
size
()
==
1
)
{
{
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
std
::
size_t
dim_0
=
b_lens
.
size
()
-
2
;
...
@@ -359,12 +357,15 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -359,12 +357,15 @@ argument miopen_gemm::compute(context& ctx,
rocblas_int
lda
=
args
[
1
].
get_shape
().
strides
()[
trans
?
dim_1
:
dim_0
];
rocblas_int
lda
=
args
[
1
].
get_shape
().
strides
()[
trans
?
dim_1
:
dim_0
];
assert
(
b_lens
.
back
()
==
args
[
0
].
get_shape
().
elements
());
assert
(
b_lens
.
back
()
==
args
[
0
].
get_shape
().
elements
());
std
::
size_t
batch_num
=
std
::
accumulate
(
b_lens
.
rbegin
()
+
2
,
b_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
batch_num
=
std
::
accumulate
(
b_lens
.
rbegin
()
+
2
,
b_lens
.
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
for
(
std
::
size_t
batch_no
=
0
;
batch_no
<
batch_num
;
++
batch_no
)
{
{
generic_rocblas_gemv
(
as
,
generic_rocblas_gemv
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
...
@@ -377,8 +378,7 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -377,8 +378,7 @@ argument miopen_gemm::compute(context& ctx,
to_pointer
(
args
[
1
],
batch_no
*
m
*
n
),
to_pointer
(
args
[
1
],
batch_no
*
m
*
n
),
1
,
1
,
&
beta_r
,
&
beta_r
,
to_pointer
(
args
[
2
],
batch_no
*
m
)
to_pointer
(
args
[
2
],
batch_no
*
m
)
1
);
1
);
}
}
});
});
}
}
...
@@ -391,17 +391,19 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -391,17 +391,19 @@ argument miopen_gemm::compute(context& ctx,
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
a_lens
.
size
()
-
1
:
a_lens
.
size
()
-
2
];
rocblas_int
lda
=
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
b_lens
.
size
()
-
1
:
b_lens
.
size
()
-
2
];
args
[
0
].
get_shape
().
strides
()[
transa
?
a_lens
.
size
()
-
1
:
a_lens
.
size
()
-
2
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
b_lens
.
size
()
-
1
:
b_lens
.
size
()
-
2
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
out_lens
.
size
()
-
2
];
rocblas_int
ldc
=
args
[
2
].
get_shape
().
strides
()[
out_lens
.
size
()
-
2
];
rocblas_int
m
=
out_lens
[
out_lens
.
size
()
-
2
];
rocblas_int
m
=
out_lens
[
out_lens
.
size
()
-
2
];
rocblas_int
n
=
out_lens
[
out_lens
.
size
()
-
1
];
rocblas_int
n
=
out_lens
[
out_lens
.
size
()
-
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
a_lens
.
size
()
-
1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
a_lens
.
size
()
-
1
];
auto
input_dims
=
std
::
min
(
a_lens
.
size
(),
b_lens
.
size
());
auto
input_dims
=
std
::
min
(
a_lens
.
size
(),
b_lens
.
size
());
std
::
size_t
axis
{
0
};
std
::
size_t
axis
{
0
};
for
(
axis
=
2
;
axis
<
input_dims
;
++
axis
)
for
(
axis
=
2
;
axis
<
input_dims
;
++
axis
)
{
{
if
(
a_lens
[
a_lens
.
size
()
-
axis
]
!=
b_lens
[
b_lens
.
size
()
-
axis
])
if
(
a_lens
[
a_lens
.
size
()
-
axis
]
!=
b_lens
[
b_lens
.
size
()
-
axis
])
{
{
break
;
break
;
}
}
...
@@ -410,14 +412,19 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -410,14 +412,19 @@ argument miopen_gemm::compute(context& ctx,
// The number of matrices that can be computed in one call
// The number of matrices that can be computed in one call
// batch_num > 1, we need to call the batch_gemm function,
// batch_num > 1, we need to call the batch_gemm function,
// otherwise, call the gemm function directly
// otherwise, call the gemm function directly
std
::
size_t
num_matrices
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
std
::
size_t
num_matrices
=
std
::
accumulate
(
a_lens
.
rbegin
()
+
2
,
(
axis
==
a_lens
.
size
()
?
a_lens
.
rend
()
:
a_lens
.
rbegin
()
+
axis
),
(
axis
==
a_lens
.
size
()
?
a_lens
.
rend
()
:
a_lens
.
rbegin
()
+
axis
),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
size_t
a_len_diff
=
out_lens
.
size
()
-
a_lens
.
size
();
std
::
size_t
a_len_diff
=
out_lens
.
size
()
-
a_lens
.
size
();
std
::
size_t
b_len_diff
=
out_lens
.
size
()
-
b_lens
.
size
();
std
::
size_t
b_len_diff
=
out_lens
.
size
()
-
b_lens
.
size
();
std
::
vector
<
std
::
size_t
>
a_batch_lens
(
a_lens
.
begin
(),
a_lens
.
begin
()
+
a_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
a_batch_lens
(
a_lens
.
begin
(),
std
::
vector
<
std
::
size_t
>
b_batch_lens
(
b_lens
.
begin
(),
b_lens
.
begin
()
+
b_lens
.
size
()
-
axis
);
a_lens
.
begin
()
+
a_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
out_batch_lens
(
out_lens
.
begin
(),
out_lens
.
begin
()
+
out_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
b_batch_lens
(
b_lens
.
begin
(),
b_lens
.
begin
()
+
b_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
out_batch_lens
(
out_lens
.
begin
(),
out_lens
.
begin
()
+
out_lens
.
size
()
-
axis
);
shape
::
type_t
t
=
output_shape
.
type
();
shape
::
type_t
t
=
output_shape
.
type
();
shape
a_batch_shape
{
t
,
a_batch_lens
};
shape
a_batch_shape
{
t
,
a_batch_lens
};
...
@@ -428,12 +435,16 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -428,12 +435,16 @@ argument miopen_gemm::compute(context& ctx,
std
::
size_t
out_ind
=
out_batch_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
());
std
::
size_t
out_ind
=
out_batch_shape
.
index
(
out_idx
.
begin
(),
out_idx
.
end
());
std
::
vector
<
std
::
size_t
>
a_idx
(
a_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
a_idx
(
a_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
b_idx
(
b_lens
.
size
()
-
axis
);
std
::
vector
<
std
::
size_t
>
b_idx
(
b_lens
.
size
()
-
axis
);
std
::
transform
(
out_idx
.
begin
()
+
a_len_diff
,
out_idx
.
end
(),
a_batch_lens
.
begin
(),
a_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
std
::
transform
(
out_idx
.
begin
()
+
a_len_diff
,
return
(
j
==
1
)
?
0
:
i
;
out_idx
.
end
(),
});
a_batch_lens
.
begin
(),
std
::
transform
(
out_idx
.
begin
()
+
b_len_diff
,
out_idx
.
end
(),
b_batch_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
a_idx
.
begin
(),
return
(
j
==
1
)
?
0
:
i
;
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
});
std
::
transform
(
out_idx
.
begin
()
+
b_len_diff
,
out_idx
.
end
(),
b_batch_lens
.
begin
(),
b_idx
.
begin
(),
[
&
](
auto
i
,
auto
j
)
{
return
(
j
==
1
)
?
0
:
i
;
});
std
::
size_t
a_ind
=
a_batch_shape
.
index
(
a_idx
.
begin
(),
b_idx
.
end
());
std
::
size_t
a_ind
=
a_batch_shape
.
index
(
a_idx
.
begin
(),
b_idx
.
end
());
std
::
size_t
b_ind
=
b_batch_shape
.
index
(
b_idx
.
begin
(),
b_idx
.
end
());
std
::
size_t
b_ind
=
b_batch_shape
.
index
(
b_idx
.
begin
(),
b_idx
.
end
());
...
@@ -441,8 +452,11 @@ argument miopen_gemm::compute(context& ctx,
...
@@ -441,8 +452,11 @@ argument miopen_gemm::compute(context& ctx,
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
beta_r
=
=
to_rocblas_type
(
as
(
beta
));
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
auto
to_pointer
=
[
&
](
auto
&&
arg
,
std
::
size_t
offset
=
0
)
{
generic_rocblas_batched_gemm
(
as
,
return
to_rocblas_type
(
as
.
from
(
arg
.
data
()
+
offset
));
};
generic_rocblas_batched_gemm
(
as
,
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
src/targets/gpu/include/migraphx/gpu/gemm.hpp
View file @
52b9cf14
...
@@ -20,7 +20,10 @@ struct miopen_gemm
...
@@ -20,7 +20,10 @@ struct miopen_gemm
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
int
output_alias
(
const
std
::
vector
<
shape
>&
shapes
)
const
{
return
shapes
.
size
()
-
1
;
}
private:
private:
void
fill_result
(
context
&
ctx
,
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
;
void
fill_result
(
context
&
ctx
,
const
shape
&
output_shape
,
const
argument
&
result
,
const
argument
&
c
)
const
;
};
};
}
// namespace gpu
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment