Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3133fd79
Commit
3133fd79
authored
Nov 15, 2022
by
Alan Turner
Browse files
Formatting
parent
d7ea085c
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
133 additions
and
119 deletions
+133
-119
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+23
-16
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+17
-11
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
...kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
+70
-67
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
...nclude/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
+8
-12
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+14
-12
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
3133fd79
...
...
@@ -79,7 +79,7 @@ struct ck_gemm_scale_bias_softmax_gemm
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
//std::cout << input << std::endl;
//
std::cout << input << std::endl;
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
...
...
@@ -158,10 +158,13 @@ struct find_ck_gemm_scale_bias_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
pw
=
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
pw
=
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
pw
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
...
...
@@ -180,16 +183,19 @@ struct find_ck_gemm_scale_bias_softmax_gemm
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
//inputs.push_back(pw_ins->inputs().back()); // C
//
inputs.push_back(pw_ins->inputs().back()); // C
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
}
// auto matcher() const
// {
// auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// auto gemm1 =
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...
...
@@ -207,7 +213,8 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// mpm.get_module().replace_instruction(ins,
// ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
};
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
3133fd79
...
...
@@ -213,7 +213,10 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
...
...
@@ -259,14 +262,15 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto
gemm1_nperblock
=
64
;
auto
gemm01_mperblock
=
256
;
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
int_div_ceil
(
n
,
gemm1_nperblock
);
//ip.get_grid_size(config);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
int_div_ceil
(
n
,
gemm1_nperblock
);
// ip.get_grid_size(config);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
hip_compile_options
options
;
auto
block_size
=
256
;
//ip.get_block_size();
auto
block_size
=
256
;
//
ip.get_block_size();
auto
grid_size
=
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
...
...
@@ -278,7 +282,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
{{
"instance"
,
""
/* ip.str() */
},
{{
"instance"
,
""
/* ip.str() */
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
...
...
@@ -296,7 +300,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"
;
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
...
...
@@ -306,7 +311,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
});
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
View file @
3133fd79
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
View file @
3133fd79
...
...
@@ -54,14 +54,15 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr
const
auto
m
=
a_shape
.
lens
[
0
];
constexpr
const
auto
k
=
a_shape
.
lens
[
1
];
constexpr
const
auto
sa
=
a_shape
.
strides
[
0
];
constexpr
const
auto
a_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
k
),
ck
::
make_tuple
(
sa
,
1
));
constexpr
const
auto
a_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
k
),
ck
::
make_tuple
(
sa
,
1
));
constexpr
const
auto
a_grid_desc_mraw_kraw
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
a_tensor
);
constexpr
const
auto
AK1
=
gemm
.
get_AK1
();
constexpr
const
auto
AK0
=
k
/
AK1
;
constexpr
const
auto
a_grid_desc_ak0_m_ak1
=
ck
::
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
constexpr
const
auto
a_grid_desc_ak0_m_ak1
=
ck
::
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
m
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
...
@@ -73,10 +74,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK0
=
k
/
BK1
;
constexpr
const
auto
b_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n
,
k
),
ck
::
make_tuple
(
sb
,
1
));
constexpr
const
auto
b_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n
,
k
),
ck
::
make_tuple
(
sb
,
1
));
constexpr
const
auto
b_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
b_tensor
);
constexpr
const
auto
b_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
constexpr
const
auto
b_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
n
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
...
@@ -89,10 +91,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
constexpr
const
auto
b1_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n1
,
k1
),
ck
::
make_tuple
(
1
,
sb1
));
constexpr
const
auto
b1_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n1
,
k1
),
ck
::
make_tuple
(
1
,
sb1
));
constexpr
const
auto
b1_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadB1Descriptor_N_K
(
b1_tensor
);
constexpr
const
auto
b1_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
constexpr
const
auto
b1_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
B1K0
,
B1K1
)),
ck
::
make_pass_through_transform
(
n1
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
...
...
@@ -100,11 +103,10 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr
const
auto
c_shape
=
get_shape_c
<
C
>
{};
constexpr
const
auto
sc
=
c_shape
.
strides
[
0
];
constexpr
const
auto
c_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
n1
),
ck
::
make_tuple
(
sc
,
1
));
constexpr
const
auto
c_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
n1
),
ck
::
make_tuple
(
sc
,
1
));
constexpr
const
auto
c_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
c_tensor
);
constexpr
const
auto
MPerBlock
=
gemm
.
get_mperblock
();
constexpr
const
auto
Gemm1NPerBlock
=
gemm
.
get_gemm1nperblock
();
constexpr
const
auto
MBlock
=
m
/
MPerBlock
;
...
...
@@ -112,18 +114,20 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
ck
::
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
MBlock
,
ck
::
Number
<
MPerBlock
>
{})),
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
MBlock
,
ck
::
Number
<
MPerBlock
>
{})),
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
NBlock
,
ck
::
Number
<
Gemm1NPerBlock
>
{}))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
,
3
>
{}));
constexpr
const
auto
block_2_ctile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
(
constexpr
const
auto
block_2_ctile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
(
c_grid_desc_m_n
);
const
C0MatrixMask
c0_matrix_mask
(
n
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{});
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{});
using
gridwise
=
typename
G
::
template
rt_gridwisegemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
...
...
@@ -135,7 +139,6 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
static_assert
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
View file @
3133fd79
...
...
@@ -121,10 +121,7 @@ struct C0MatrixMask
__device__
bool
IsUpperTriangle
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
return
n
>
m
;
}
__device__
bool
IsNOutOfBound
(
/*ck::index_t m, */
ck
::
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__device__
bool
IsNOutOfBound
(
/*ck::index_t m, */
ck
::
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__device__
bool
IsMaskedElement
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
...
...
@@ -197,8 +194,8 @@ template <typename ALayout,
ck
::
LoopScheduler
LoopSched
=
ck
::
LoopScheduler
::
Default
>
struct
CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
GemmGemmPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
{
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
GemmGemmPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
static
constexpr
auto
get_AK1
()
{
return
AK1
;
};
...
...
@@ -215,7 +212,7 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation
c_element_op
{};
AccElementwiseOperation
acc_element_op
{
alpha
};
template
<
typename
AGridDesc_AK0_M_AK1
,
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
>
...
...
@@ -286,6 +283,5 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
};
};
}
// namespace migraphx
#endif
test/onnx/gen_onnx.py
View file @
3133fd79
...
...
@@ -31,6 +31,7 @@ from onnx import TensorProto
def
onnx_test
(
op_test
):
def
run_test
():
op_info
=
op_test
()
if
len
(
op_info
)
>
3
:
...
...
@@ -1995,12 +1996,13 @@ def gemm_softmax_gemm_test():
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
out
=
helper
.
make_tensor_value_info
(
'out'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
scale_array
=
np
.
array
([(
1
/
8
)])
scale_array
=
np
.
array
([(
1
/
8
)])
scale_tensor
=
helper
.
make_tensor
(
name
=
'scale'
,
data_type
=
TensorProto
.
FLOAT16
,
dims
=
scale_array
.
shape
,
vals
=
scale_array
.
flatten
().
astype
(
np
.
float16
))
vals
=
scale_array
.
flatten
().
astype
(
np
.
float16
))
gemm1
=
onnx
.
helper
.
make_node
(
'MatMul'
,
inputs
=
[
'a'
,
'b'
],
...
...
@@ -2018,8 +2020,8 @@ def gemm_softmax_gemm_test():
inputs
=
[
'softmax_out'
,
'b1'
],
outputs
=
[
'out'
])
return
([
gemm1
,
mul1
,
add1
,
softmax
,
gemm2
],
[
a
,
b
,
c
,
b1
,
bias
],
[
out
],
[
scale_tensor
])
return
([
gemm1
,
mul1
,
add1
,
softmax
,
gemm2
],
[
a
,
b
,
c
,
b1
,
bias
],
[
out
],
[
scale_tensor
])
@
onnx_test
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment