Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
69145ea1
Commit
69145ea1
authored
May 10, 2019
by
Shucai Xiao
Browse files
refine the implementation.
parent
32addf31
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
20 additions
and
69 deletions
+20
-69
src/targets/gpu/include/migraphx/gpu/quant_gemm.hpp
src/targets/gpu/include/migraphx/gpu/quant_gemm.hpp
+2
-0
src/targets/gpu/lowering.cpp
src/targets/gpu/lowering.cpp
+1
-41
src/targets/gpu/quant_gemm.cpp
src/targets/gpu/quant_gemm.cpp
+17
-28
No files found.
src/targets/gpu/include/migraphx/gpu/quant_gemm.hpp
View file @
69145ea1
...
...
@@ -13,6 +13,8 @@ struct context;
struct
miopen_quant_gemm
{
op
::
quant_dot
op
;
mutable
argument
arg_a
{};
mutable
argument
arg_b
{};
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
...
...
src/targets/gpu/lowering.cpp
View file @
69145ea1
...
...
@@ -98,6 +98,7 @@ struct miopen_apply
add_generic_op
<
hip_min
>
(
"min"
);
add_extend_op
<
miopen_gemm
,
op
::
dot
>
(
"dot"
);
add_extend_op
<
miopen_quant_gemm
,
op
::
quant_dot
>
(
"quant_dot"
);
add_extend_op
<
miopen_contiguous
,
op
::
contiguous
>
(
"contiguous"
);
add_extend_op
<
hip_concat
,
op
::
concat
>
(
"concat"
);
add_extend_op
<
miopen_softmax
,
op
::
softmax
>
(
"softmax"
);
...
...
@@ -109,7 +110,6 @@ struct miopen_apply
add_lrn_op
();
add_convolution_op
();
add_quant_convolution_op
();
add_quant_gemm_op
();
add_pooling_op
();
add_batch_norm_inference_op
();
}
...
...
@@ -172,46 +172,6 @@ struct miopen_apply
});
}
void
add_quant_gemm_op
()
{
apply_map
.
emplace
(
"quant_dot"
,
[
=
](
instruction_ref
ins
)
{
auto
&&
op
=
any_cast
<
op
::
quant_dot
>
(
ins
->
get_operator
());
std
::
vector
<
instruction_ref
>
refs
=
ins
->
inputs
();
// add additional arguments if need packing. Since lowering is added
// after auto_contiguous and before eliminate contiguous, the shapes
// of all inputs are standard, so the input shape cannot be transposed.
// To avoid that, we need to check whether this argument is an output
// of contiguous. If true, we should check the shape of the input
// of the contiguous operator.
auto
prev_ins
=
refs
.
at
(
0
);
if
(
prev_ins
->
name
()
==
"gpu::contiguous"
)
{
auto
input
=
prev_ins
->
inputs
().
front
();
if
(
input
->
get_shape
().
transposed
())
{
auto
pack_a
=
insert_allocation
(
input
,
input
->
get_shape
());
// replace one of the inputs of quant_gemm from the output to the
// input of contiguous. Then the contiguous could become dead code
// of prev_ins is its only output
refs
.
at
(
0
)
=
input
;
instruction
::
replace_argument
(
ins
,
prev_ins
,
input
);
refs
.
push_back
(
pack_a
);
}
}
if
(
!
refs
.
at
(
1
)
->
get_shape
().
transposed
())
{
auto
pack_b
=
insert_allocation
(
refs
.
at
(
1
),
refs
.
at
(
1
)
->
get_shape
());
refs
.
push_back
(
pack_b
);
}
auto
output
=
insert_allocation
(
ins
,
ins
->
get_shape
());
refs
.
push_back
(
output
);
return
prog
->
replace_instruction
(
ins
,
miopen_quant_gemm
{
op
},
refs
);
});
}
void
add_pooling_op
()
{
apply_map
.
emplace
(
"pooling"
,
[
=
](
instruction_ref
ins
)
{
...
...
src/targets/gpu/quant_gemm.cpp
View file @
69145ea1
...
...
@@ -56,16 +56,6 @@ shape miopen_quant_gemm::compute_shape(const std::vector<shape>& inputs) const
{
std
::
vector
<
shape
>
input_shapes
(
inputs
);
input_shapes
.
pop_back
();
if
(
!
inputs
.
at
(
1
).
transposed
())
{
input_shapes
.
pop_back
();
}
if
(
inputs
.
at
(
0
).
transposed
())
{
input_shapes
.
pop_back
();
}
check_shapes
{
input_shapes
}.
not_broadcasted
();
return
op
.
compute_shape
(
input_shapes
);
}
...
...
@@ -74,8 +64,6 @@ argument miopen_quant_gemm::compute(context& ctx,
const
shape
&
output_shape
,
const
std
::
vector
<
argument
>&
args
)
const
{
// handling the packing of B MUST be before handling that for A
auto
arg_res
=
args
.
back
();
bool
transa
=
args
[
0
].
get_shape
().
transposed
();
bool
transb
=
args
[
1
].
get_shape
().
transposed
();
auto
n_dim
=
output_shape
.
lens
().
size
();
...
...
@@ -83,28 +71,29 @@ argument miopen_quant_gemm::compute(context& ctx,
auto
dim_0
=
n_dim
-
2
;
rocblas_int
lda
=
args
[
0
].
get_shape
().
strides
()[
transa
?
dim_1
:
dim_0
];
rocblas_int
ldb
=
args
[
1
].
get_shape
().
strides
()[
transb
?
dim_1
:
dim_0
];
rocblas_int
ldc
=
arg
_res
.
get_shape
().
strides
()[
dim_0
];
rocblas_int
ldc
=
arg
s
[
2
]
.
get_shape
().
strides
()[
dim_0
];
auto
arg_b
=
args
.
at
(
1
);
std
::
size_t
pack_arg_num
=
0
;
if
(
!
transb
)
{
arg_b
=
args
.
at
(
args
.
size
()
-
2
);
++
pack_arg_num
;
if
(
arg_b
.
empty
())
{
arg_b
=
allocate_gpu
(
args
[
1
].
get_shape
());
}
device
::
pack_a
(
ctx
.
get_stream
().
get
(),
arg_b
,
args
[
1
]);
}
// need to pack A in this scenario, use the algorithm to pack B in the
// comment of the API
auto
arg_a
=
args
.
at
(
0
);
if
(
transa
)
{
arg_a
=
args
.
at
(
args
.
size
()
-
2
-
pack_arg_num
);
++
pack_arg_num
;
if
(
arg_a
.
empty
())
{
arg_a
=
allocate_gpu
(
args
.
at
(
0
).
get_shape
());
}
device
::
pack_b
(
ctx
.
get_stream
().
get
(),
arg_a
,
args
[
0
]);
}
bool
is_3inputs
=
(
args
.
size
()
-
pack_arg_num
==
4
);
bool
is_3inputs
=
(
args
.
size
()
==
4
);
int8_t
beta
=
0
;
if
(
is_3inputs
)
{
...
...
@@ -138,17 +127,17 @@ argument miopen_quant_gemm::compute(context& ctx,
m
,
k
,
&
alpha_r
,
to_pointer
(
arg_b
),
(
!
transb
)
?
to_pointer
(
arg_b
)
:
to_pointer
(
args
.
at
(
1
))
,
rocblas_datatype_i8_r
,
ldb
,
to_pointer
(
arg_a
),
transa
?
to_pointer
(
arg_a
)
:
to_pointer
(
args
.
at
(
0
))
,
rocblas_datatype_i8_r
,
lda
,
&
beta_r
,
to_pointer
(
args
[
2
]),
rocblas_datatype_i32_r
,
ldc
,
to_pointer
(
arg
_res
),
is_3inputs
?
to_pointer
(
arg
s
.
at
(
3
))
:
to_pointer
(
args
[
2
]
),
rocblas_datatype_i32_r
,
ldc
,
rocblas_datatype_i32_r
,
...
...
@@ -168,11 +157,11 @@ argument miopen_quant_gemm::compute(context& ctx,
m
,
k
,
&
alpha_r
,
to_pointer
(
arg_b
),
(
!
transb
)
?
to_pointer
(
arg_b
)
:
to_pointer
(
args
.
at
(
1
))
,
rocblas_datatype_i8_r
,
ldb
,
k
*
n
,
to_pointer
(
arg_a
),
transa
?
to_pointer
(
arg_a
)
:
to_pointer
(
args
.
at
(
0
))
,
rocblas_datatype_i8_r
,
lda
,
m
*
k
,
...
...
@@ -181,7 +170,7 @@ argument miopen_quant_gemm::compute(context& ctx,
rocblas_datatype_i32_r
,
ldc
,
m
*
n
,
to_pointer
(
arg
_res
),
is_3inputs
?
to_pointer
(
arg
s
.
at
(
3
))
:
to_pointer
(
args
[
2
]
),
rocblas_datatype_i32_r
,
ldc
,
m
*
n
,
...
...
@@ -195,7 +184,7 @@ argument miopen_quant_gemm::compute(context& ctx,
}
});
return
arg_res
;
return
is_3inputs
?
args
.
at
(
3
)
:
args
[
2
]
;
}
}
// namespace gpu
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment