Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4e64e2c2
Commit
4e64e2c2
authored
Nov 16, 2022
by
Alan Turner
Browse files
Formatting
parent
f3fcfcc7
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
160 additions
and
154 deletions
+160
-154
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+3
-2
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
+2
-4
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+23
-19
src/targets/gpu/jit/ck_gsg_instances.cpp
src/targets/gpu/jit/ck_gsg_instances.cpp
+116
-116
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
...kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
+3
-3
test/verify/0ck_gemm_softmax_gemm.cpp
test/verify/0ck_gemm_softmax_gemm.cpp
+12
-9
tools/tune_ck.py
tools/tune_ck.py
+1
-1
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
4e64e2c2
...
...
@@ -162,8 +162,9 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto pw =
// match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(
// match::any_of[match::inputs()](softmax));
// }
...
...
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
View file @
4e64e2c2
...
...
@@ -66,10 +66,8 @@ struct find_gemm_softmax_gemm_gemm
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
any_of
[
match
::
inputs
()](
mul
));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
any_of
[
match
::
inputs
()](
mul
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
add
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
4e64e2c2
...
...
@@ -111,8 +111,8 @@ struct instance
void
set_gemm
(
const
std
::
string
&
s
)
{
assert
(
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
or
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
);
assert
(
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
or
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
);
params
[
15
]
=
s
;
}
...
...
@@ -155,12 +155,13 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
if
(
not
s
.
transposed
())
if
(
not
s
.
transposed
())
return
"ck::tensor_layout::gemm::RowMajor"
;
auto
lens
=
s
.
lens
();
return
lens
[
lens
.
size
()
-
1
]
>
lens
[
lens
.
size
()
-
2
]
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
return
lens
[
lens
.
size
()
-
1
]
>
lens
[
lens
.
size
()
-
2
]
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
std
::
string
get_type
(
const
shape
&
s
)
...
...
@@ -185,23 +186,26 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
b1_shape
=
inputs
[
2
];
auto
c_shape
=
inputs
.
back
();
auto
m
=
a_shape
.
lens
()[
0
];
auto
k
=
a_shape
.
lens
()[
1
];
auto
n
=
c_shape
.
lens
()[
1
];
auto
c_shape
=
inputs
.
back
();
auto
m
=
a_shape
.
lens
()[
0
];
auto
k
=
a_shape
.
lens
()[
1
];
auto
n
=
c_shape
.
lens
()[
1
];
auto
rank
=
a_shape
.
lens
().
size
();
std
::
array
<
char
,
4
>
keys
{
'M'
,
'N'
,
'K'
,
'O'
};
// config (m0, n0, k0, n1)
std
::
array
<
std
::
size_t
,
4
>
config
{
c_shape
.
lens
()[
rank
-
2
],
b_shape
.
lens
()[
rank
-
2
],
a_shape
.
lens
().
back
(),
c_shape
.
lens
().
back
()};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
b1_shape
,
c_shape
}));
auto
ip
=
instance
{
get_gsg_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
std
::
array
<
std
::
size_t
,
4
>
config
{
c_shape
.
lens
()[
rank
-
2
],
b_shape
.
lens
()[
rank
-
2
],
a_shape
.
lens
().
back
(),
c_shape
.
lens
().
back
()};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
b1_shape
,
c_shape
}));
auto
ip
=
instance
{
get_gsg_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
...
...
@@ -220,8 +224,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
gemm_type
+=
"Padding"
;
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
...
...
src/targets/gpu/jit/ck_gsg_instances.cpp
View file @
4e64e2c2
...
...
@@ -935,122 +935,122 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"8"
,
"false"
,
"std::ratio<1, 8>"
},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "64",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "2",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<16,16,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "128",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "4",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<8,32,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "64",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "2",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<16,16,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "128",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "4",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<8,32,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
View file @
4e64e2c2
...
...
@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
n
=
b_shape
.
lens
[
1
];
constexpr
const
auto
n
=
b_shape
.
lens
[
1
];
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK0
=
k
/
BK1
;
...
...
@@ -85,8 +85,8 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// row-major
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
...
...
test/verify/0ck_gemm_softmax_gemm.cpp
View file @
4e64e2c2
...
...
@@ -50,15 +50,15 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
//
auto scale =
mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
//
auto bias =
mm->add_instruction(migraphx::make_op("add"), scale, zero);
//
auto softmax =
mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
//
b);
auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
auto scale =
// mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
auto bias =
// mm->add_instruction(migraphx::make_op("add"), scale, zero);
auto softmax =
// mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
size_t
batch
=
2
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
384
,
2304
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
12
,
384
,
384
}};
...
...
@@ -73,9 +73,12 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch
,
384
,
36
,
64
}}}),
g
);
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
g
);
auto
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
g
);
auto
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
g
);
auto
b1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
g
);
auto
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
g
);
auto
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
g
);
auto
b1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
g
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
...
...
tools/tune_ck.py
View file @
4e64e2c2
...
...
@@ -2,6 +2,7 @@ import os, json, subprocess, tempfile, sys, argparse, contextlib
ck_function
=
-
1
@
contextlib
.
contextmanager
def
tmp_file
(
dump
=
None
):
tmp_name
=
None
...
...
@@ -99,7 +100,6 @@ def parse_log(f):
config
=
json
.
loads
(
line
)
ck_function
=
1
yield
config
def
benchmark_log
(
f
,
n
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment