Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f3fcfcc7
Commit
f3fcfcc7
authored
Nov 16, 2022
by
Alan Turner
Browse files
Fix fusion pass and add tuning
parent
3133fd79
Changes
12
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1067 additions
and
2651 deletions
+1067
-2651
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+73
-73
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
+102
-0
src/targets/gpu/include/migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp
...ts/gpu/include/migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp
+25
-0
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+48
-92
src/targets/gpu/jit/ck_gsg_instances.cpp
src/targets/gpu/jit/ck_gsg_instances.cpp
+610
-2462
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
...kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
+9
-9
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
test/verify/0ck_gemm_softmax_gemm.cpp
test/verify/0ck_gemm_softmax_gemm.cpp
+37
-7
tools/tune_ck.py
tools/tune_ck.py
+33
-8
tools/tune_ck.sh
tools/tune_ck.sh
+13
-0
tools/tune_ck_gsg.py
tools/tune_ck_gsg.py
+113
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
f3fcfcc7
...
@@ -90,6 +90,7 @@ add_library(migraphx_gpu
...
@@ -90,6 +90,7 @@ add_library(migraphx_gpu
device_name.cpp
device_name.cpp
elu.cpp
elu.cpp
fuse_ck.cpp
fuse_ck.cpp
fuse_ck_gemm_softmax_gemm.cpp
fuse_mlir.cpp
fuse_mlir.cpp
fuse_ops.cpp
fuse_ops.cpp
gather.cpp
gather.cpp
...
...
src/targets/gpu/fuse_ck.cpp
View file @
f3fcfcc7
...
@@ -49,43 +49,43 @@ struct ck_gemm
...
@@ -49,43 +49,43 @@ struct ck_gemm
};
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_scale_bias_softmax_gemm
//
struct ck_gemm_scale_bias_softmax_gemm
{
//
{
operation
op
=
make_op
(
"dot"
);
//
operation op = make_op("dot");
template
<
class
Self
,
class
F
>
//
template <class Self, class F>
static
auto
reflect
(
Self
&
self
,
F
f
)
//
static auto reflect(Self& self, F f)
{
//
{
return
pack
(
f
(
self
.
op
,
"op"
));
//
return pack(f(self.op, "op"));
}
//
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
//
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
void
check_gemm_shape
(
const
shape
&
s
)
const
//
void check_gemm_shape(const shape& s) const
{
//
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
//
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm_scale_bias_softmax_gemm"
);
//
MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
}
//
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
//
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
{
//
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
//
check_shapes{inputs, *this}.same_ndims();
// if(mods.size() != 1)
//
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
//
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
//
if(inputs.size() < 2)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
//
MIGRAPHX_THROW("should have at least two inputs.");
auto
a
=
inputs
[
0
];
//
auto a = inputs[0];
auto
b
=
inputs
[
1
];
//
auto b = inputs[1];
auto
b1
=
inputs
[
2
];
//
auto b1 = inputs[2];
for
(
const
auto
&
input
:
inputs
)
//
for(const auto& input : inputs)
{
//
{
// std::cout << input << std::endl;
//
// std::cout << input << std::endl;
check_gemm_shape
(
input
);
//
check_gemm_shape(input);
}
//
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
//
return op.compute_shape({op.compute_shape({a, b}), b1});
}
//
}
};
//
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_scale_bias_softmax_gemm
);
//
MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace
{
namespace
{
...
@@ -156,38 +156,38 @@ struct find_ck_gemm
...
@@ -156,38 +156,38 @@ struct find_ck_gemm
struct
find_ck_gemm_scale_bias_softmax_gemm
struct
find_ck_gemm_scale_bias_softmax_gemm
{
{
auto
matcher
()
const
//
auto matcher() const
{
//
{
auto
gemm1
=
//
auto gemm1 =
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
//
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
auto
pw
=
//
auto pw =
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
//
match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
pw
)).
bind
(
"softmax"
);
//
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
//
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
//
match::any_of[match::inputs()](softmax));
}
//
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
//
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
//
{
std
::
cout
<<
"Matched"
<<
std
::
endl
;
//
std::cout << "Matched" << std::endl;
auto
ins
=
r
.
result
;
//
auto ins = r.result;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
//
auto gemm2_ins = r.instructions["gemm2"];
auto
sm_ins
=
r
.
instructions
[
"softmax"
];
//
auto sm_ins = r.instructions["softmax"];
auto
pw_ins
=
r
.
instructions
[
"scale_bias"
];
//
auto pw_ins = r.instructions["scale_bias"];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
//
auto gemm1_ins = r.instructions["gemm1"];
gemm2_ins
->
debug_print
();
//
gemm2_ins->debug_print();
sm_ins
->
debug_print
();
//
sm_ins->debug_print();
pw_ins
->
debug_print
();
//
pw_ins->debug_print();
gemm1_ins
->
debug_print
();
//
gemm1_ins->debug_print();
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
//
auto inputs = gemm1_ins->inputs(); // A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
//
inputs.push_back(gemm2_ins->inputs().back()); // B1
// inputs.push_back(pw_ins->inputs().back()); // C
//
// inputs.push_back(pw_ins->inputs().back()); // C
mpm
.
get_module
().
replace_instruction
(
//
mpm.get_module().replace_instruction(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
//
ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
}
//
}
// auto matcher() const
// auto matcher() const
// {
// {
...
@@ -223,11 +223,11 @@ struct find_ck_gemm_scale_bias_softmax_gemm
...
@@ -223,11 +223,11 @@ struct find_ck_gemm_scale_bias_softmax_gemm
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
// mpm.get_module().debug_print();
// mpm.get_module().debug_print();
match
::
find_matches
(
mpm
,
find_ck_gemm_scale_bias_softmax_gemm
{});
//
match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
//
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
//
match::find_matches(mpm, find_ck_gemm_pointwise{});
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
//
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
//
match::find_matches(mpm, find_ck_gemm{});
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
0 → 100644
View file @
f3fcfcc7
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
gemm_softmax_gemm_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for gemm_softmax_gemm_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
};
MIGRAPHX_REGISTER_OP
(
gemm_softmax_gemm_gemm
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
back
()
>
2048
)
return
false
;
return
true
;
}
struct
find_gemm_softmax_gemm_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
any_of
[
match
::
inputs
()](
mul
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
add
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
mpm
.
get_module
().
replace_instruction
(
ins
,
gemm_softmax_gemm_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
}
};
}
// namespace
void
fuse_ck_gemm_softmax_gemm
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm_gemm
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/include/migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp
0 → 100644
View file @
f3fcfcc7
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
namespace
gpu
{
struct
fuse_ck_gemm_softmax_gemm
{
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::fuse_ck_gemm_softmax_gemm"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
f3fcfcc7
...
@@ -39,7 +39,7 @@
...
@@ -39,7 +39,7 @@
#include <migraphx/file_buffer.hpp>
#include <migraphx/file_buffer.hpp>
const
std
::
vector
<
std
::
string
>&
const
std
::
vector
<
std
::
string
>&
get_instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
get_
gsg_
instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
@@ -62,34 +62,12 @@ namespace migraphx {
...
@@ -62,34 +62,12 @@ namespace migraphx {
${preamble}
${preamble}
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck_passthrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck_scale;//ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;//ck_add;//ck::tensor_operation::element_wise::Add;
using gemm = CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false, std::ratio<1, 8>>;
extern "C" {
extern "C" {
__global__ void ${kernel}(${params})
__global__ void ${kernel}(${params})
{
{
// transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
// ck_gemm_softmax_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
// });
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<
gemm
, ${blocks_per_batch}>(xs...);
ck_gemm_softmax_gemm<
CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${instance}>
, ${blocks_per_batch}>(xs...);
});
});
}
}
...
@@ -112,13 +90,13 @@ struct instance
...
@@ -112,13 +90,13 @@ struct instance
std
::
size_t
get_pb
(
std
::
size_t
i
)
const
std
::
size_t
get_pb
(
std
::
size_t
i
)
const
{
{
assert
(
i
<
4
);
assert
(
i
<
=
4
);
return
int_at
(
block_size_index
+
1
+
i
);
return
int_at
(
block_size_index
+
1
+
i
);
}
}
std
::
array
<
std
::
size_t
,
3
>
get_pad
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
std
::
array
<
std
::
size_t
,
4
>
get_pad
(
const
std
::
array
<
std
::
size_t
,
4
>&
config
)
const
{
{
std
::
array
<
std
::
size_t
,
3
>
result
{};
std
::
array
<
std
::
size_t
,
4
>
result
{};
for
(
auto
i
:
range
(
config
.
size
()))
for
(
auto
i
:
range
(
config
.
size
()))
{
{
result
[
i
]
=
int_div_ceil
(
config
[
i
],
get_pb
(
i
))
*
get_pb
(
i
)
-
config
[
i
];
result
[
i
]
=
int_div_ceil
(
config
[
i
],
get_pb
(
i
))
*
get_pb
(
i
)
-
config
[
i
];
...
@@ -126,33 +104,16 @@ struct instance
...
@@ -126,33 +104,16 @@ struct instance
return
result
;
return
result
;
}
}
std
::
size_t
get_grid_size
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
std
::
size_t
get_grid_size
(
const
std
::
array
<
std
::
size_t
,
4
>&
config
)
const
{
{
return
int_div_ceil
(
config
[
0
],
get_pb
(
0
))
*
int_div_ceil
(
config
[
1
],
get_pb
(
1
));
return
int_div_ceil
(
config
[
0
],
get_pb
(
0
))
*
int_div_ceil
(
config
[
3
],
get_pb
(
3
));
}
void
set_ds_layout
(
const
std
::
string
&
s
)
{
assert
(
params
[
2
]
==
"ck::Tuple<>"
);
params
[
2
]
=
s
;
}
void
set_ds_type
(
const
std
::
string
&
s
)
{
assert
(
params
[
8
]
==
"ck::Tuple<>"
);
params
[
8
]
=
s
;
}
void
set_ds_op
(
const
std
::
string
&
s
)
{
assert
(
params
[
12
]
==
"ck_passthrough"
);
params
[
12
]
=
s
;
}
}
void
set_gemm
(
const
std
::
string
&
s
)
void
set_gemm
(
const
std
::
string
&
s
)
{
{
assert
(
params
[
13
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
);
assert
(
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
or
params
[
13
]
=
s
;
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
);
params
[
15
]
=
s
;
}
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
...
@@ -179,12 +140,12 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
...
@@ -179,12 +140,12 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
std
::
cout
<<
"*********** Warning: No CK
GSG
tuning!"
<<
std
::
endl
;
auto
it
=
std
::
find_if
(
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
if
(
it
==
tuning
.
end
())
{
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
"*********** Warning: CK
GSG
tuning missing for config!"
<<
std
::
endl
;
return
4
;
return
4
;
}
}
return
it
->
second
;
return
it
->
second
;
...
@@ -194,8 +155,12 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -194,8 +155,12 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
{
static
std
::
string
get_layout
(
const
shape
&
s
)
static
std
::
string
get_layout
(
const
shape
&
s
)
{
{
return
s
.
transposed
()
?
"ck::tensor_layout::gemm::ColumnMajor"
if
(
not
s
.
transposed
())
:
"ck::tensor_layout::gemm::RowMajor"
;
return
"ck::tensor_layout::gemm::RowMajor"
;
auto
lens
=
s
.
lens
();
return
lens
[
lens
.
size
()
-
1
]
>
lens
[
lens
.
size
()
-
2
]
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
}
static
std
::
string
get_type
(
const
shape
&
s
)
static
std
::
string
get_type
(
const
shape
&
s
)
...
@@ -222,6 +187,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -222,6 +187,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
{
auto
a_shape
=
inputs
[
0
];
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
b_shape
=
inputs
[
1
];
auto
b1_shape
=
inputs
[
2
];
auto
c_shape
=
inputs
.
back
();
auto
c_shape
=
inputs
.
back
();
auto
m
=
a_shape
.
lens
()[
0
];
auto
m
=
a_shape
.
lens
()[
0
];
auto
k
=
a_shape
.
lens
()[
1
];
auto
k
=
a_shape
.
lens
()[
1
];
...
@@ -229,48 +195,39 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -229,48 +195,39 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto
rank
=
a_shape
.
lens
().
size
();
auto
rank
=
a_shape
.
lens
().
size
();
// std::array<char, 3> keys{'M', 'N', 'K'};
std
::
array
<
char
,
4
>
keys
{
'M'
,
'N'
,
'K'
,
'O'
};
// std::array<std::size_t, 3> config{
// config (m0, n0, k0, n1)
// c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
std
::
array
<
std
::
size_t
,
4
>
config
{
c_shape
.
lens
()[
rank
-
2
],
b_shape
.
lens
()[
rank
-
2
],
a_shape
.
lens
().
back
(),
c_shape
.
lens
().
back
()};
// auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
// auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
b1_shape
,
c_shape
}));
// return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
auto
ip
=
instance
{
get_gsg_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
// })};
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
// assert(inputs.size() < 4 or v.contains("post"));
})};
// if(v.contains("post"))
// {
auto
padding
=
ip
.
get_pad
(
config
);
// ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
std
::
string
gemm_type
;
// ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
for
(
auto
i
:
range
(
padding
.
size
()))
// ip.set_ds_op(v.at("post").to<std::string>());
{
// }
if
(
padding
[
i
]
!=
0
)
gemm_type
+=
keys
[
i
];
// auto padding = ip.get_pad(config);
}
// std::string gemm_type;
if
(
gemm_type
.
empty
())
// for(auto i : range(padding.size()))
gemm_type
=
"Default"
;
// {
else
// if(padding[i] != 0)
gemm_type
+=
"Padding"
;
// gemm_type += keys[i];
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
// }
// if(gemm_type.empty())
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
// gemm_type = "Default";
// else
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto
gemm1_nperblock
=
64
;
auto
gemm01_mperblock
=
256
;
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
int_div_ceil
(
n
,
gemm1_nperblock
);
// ip.get_grid_size(config);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
multiplies
<
std
::
size_t
>
());
hip_compile_options
options
;
hip_compile_options
options
;
auto
block_size
=
256
;
//
ip.get_block_size();
auto
block_size
=
ip
.
get_block_size
();
auto
grid_size
=
batch_count
*
blocks_per_batch
;
auto
grid_size
=
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
...
@@ -282,7 +239,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -282,7 +239,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
{{
"instance"
,
""
/*
ip.str()
*/
},
{{
"instance"
,
ip
.
str
()},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
...
@@ -302,7 +259,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -302,7 +259,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"
;
"post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
}
...
@@ -310,7 +266,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -310,7 +266,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
[
2
],
shapes
.
back
()};
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
<<
std
::
endl
;
}
}
...
...
src/targets/gpu/jit/ck_gsg_instances.cpp
View file @
f3fcfcc7
This diff is collapsed.
Click to expand it.
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
View file @
f3fcfcc7
...
@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
...
@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
n
=
b_shape
.
lens
[
0
];
// col-major
constexpr
const
auto
n
=
b_shape
.
lens
[
1
];
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK0
=
k
/
BK1
;
constexpr
const
auto
BK0
=
k
/
BK1
;
...
@@ -85,9 +85,9 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
...
@@ -85,9 +85,9 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
// row-major
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
// row-major
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// row
l
-major
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// row-major
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
...
@@ -139,11 +139,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
...
@@ -139,11 +139,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
static_assert
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
//
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
b_grid_desc_bk0_n_bk1
,
//
b_grid_desc_bk0_n_bk1,
b1_grid_desc_bk0_n_bk1
,
//
b1_grid_desc_bk0_n_bk1,
c_grid_desc_m_n
,
//
c_grid_desc_m_n,
block_2_ctile_map
));
//
block_2_ctile_map));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
to_ck_const_pointer
(
b1
.
data
()),
to_ck_const_pointer
(
b1
.
data
()),
...
...
src/targets/gpu/target.cpp
View file @
f3fcfcc7
...
@@ -56,6 +56,7 @@
...
@@ -56,6 +56,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
...
@@ -131,6 +132,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
...
@@ -131,6 +132,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes
{},
simplify_reshapes
{},
propagate_constant
{},
propagate_constant
{},
dead_code_elimination
{},
dead_code_elimination
{},
fuse_ck_gemm_softmax_gemm
{
&
ctx
},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}),
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}),
dead_code_elimination
{},
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
fuse_mlir
{
&
ctx
},
...
...
test/verify/0ck_gemm_softmax_gemm.cpp
View file @
f3fcfcc7
...
@@ -31,15 +31,39 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
...
@@ -31,15 +31,39 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{
{
migraphx
::
program
create_program
()
const
migraphx
::
program
create_program
()
const
{
{
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {16, 12, 384, 64}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {16, 12, 384, 384}};
// auto m2_elements = 16 * 12 * 384 * 384;
// auto a = mm->add_parameter("1", m1_shape);
// auto b = mm->add_parameter("2", m1_shape);
// auto b1 = mm->add_parameter("3", m1_shape);
// auto c = mm->add_parameter("4", m1_shape);
// std::vector<float> eights(m2_elements, 0.125);
// auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
// std::vector<float> zeros(m2_elements, 0);
// auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
// std::vector<float> ones(m2_elements, 1);
// auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx
::
program
p
;
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
size_t
batch
=
2
;
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
384
,
2304
}};
auto
m2_elements
=
1
*
12
*
256
*
256
;
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
12
,
384
,
384
}};
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
m2_elements
=
batch
*
12
*
384
*
384
;
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
g
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
...
@@ -47,7 +71,13 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
...
@@ -47,7 +71,13 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
std
::
vector
<
float
>
ones
(
m2_elements
,
1
);
std
::
vector
<
float
>
ones
(
m2_elements
,
1
);
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch
,
384
,
36
,
64
}}}),
g
);
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
g
);
auto
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
g
);
auto
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
g
);
auto
b1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
g
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
...
...
tools/tune_ck.py
View file @
f3fcfcc7
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
ck_function
=
-
1
@
contextlib
.
contextmanager
@
contextlib
.
contextmanager
def
tmp_file
(
dump
=
None
):
def
tmp_file
(
dump
=
None
):
...
@@ -20,6 +21,9 @@ def pretty_print(obj):
...
@@ -20,6 +21,9 @@ def pretty_print(obj):
def
run_driver
(
b
):
def
run_driver
(
b
):
print
(
b
)
print
(
b
)
outfile
=
open
(
"temp2.json"
,
"w"
)
json
.
dump
(
b
,
outfile
)
outfile
.
close
()
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
capture_output
=
True
,
...
@@ -45,7 +49,7 @@ def get_device_time(s):
...
@@ -45,7 +49,7 @@ def get_device_time(s):
def
benchmark_ck
(
config
,
tuning
):
def
benchmark_ck
(
config
,
tuning
):
try
:
try
:
b
=
{
b
0
=
{
'settings'
:
{
'settings'
:
{
'iterations'
:
100
'iterations'
:
100
},
},
...
@@ -56,6 +60,18 @@ def benchmark_ck(config, tuning):
...
@@ -56,6 +60,18 @@ def benchmark_ck(config, tuning):
'inputs'
:
config
'inputs'
:
config
}
}
}
}
b1
=
{
'settings'
:
{
'iterations'
:
100
},
'compile_op'
:
{
'name'
:
'ck_gemm_softmax_gemm'
,
'check'
:
True
,
'tuning_val'
:
tuning
,
'inputs'
:
config
}
}
b
=
b0
if
(
ck_function
==
0
)
else
b1
for
line
in
run_driver
(
b
):
for
line
in
run_driver
(
b
):
dtime
=
get_device_time
(
line
)
dtime
=
get_device_time
(
line
)
print
(
dtime
)
print
(
dtime
)
...
@@ -72,17 +88,26 @@ def benchmark(config, size):
...
@@ -72,17 +88,26 @@ def benchmark(config, size):
def
parse_log
(
f
):
def
parse_log
(
f
):
for
line
in
open
(
f
).
readlines
():
for
line
in
open
(
f
).
readlines
():
line
=
line
.
strip
()
line
=
line
.
strip
()
if
not
line
.
startswith
(
'ck_gemm:'
):
global
ck_function
continue
if
line
.
startswith
(
'ck_gemm:'
):
line
=
line
[
len
(
'ck_gemm:'
):].
strip
()
line
=
line
[
len
(
'ck_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
config
=
json
.
loads
(
line
)
yield
config
ck_function
=
0
yield
config
if
line
.
startswith
(
'ck_gemm_softmax_gemm:'
):
line
=
line
[
len
(
'ck_gemm_softmax_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
ck_function
=
1
yield
config
def
benchmark_log
(
f
,
n
):
def
benchmark_log
(
f
,
n
):
result
=
[]
result
=
[]
for
config
in
parse_log
(
f
):
logs
=
parse_log
(
f
)
tuned
=
benchmark
(
config
,
n
)
for
config
in
logs
:
additional_tv
=
ck_function
*
2
tuned
=
benchmark
(
config
,
n
+
additional_tv
)
print
(
"Tuned:"
,
tuned
)
print
(
"Tuned:"
,
tuned
)
result
.
append
([
config
,
tuned
])
result
.
append
([
config
,
tuned
])
return
result
return
result
...
...
tools/tune_ck.sh
0 → 100755
View file @
f3fcfcc7
#!/bin/bash
MODEL
=
$1
LOG
=
"ck_bbc.log"
TUNING_DB
=
"ck_bbc.json"
rm
$LOG
touch
$LOG
for
N
in
1 16 32 64
do
MIGRAPHX_LOG_CK_GEMM
=
1 ./bin/driver run
$MODEL
-g
--fill1
input_ids
--input-dim
@input_ids
$N
384 |
grep
'ck_gemm.*: \[{'
|
sort
-u
>>
$LOG
done
python3 ../tools/tune_ck.py
-n
16
-l
$LOG
-o
$TUNING_DB
\ No newline at end of file
tools/tune_ck_gsg.py
0 → 100644
View file @
f3fcfcc7
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
@
contextlib
.
contextmanager
def
tmp_file
(
dump
=
None
):
tmp_name
=
None
try
:
with
tempfile
.
NamedTemporaryFile
(
mode
=
'w+'
,
delete
=
False
)
as
f
:
tmp_name
=
f
.
name
if
dump
:
dump
(
f
)
yield
tmp_name
finally
:
os
.
unlink
(
tmp_name
)
def
pretty_print
(
obj
):
print
(
json
.
dumps
(
obj
,
indent
=
2
))
def
run_driver
(
b
):
print
(
b
)
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
check
=
True
,
shell
=
True
)
for
line
in
cp
.
stdout
.
decode
().
split
(
"
\n
"
):
s
=
line
.
strip
()
if
not
s
:
continue
if
not
']: '
in
s
:
continue
yield
s
.
split
(
']: '
)[
1
].
strip
()
def
convert_to_float
(
s
):
return
s
[:
-
2
]
def
get_device_time
(
s
):
fields
=
s
.
split
(
','
)
return
convert_to_float
(
fields
[
-
1
].
strip
())
def
benchmark_ck
(
config
,
tuning
):
try
:
b
=
{
'settings'
:
{
'iterations'
:
100
},
'compile_op'
:
{
'name'
:
'ck_gemm_softmax_gemm'
,
'check'
:
True
,
'tuning_val'
:
tuning
,
'inputs'
:
config
}
}
for
line
in
run_driver
(
b
):
dtime
=
get_device_time
(
line
)
print
(
dtime
)
return
float
(
dtime
)
except
:
return
sys
.
float_info
.
max
def
benchmark
(
config
,
size
):
times
=
[
benchmark_ck
(
config
,
i
)
for
i
in
range
(
size
)]
return
times
.
index
(
min
(
times
))
def
parse_log
(
f
):
for
line
in
open
(
f
).
readlines
():
line
=
line
.
strip
()
if
not
line
.
startswith
(
'ck_gemm_softmax_gemm:'
):
continue
line
=
line
[
len
(
'ck_gemm_softmax_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
yield
config
def
benchmark_log
(
f
,
n
):
result
=
[]
for
config
in
parse_log
(
f
):
tuned
=
benchmark
(
config
,
n
)
print
(
"Tuned:"
,
tuned
)
result
.
append
([
config
,
tuned
])
return
result
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple tuner for CK gemms"
)
parser
.
add_argument
(
'--log'
,
'-l'
,
type
=
str
,
metavar
=
'file'
,
help
=
'Path to logfile'
)
parser
.
add_argument
(
'--out'
,
'-o'
,
type
=
str
,
metavar
=
'file'
,
help
=
'Output json file to save tunings'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
help
=
'Number of instances to tune'
)
args
=
parser
.
parse_args
()
return
args
def
run
(
args
):
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
run
(
parse_args
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment