Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
b9f1b198
Commit
b9f1b198
authored
Mar 28, 2023
by
Alan Turner
Browse files
Reconfigure to use int8 ck gemms
parent
ac7a0025
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
228 additions
and
59 deletions
+228
-59
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+5
-5
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+66
-8
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+79
-46
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+37
-0
test/onnx/int8_gemm.onnx
test/onnx/int8_gemm.onnx
+19
-0
test/onnx/int8_gemm_verify.onnx
test/onnx/int8_gemm_verify.onnx
+22
-0
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
b9f1b198
...
@@ -17,7 +17,7 @@ namespace gpu {
...
@@ -17,7 +17,7 @@ namespace gpu {
struct
ck_gemm
struct
ck_gemm
{
{
operation
op
=
make_op
(
"dot"
);
operation
op
=
make_op
(
"
quant_
dot"
);
template
<
class
Self
,
class
F
>
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
static
auto
reflect
(
Self
&
self
,
F
f
)
...
@@ -46,7 +46,7 @@ struct ck_gemm
...
@@ -46,7 +46,7 @@ struct ck_gemm
check_gemm_shape
(
input
);
check_gemm_shape
(
input
);
auto
r
=
op
.
compute_shape
({
a
,
b
});
auto
r
=
op
.
compute_shape
({
a
,
b
});
if
(
mods
.
empty
())
if
(
mods
.
empty
())
return
r
;
return
r
.
with_type
(
shape
::
int8_type
)
;
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
return
r
.
with_type
(
mods
.
front
()
->
get_output_shapes
().
front
().
type
());
}
}
};
};
...
@@ -56,7 +56,7 @@ namespace {
...
@@ -56,7 +56,7 @@ namespace {
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
{
if
(
ins
->
name
()
!=
"dot"
)
if
(
ins
->
name
()
!=
"
quant_
dot"
)
return
false
;
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
...
@@ -87,7 +87,7 @@ struct find_ck_gemm_pointwise
...
@@ -87,7 +87,7 @@ struct find_ck_gemm_pointwise
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_it
=
std
::
find
(
inputs
.
begin
(),
inputs
.
end
(),
x_ins
);
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
auto
gemm_idx
=
gemm_it
-
inputs
.
begin
();
assert
(
gemm_it
!=
inputs
.
end
());
assert
(
gemm_it
!=
inputs
.
end
());
if
(
ins
->
get_shape
().
type
()
!=
shape
::
half_
type
)
if
(
ins
->
get_shape
().
type
()
!=
shape
::
int8_type
and
ins
->
get_shape
().
type
()
)
return
;
return
;
if
(
gemm_idx
!=
0
)
if
(
gemm_idx
!=
0
)
{
{
...
@@ -110,7 +110,7 @@ struct find_ck_gemm_pointwise
...
@@ -110,7 +110,7 @@ struct find_ck_gemm_pointwise
struct
find_ck_gemm
struct
find_ck_gemm
{
{
auto
matcher
()
const
{
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
auto
matcher
()
const
{
return
match
::
name
(
"
quant_
dot"
)(
is_ck_gemm
().
bind
(
"gemm"
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
...
...
src/targets/gpu/jit/ck_gemm.cpp
View file @
b9f1b198
...
@@ -58,6 +58,60 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
...
@@ -58,6 +58,60 @@ static const char* const ck_gemm_kernel = R"__migraphx__(
#include <args.hpp>
#include <args.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/ck_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include <migraphx/kernels/pointwise.hpp>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_dl.hpp"
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using Empty_Tuple = ck::Tuple<>;
using GEMM = ck::tensor_operation::device::DeviceGemmMultipleD_Dl<
Row,
Row,
Empty_Tuple,
Row,
int8_t,
int8_t,
int32_t,
Empty_Tuple,
int8_t, //EDataType
PassThrough,
PassThrough,
PassThrough,
ck::tensor_operation::device::GemmSpecialization::MNKPadding,
256,
128,
128,
16,
4,
4,
4,
1,
S<8,2>,
S<8,2>,
S<8,1,1,4>,
S<2,1,128,1>,
S<1,2,0,3>,
S<1,2,0,3>,
S<4,1,1,4>,
S<1,2,0,3>,
S<1,1,1,4>,
S<2,1,4,4>,
S<8,1,32,1>,
S<0,3,1,2>,
S<0,3,1,2>,
S<1,1,4,1>,
S<0,3,1,2>,
S<1,1,4,4>,
S<0,1,2,3,4,5>,
5,
4>;
namespace migraphx {
namespace migraphx {
...
@@ -68,7 +122,7 @@ extern "C" {
...
@@ -68,7 +122,7 @@ extern "C" {
__global__ void ${kernel}(${params})
__global__ void ${kernel}(${params})
{
{
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm<
CK_DeviceGemmMultipleD<${instance}>
, ${blocks_per_batch}>(xs...);
ck_gemm<
GEMM
, ${blocks_per_batch}>(xs...);
});
});
}
}
...
@@ -295,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -295,9 +349,9 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
return
true
;
/*
get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
*/
})};
})};
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
assert
(
inputs
.
size
()
<
4
or
v
.
contains
(
"post"
));
if
(
v
.
contains
(
"post"
))
if
(
v
.
contains
(
"post"
))
...
@@ -320,19 +374,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -320,19 +374,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
gemm_type
+=
"Padding"
;
gemm_type
+=
"Padding"
;
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
128
)
*
int_div_ceil
(
n
,
128
);;
//
ip.get_grid_size(config);
hip_compile_options
options
;
hip_compile_options
options
;
auto
block_size
=
ip
.
get_block_size
();
auto
block_size
=
256
;
//
ip.get_block_size();
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
//auto new_inputs = inputs;
auto
new_inputs
=
inputs
;
// auto out_s = inputs.back();
// new_inputs.back() = shape{shape::int8_type, out_s.lens(), out_s.strides()};
options
.
inputs
=
new_inputs
;
options
.
output
=
c_shape
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
options
.
virtual_inputs
=
new_
inputs
;
if
(
can_fold_batch
)
if
(
can_fold_batch
)
{
{
auto
vinputs
=
inputs
;
auto
vinputs
=
new_
inputs
;
fold_batch_dims
(
vinputs
[
0
]);
fold_batch_dims
(
vinputs
[
0
]);
remove_batch_dims
(
vinputs
[
1
]);
remove_batch_dims
(
vinputs
[
1
]);
std
::
for_each
(
vinputs
.
begin
()
+
2
,
vinputs
.
end
(),
fold_batch_dims
);
std
::
for_each
(
vinputs
.
begin
()
+
2
,
vinputs
.
end
(),
fold_batch_dims
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
b9f1b198
...
@@ -45,55 +45,88 @@ template <class Tensor>
...
@@ -45,55 +45,88 @@ template <class Tensor>
using
ck_transposeb
=
decltype
(
make_shape
(
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
lens
),
using
ck_transposeb
=
decltype
(
make_shape
(
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
lens
),
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
strides
)));
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
strides
)));
template
<
class
...
Xs
>
constexpr
void
noop
(
Xs
...)
{}
template
<
class
G
,
class
E
,
class
A
,
class
B
,
class
...
Ds
>
template
<
class
G
,
class
E
,
class
A
,
class
B
,
class
...
Ds
>
__device__
void
ck_gemm_matrix
(
E
e
,
A
a
,
B
b
,
Ds
...
ds
)
__device__
void
ck_gemm_matrix
(
E
e
,
A
a
,
B
b
,
Ds
...
ds
)
{
{
constexpr
const
G
gemm
{};
constexpr
auto
desc
=
G
::
make_descriptor
(
to_ck_tensor
<
A
>
(),
to_ck_tensor
<
ck_transposeb
<
B
>>
(),
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_tensor
<
A
>
());
ck
::
make_tuple
(
to_ck_tensor
<
Ds
>
()...),
constexpr
const
auto
b_grid_desc_n_k
=
to_ck_tensor
<
E
>
());
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_tensor
<
ck_transposeb
<
B
>>
());
G
::
Run
(
desc
,
to_ck_const_pointer
(
a
.
data
()),
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
E
>
());
to_ck_const_pointer
(
b
.
data
()),
constexpr
const
auto
ds_grid_desc_m_n
=
ck
::
make_tuple
(
to_ck_const_pointer
(
ds
.
data
())...),
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
Ds
>
())...);
to_ck_pointer
(
e
.
data
()));
constexpr
const
auto
block_2_etile_map
=
gemm
.
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
);
// constexpr const auto M = a.get_shape().lens[0];
using
GridwiseGemm
=
typename
G
::
GridwiseGemm
;
// constexpr const auto N = b.get_shape().lens[1];
// constexpr const auto K = a.get_shape().lens[1];
// tensor descriptors for block/thread-wise copy
// constexpr const auto K1Number = ck::Number<4>{};
constexpr
auto
a_grid_desc_ak0_m_ak1
=
// constexpr const auto K0 = K / 4;
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
// using GridwiseGemm = typename G::GridwiseGemm;
constexpr
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
// constexpr auto a_grid_desc_k0_m_k1 = ck::transform_tensor_descriptor(
// to_ck_tensor<A>(),
constexpr
auto
ds_grid_desc_mblock_mperblock_nblock_nperblock
=
// ck::make_tuple(ck::make_unmerge_transform(ck::make_tuple(K0, K1Number)),
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n
);
// ck::make_pass_through_transform(M)),
// ck::make_tuple(ck::Sequence<1>{}, ck::Sequence<0>{}),
constexpr
auto
e_grid_desc_mblock_mperblock_nblock_nperblock
=
// ck::make_tuple(ck::Sequence<0, 2>{}, ck::Sequence<1>{}));
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
);
// constexpr const auto a_grid_desc_k0_m0_m1_k1 =
// GridwiseGemm::MakeAGridDescriptor_K0_M0_M1_K1(a_grid_desc_k0_m_k1);
static_assert
(
GridwiseGemm
::
CheckValidity
(
// noop(a, b, e, M, N, K, a_grid_desc_k0_m_k1, a_grid_desc_k0_m0_m1_k1, ds...);
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
));
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
//////////////////////////
constexpr
const
bool
HasMainKBlockLoop
=
// constexpr const G gemm{};
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{}));
// constexpr const auto a_grid_desc_m_k = gemm.matrix_padder.PadADescriptor_M_K(to_ck_tensor<A>());
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
// constexpr const auto b_grid_desc_n_k =
to_ck_const_pointer
(
b
.
data
()),
// gemm.matrix_padder.PadBDescriptor_N_K(to_ck_tensor<ck_transposeb<B>>());
ck
::
make_tuple
(
to_ck_const_pointer
(
ds
.
data
())...),
to_ck_pointer
(
e
.
data
()),
// constexpr const auto e_grid_desc_m_n = gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<E>());
p_shared_block
,
// constexpr const auto ds_grid_desc_m_n =
gemm
.
a_element_op
,
// ck::make_tuple(gemm.matrix_padder.PadCDescriptor_M_N(to_ck_tensor<Ds>())...);
gemm
.
b_element_op
,
// constexpr const auto block_2_etile_map = gemm.MakeDefaultBlock2ETileMap(e_grid_desc_m_n);
gemm
.
cde_element_op
,
a_grid_desc_ak0_m_ak1
,
// using GridwiseGemm = typename G::GridwiseGemm;
b_grid_desc_bk0_n_bk1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
// // tensor descriptors for block/thread-wise copy
e_grid_desc_mblock_mperblock_nblock_nperblock
,
// constexpr auto a_grid_desc_ak0_m_ak1 =
block_2_etile_map
);
// GridwiseGemm::MakeDefaultAGridDescriptor_AK0_M_AK1(a_grid_desc_m_k);
// constexpr auto b_grid_desc_bk0_n_bk1 =
// GridwiseGemm::MakeDefaultBGridDescriptor_BK0_N_BK1(b_grid_desc_n_k);
// constexpr auto ds_grid_desc_mblock_mperblock_nblock_nperblock =
// GridwiseGemm::MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(ds_grid_desc_m_n);
// constexpr auto e_grid_desc_mblock_mperblock_nblock_nperblock =
// GridwiseGemm::MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(e_grid_desc_m_n);
// static_assert(GridwiseGemm::CheckValidity(
// a_grid_desc_m_k, b_grid_desc_n_k, ds_grid_desc_m_n, e_grid_desc_m_n, block_2_etile_map));
// __shared__ char p_shared_block[GridwiseGemm::GetSharedMemoryNumberOfByte()];
// constexpr const bool HasMainKBlockLoop =
// GridwiseGemm::CalculateHasMainKBlockLoop(a_grid_desc_ak0_m_ak1.GetLength(ck::Number<0>{}) *
// a_grid_desc_ak0_m_ak1.GetLength(ck::Number<2>{}));
// GridwiseGemm::template Run<HasMainKBlockLoop>(to_ck_const_pointer(a.data()),
// to_ck_const_pointer(b.data()),
// ck::make_tuple(to_ck_const_pointer(ds.data())...),
// to_ck_pointer(e.data()),
// p_shared_block,
// gemm.a_element_op,
// gemm.b_element_op,
// gemm.cde_element_op,
// a_grid_desc_ak0_m_ak1,
// b_grid_desc_bk0_n_bk1,
// ds_grid_desc_mblock_mperblock_nblock_nperblock,
// e_grid_desc_mblock_mperblock_nblock_nperblock,
// block_2_etile_map);
}
}
template
<
class
G
,
index_int
BlocksPerBatch
,
class
...
Ts
>
template
<
class
G
,
index_int
BlocksPerBatch
,
class
...
Ts
>
...
...
test/onnx/gen_onnx.py
View file @
b9f1b198
...
@@ -4031,6 +4031,43 @@ def matmulinteger_test():
...
@@ -4031,6 +4031,43 @@ def matmulinteger_test():
return
([
node
],
[
m1
,
m2
],
[
y
])
return
([
node
],
[
m1
,
m2
],
[
y
])
@
onnx_test
()
def
int8_gemm
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
INT8
,
[
256
,
256
])
m2
=
helper
.
make_tensor_value_info
(
'2'
,
TensorProto
.
INT8
,
[
256
,
256
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT8
,
[
256
,
256
])
node
=
onnx
.
helper
.
make_node
(
'MatMulInteger'
,
inputs
=
[
'1'
,
'2'
],
outputs
=
[
'y'
],
)
return
([
node
],
[
m1
,
m2
],
[
y
])
@
onnx_test
()
def
int8_gemm_verify
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
INT8
,
[
256
,
256
])
m2
=
helper
.
make_tensor_value_info
(
'2'
,
TensorProto
.
INT8
,
[
256
,
256
])
y
=
helper
.
make_tensor_value_info
(
'y'
,
TensorProto
.
INT32
,
[
256
,
256
])
node
=
onnx
.
helper
.
make_node
(
'MatMulInteger'
,
inputs
=
[
'1'
,
'2'
],
outputs
=
[
'x'
],
)
convert
=
onnx
.
helper
.
make_node
(
'Cast'
,
inputs
=
[
'x'
],
outputs
=
[
'y'
],
to
=
6
)
return
([
node
,
convert
],
[
m1
,
m2
],
[
y
])
@
onnx_test
()
@
onnx_test
()
def
matmulinteger_dyn_error
():
def
matmulinteger_dyn_error
():
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
INT8
,
[
None
,
6
,
16
])
m1
=
helper
.
make_tensor_value_info
(
'1'
,
TensorProto
.
INT8
,
[
None
,
6
,
16
])
...
...
test/onnx/int8_gemm.onnx
0 → 100644
View file @
b9f1b198
int8_gemm:j
1
2y" MatMulInteger int8_gemmZ
1
Z
2
b
y
B
\ No newline at end of file
test/onnx/int8_gemm_verify.onnx
0 → 100644
View file @
b9f1b198
int8_gemm_verify:
1
2x" MatMulInteger
xy"Cast*
toint8_gemm_verifyZ
1
Z
2
b
y
B
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment