Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
2c779d0b
Commit
2c779d0b
authored
Aug 23, 2019
by
Shucai Xiao
Browse files
simplify the gemm call for int8
parent
8f9a766f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
50 deletions
+5
-50
src/targets/gpu/quant_gemm.cpp
src/targets/gpu/quant_gemm.cpp
+5
-50
No files found.
src/targets/gpu/quant_gemm.cpp
View file @
2c779d0b
...
@@ -8,51 +8,6 @@ namespace migraphx {
...
@@ -8,51 +8,6 @@ namespace migraphx {
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
template
<
class
...
Ts
>
rocblas_status
generic_rocblas_gemm_ex
(
Ts
&&
...
xs
)
{
return
rocblas_gemm_ex
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
...
Ts
>
rocblas_status
generic_rocblas_batched_gemm_ex
(
Ts
&&
...
xs
)
{
return
rocblas_gemm_strided_batched_ex
(
std
::
forward
<
Ts
>
(
xs
)...);
}
template
<
class
T
>
struct
compute_rocblas_type
{
using
type
=
T
;
};
template
<
class
T
>
struct
compute_rocblas_type
<
const
T
>
{
using
type
=
const
typename
compute_rocblas_type
<
T
>::
type
;
};
template
<
>
struct
compute_rocblas_type
<
half
>
{
using
type
=
rocblas_half
;
};
template
<
class
T
>
using
rb_type
=
typename
compute_rocblas_type
<
T
>::
type
;
template
<
class
T
>
rb_type
<
T
>
to_rocblas_type
(
T
x
)
{
return
reinterpret_cast
<
const
rb_type
<
T
>&>
(
x
);
}
template
<
class
T
>
rb_type
<
T
>*
to_rocblas_type
(
T
*
x
)
{
return
reinterpret_cast
<
rb_type
<
T
>*>
(
x
);
}
shape
rocblas_quant_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
shape
rocblas_quant_gemm
::
compute_shape
(
const
std
::
vector
<
shape
>&
inputs
)
const
{
{
std
::
vector
<
shape
>
in_shapes
(
inputs
);
std
::
vector
<
shape
>
in_shapes
(
inputs
);
...
@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
...
@@ -102,13 +57,13 @@ argument rocblas_quant_gemm::compute(context& ctx,
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
a_lens
=
args
[
0
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
auto
b_lens
=
args
[
1
].
get_shape
().
lens
();
output_shape
.
visit_type
([
&
](
auto
as
)
{
output_shape
.
visit_type
([
&
](
auto
as
)
{
auto
alpha_r
=
to_rocblas_type
(
as
(
op
.
alpha
)
)
;
auto
alpha_r
=
as
(
op
.
alpha
);
auto
beta_r
=
to_rocblas_type
(
as
(
beta
)
)
;
auto
beta_r
=
as
(
beta
);
auto
out_lens
=
output_shape
.
lens
();
auto
out_lens
=
output_shape
.
lens
();
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
m
=
out_lens
[
dim_0
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
n
=
out_lens
[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
rocblas_int
k
=
args
[
0
].
get_shape
().
lens
()[
dim_1
];
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
to_rocblas_type
(
as
.
from
(
arg
.
data
())
)
;
};
auto
to_pointer
=
[
&
](
auto
&&
arg
)
{
return
as
.
from
(
arg
.
data
());
};
assert
(
k
%
4
==
0
);
assert
(
k
%
4
==
0
);
auto
num_matrices
=
std
::
accumulate
(
auto
num_matrices
=
std
::
accumulate
(
...
@@ -119,7 +74,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
...
@@ -119,7 +74,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
// column-major format. When doing a C = A * B, we actually do
// column-major format. When doing a C = A * B, we actually do
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// C^T = (B^T) * (A^T). That is the reason we input args[1] as
// A and args[0] as B in calling the rocblas_gemm.
// A and args[0] as B in calling the rocblas_gemm.
generic_
rocblas_gemm_ex
(
ctx
.
get_stream
().
get_rocblas
(),
rocblas_gemm_ex
(
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
n
,
n
,
...
@@ -148,7 +103,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
...
@@ -148,7 +103,7 @@ argument rocblas_quant_gemm::compute(context& ctx,
}
}
else
else
{
{
generic_rocblas
_batched_
gemm_
ex
(
rocblas_gemm_strided
_batched_ex
(
ctx
.
get_stream
().
get_rocblas
(),
ctx
.
get_stream
().
get_rocblas
(),
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transb
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
transa
?
rocblas_operation_transpose
:
rocblas_operation_none
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment