Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
f3fcfcc7
Commit
f3fcfcc7
authored
Nov 16, 2022
by
Alan Turner
Browse files
Fix fusion pass and add tuning
parent
3133fd79
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
1067 additions
and
2651 deletions
+1067
-2651
src/targets/gpu/CMakeLists.txt
src/targets/gpu/CMakeLists.txt
+1
-0
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+73
-73
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
+102
-0
src/targets/gpu/include/migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp
...ts/gpu/include/migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp
+25
-0
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+48
-92
src/targets/gpu/jit/ck_gsg_instances.cpp
src/targets/gpu/jit/ck_gsg_instances.cpp
+610
-2462
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
...kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
+9
-9
src/targets/gpu/target.cpp
src/targets/gpu/target.cpp
+3
-0
test/verify/0ck_gemm_softmax_gemm.cpp
test/verify/0ck_gemm_softmax_gemm.cpp
+37
-7
tools/tune_ck.py
tools/tune_ck.py
+33
-8
tools/tune_ck.sh
tools/tune_ck.sh
+13
-0
tools/tune_ck_gsg.py
tools/tune_ck_gsg.py
+113
-0
No files found.
src/targets/gpu/CMakeLists.txt
View file @
f3fcfcc7
...
...
@@ -90,6 +90,7 @@ add_library(migraphx_gpu
device_name.cpp
elu.cpp
fuse_ck.cpp
fuse_ck_gemm_softmax_gemm.cpp
fuse_mlir.cpp
fuse_ops.cpp
gather.cpp
...
...
src/targets/gpu/fuse_ck.cpp
View file @
f3fcfcc7
...
...
@@ -49,43 +49,43 @@ struct ck_gemm
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_scale_bias_softmax_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm_scale_bias_softmax_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
// std::cout << input << std::endl;
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_scale_bias_softmax_gemm
);
//
struct ck_gemm_scale_bias_softmax_gemm
//
{
//
operation op = make_op("dot");
//
template <class Self, class F>
//
static auto reflect(Self& self, F f)
//
{
//
return pack(f(self.op, "op"));
//
}
//
std::string name() const { return "gpu::ck_gemm_softmax_gemm"; }
//
void check_gemm_shape(const shape& s) const
//
{
//
if(not contains(range(s.strides().rbegin(), s.strides().rbegin() + 3), 1))
//
MIGRAPHX_THROW("Invalid shape for ck_gemm_scale_bias_softmax_gemm");
//
}
//
shape compute_shape(std::vector<shape> inputs, const std::vector<module_ref>& mods) const
//
{
//
check_shapes{inputs, *this}.same_ndims();
//
// if(mods.size() != 1)
//
// MIGRAPHX_THROW("should have one submodule.");
//
if(inputs.size() < 2)
//
MIGRAPHX_THROW("should have at least two inputs.");
//
auto a = inputs[0];
//
auto b = inputs[1];
//
auto b1 = inputs[2];
//
for(const auto& input : inputs)
//
{
//
// std::cout << input << std::endl;
//
check_gemm_shape(input);
//
}
//
return op.compute_shape({op.compute_shape({a, b}), b1});
//
}
//
};
//
MIGRAPHX_REGISTER_OP(ck_gemm_scale_bias_softmax_gemm);
namespace
{
...
...
@@ -156,38 +156,38 @@ struct find_ck_gemm
struct
find_ck_gemm_scale_bias_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
pw
=
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
pw
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
}
//
auto matcher() const
//
{
//
auto gemm1 =
//
match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
//
auto pw =
//
match::name("pointwise")(match::any_of[match::inputs()](gemm1)).bind("scale_bias");
//
auto softmax = match::name("softmax")(match::any_of[match::inputs()](pw)).bind("softmax");
//
return match::name("dot")(is_ck_gemm().bind("gemm2"))(
//
match::any_of[match::inputs()](softmax));
//
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
std
::
cout
<<
"Matched"
<<
std
::
endl
;
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
sm_ins
=
r
.
instructions
[
"softmax"
];
auto
pw_ins
=
r
.
instructions
[
"scale_bias"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
gemm2_ins
->
debug_print
();
sm_ins
->
debug_print
();
pw_ins
->
debug_print
();
gemm1_ins
->
debug_print
();
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
// inputs.push_back(pw_ins->inputs().back()); // C
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
}
//
void apply(module_pass_manager& mpm, const match::matcher_result& r) const
//
{
//
std::cout << "Matched" << std::endl;
//
auto ins = r.result;
//
auto gemm2_ins = r.instructions["gemm2"];
//
auto sm_ins = r.instructions["softmax"];
//
auto pw_ins = r.instructions["scale_bias"];
//
auto gemm1_ins = r.instructions["gemm1"];
//
gemm2_ins->debug_print();
//
sm_ins->debug_print();
//
pw_ins->debug_print();
//
gemm1_ins->debug_print();
//
auto inputs = gemm1_ins->inputs(); // A, B
//
inputs.push_back(gemm2_ins->inputs().back()); // B1
//
// inputs.push_back(pw_ins->inputs().back()); // C
//
mpm.get_module().replace_instruction(
//
ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
//
}
// auto matcher() const
// {
...
...
@@ -223,11 +223,11 @@ struct find_ck_gemm_scale_bias_softmax_gemm
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
// mpm.get_module().debug_print();
match
::
find_matches
(
mpm
,
find_ck_gemm_scale_bias_softmax_gemm
{});
//
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
//
match::find_matches(mpm, find_ck_gemm_pointwise{});
//
if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
//
match::find_matches(mpm, find_ck_gemm{});
//
match::find_matches(mpm, find_ck_gemm_scale_bias_softmax_gemm{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
}
}
// namespace gpu
...
...
src/targets/gpu/fuse_ck_gemm_softmax_gemm.cpp
0 → 100644
View file @
f3fcfcc7
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/matcher.hpp>
#include <migraphx/pass_manager.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/register_op.hpp>
#include <migraphx/env.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module
;
namespace
gpu
{
struct
gemm_softmax_gemm_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for gemm_softmax_gemm_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
};
MIGRAPHX_REGISTER_OP
(
gemm_softmax_gemm_gemm
);
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
{
if
(
ins
->
name
()
!=
"dot"
)
return
false
;
auto
a
=
ins
->
inputs
().
front
()
->
get_shape
();
auto
b
=
ins
->
inputs
().
back
()
->
get_shape
();
if
(
a
.
lens
().
back
()
>
2048
)
return
false
;
return
true
;
}
struct
find_gemm_softmax_gemm_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
mul
=
match
::
name
(
"mul"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale"
);
auto
add
=
match
::
name
(
"add"
)(
match
::
any_of
[
match
::
inputs
()](
mul
));
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
add
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
mpm
.
get_module
().
replace_instruction
(
ins
,
gemm_softmax_gemm_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
}
};
}
// namespace
void
fuse_ck_gemm_softmax_gemm
::
apply
(
module_pass_manager
&
mpm
)
const
{
match
::
find_matches
(
mpm
,
find_gemm_softmax_gemm_gemm
{});
}
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/include/migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp
0 → 100644
View file @
f3fcfcc7
#ifndef MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_GPU_FUSE_CK_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/config.hpp>
#include <migraphx/gpu/context.hpp>
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
struct
module_pass_manager
;
namespace
gpu
{
struct
fuse_ck_gemm_softmax_gemm
{
context
*
ctx
=
nullptr
;
std
::
string
name
()
const
{
return
"gpu::fuse_ck_gemm_softmax_gemm"
;
}
void
apply
(
module_pass_manager
&
mpm
)
const
;
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
#endif // MIGRAPHX_GUARD_GPU_FUSE_CK_HPP
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
f3fcfcc7
...
...
@@ -39,7 +39,7 @@
#include <migraphx/file_buffer.hpp>
const
std
::
vector
<
std
::
string
>&
get_instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
get_
gsg_
instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
...
...
@@ -62,34 +62,12 @@ namespace migraphx {
${preamble}
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck_passthrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck_scale;//ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;//ck_add;//ck::tensor_operation::element_wise::Add;
using gemm = CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false, std::ratio<1, 8>>;
extern "C" {
__global__ void ${kernel}(${params})
{
// transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
// ck_gemm_softmax_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
// });
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<
gemm
, ${blocks_per_batch}>(xs...);
ck_gemm_softmax_gemm<
CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle<${instance}>
, ${blocks_per_batch}>(xs...);
});
}
...
...
@@ -112,13 +90,13 @@ struct instance
std
::
size_t
get_pb
(
std
::
size_t
i
)
const
{
assert
(
i
<
4
);
assert
(
i
<
=
4
);
return
int_at
(
block_size_index
+
1
+
i
);
}
std
::
array
<
std
::
size_t
,
3
>
get_pad
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
std
::
array
<
std
::
size_t
,
4
>
get_pad
(
const
std
::
array
<
std
::
size_t
,
4
>&
config
)
const
{
std
::
array
<
std
::
size_t
,
3
>
result
{};
std
::
array
<
std
::
size_t
,
4
>
result
{};
for
(
auto
i
:
range
(
config
.
size
()))
{
result
[
i
]
=
int_div_ceil
(
config
[
i
],
get_pb
(
i
))
*
get_pb
(
i
)
-
config
[
i
];
...
...
@@ -126,33 +104,16 @@ struct instance
return
result
;
}
std
::
size_t
get_grid_size
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
std
::
size_t
get_grid_size
(
const
std
::
array
<
std
::
size_t
,
4
>&
config
)
const
{
return
int_div_ceil
(
config
[
0
],
get_pb
(
0
))
*
int_div_ceil
(
config
[
1
],
get_pb
(
1
));
}
void
set_ds_layout
(
const
std
::
string
&
s
)
{
assert
(
params
[
2
]
==
"ck::Tuple<>"
);
params
[
2
]
=
s
;
}
void
set_ds_type
(
const
std
::
string
&
s
)
{
assert
(
params
[
8
]
==
"ck::Tuple<>"
);
params
[
8
]
=
s
;
}
void
set_ds_op
(
const
std
::
string
&
s
)
{
assert
(
params
[
12
]
==
"ck_passthrough"
);
params
[
12
]
=
s
;
return
int_div_ceil
(
config
[
0
],
get_pb
(
0
))
*
int_div_ceil
(
config
[
3
],
get_pb
(
3
));
}
void
set_gemm
(
const
std
::
string
&
s
)
{
assert
(
params
[
13
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
);
params
[
13
]
=
s
;
assert
(
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
or
params
[
15
]
==
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
);
params
[
15
]
=
s
;
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
...
...
@@ -179,12 +140,12 @@ static std::size_t get_tuning_for(const std::vector<shape>& inputs)
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
std
::
cout
<<
"*********** Warning: No CK
GSG
tuning!"
<<
std
::
endl
;
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
std
::
cout
<<
"*********** Warning: CK
GSG
tuning missing for config!"
<<
std
::
endl
;
return
4
;
}
return
it
->
second
;
...
...
@@ -194,8 +155,12 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
return
s
.
transposed
()
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
if
(
not
s
.
transposed
())
return
"ck::tensor_layout::gemm::RowMajor"
;
auto
lens
=
s
.
lens
();
return
lens
[
lens
.
size
()
-
1
]
>
lens
[
lens
.
size
()
-
2
]
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
std
::
string
get_type
(
const
shape
&
s
)
...
...
@@ -222,6 +187,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
b1_shape
=
inputs
[
2
];
auto
c_shape
=
inputs
.
back
();
auto
m
=
a_shape
.
lens
()[
0
];
auto
k
=
a_shape
.
lens
()[
1
];
...
...
@@ -229,48 +195,39 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
auto
rank
=
a_shape
.
lens
().
size
();
// std::array<char, 3> keys{'M', 'N', 'K'};
// std::array<std::size_t, 3> config{
// c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
// auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
// auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
// })};
// assert(inputs.size() < 4 or v.contains("post"));
// if(v.contains("post"))
// {
// ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
// ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
// ip.set_ds_op(v.at("post").to<std::string>());
// }
// auto padding = ip.get_pad(config);
// std::string gemm_type;
// for(auto i : range(padding.size()))
// {
// if(padding[i] != 0)
// gemm_type += keys[i];
// }
// if(gemm_type.empty())
// gemm_type = "Default";
// else
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto
gemm1_nperblock
=
64
;
auto
gemm01_mperblock
=
256
;
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
int_div_ceil
(
n
,
gemm1_nperblock
);
// ip.get_grid_size(config);
std
::
array
<
char
,
4
>
keys
{
'M'
,
'N'
,
'K'
,
'O'
};
// config (m0, n0, k0, n1)
std
::
array
<
std
::
size_t
,
4
>
config
{
c_shape
.
lens
()[
rank
-
2
],
b_shape
.
lens
()[
rank
-
2
],
a_shape
.
lens
().
back
(),
c_shape
.
lens
().
back
()};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
b1_shape
,
c_shape
}));
auto
ip
=
instance
{
get_gsg_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
get_layout
(
c_shape
)
==
x
[
3
]
and
get_type
(
a_shape
)
==
x
[
4
]
and
get_type
(
b_shape
)
==
x
[
5
]
and
get_type
(
c_shape
)
==
x
[
9
];
})};
auto
padding
=
ip
.
get_pad
(
config
);
std
::
string
gemm_type
;
for
(
auto
i
:
range
(
padding
.
size
()))
{
if
(
padding
[
i
]
!=
0
)
gemm_type
+=
keys
[
i
];
}
if
(
gemm_type
.
empty
())
gemm_type
=
"Default"
;
else
gemm_type
+=
"Padding"
;
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
hip_compile_options
options
;
auto
block_size
=
256
;
//
ip.get_block_size();
auto
block_size
=
ip
.
get_block_size
();
auto
grid_size
=
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
...
...
@@ -282,7 +239,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
{{
"instance"
,
""
/*
ip.str()
*/
},
{{
"instance"
,
ip
.
str
()},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
...
...
@@ -302,7 +259,6 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
...
...
@@ -310,7 +266,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
[
2
],
shapes
.
back
()};
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
...
...
src/targets/gpu/jit/ck_gsg_instances.cpp
View file @
f3fcfcc7
#include <algorithm>
#include <vector>
#include <string>
...
...
@@ -10,17 +8,19 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
{
static
std
::
vector
<
std
::
vector
<
std
::
vector
<
std
::
string
>>>
instances
=
{
{{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
...
...
@@ -28,10 +28,14 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"256"
,
"128"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"2"
,
"4"
,
"2"
,
"ck::Sequence<4,64,1>"
,
...
...
@@ -40,30 +44,41 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<8,32,1>"
,
"true"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<16,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
...
...
@@ -71,42 +86,57 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"256"
,
"128"
,
"32"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"4"
,
"2"
,
"4"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"
true
"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"8"
,
"1"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
...
...
@@ -114,42 +144,57 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"128"
,
"256"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"1"
,
"8"
,
"2"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"
true
"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<16,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
...
...
@@ -157,11 +202,15 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"128"
,
"256"
,
"32"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"2"
,
"1"
,
"8"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
...
...
@@ -169,159 +218,215 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"2"
,
"8"
,
"8"
,
"
1
"
,
"
true
"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"2"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"128"
,
"64"
,
"64"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"1"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"false"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"ck::Sequence<
4,32
,1>"
,
"
false
"
,
"ck::Sequence<
16,16
,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"
0
"
,
"
false
"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"128"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"1"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"ck::Sequence<
4,32
,1>"
,
"
true
"
,
"ck::Sequence<
16,16
,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"2"
,
"false"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"1"
,
"4"
,
"4"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"8"
,
"8"
,
"false"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"
false
"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
...
...
@@ -329,2696 +434,739 @@ get_gsg_instance(std::size_t i, const std::function<bool(const std::vector<std::
"128"
,
"128"
,
"32"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"2"
,
"2"
,
"1"
,
"4"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"
true
"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"8"
,
"1"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"256"
,
"64"
,
"256"
,
"32"
,
"8"
,
"2"
,
"32"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"16"
,
"16"
,
"1"
,
"16"
,
"8"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"ck::Sequence<8,
16
,1>"
,
"
true
"
,
"ck::Sequence<8,
32
,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
"8"
,
"ck::Sequence<1,16,1,16>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"256"
,
"64"
,
"256"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"16"
,
"16"
,
"1"
,
"16"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"ck::Sequence<
4,32
,1>"
,
"
true
"
,
"ck::Sequence<
16,16
,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"8"
,
"1"
,
"false"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
"4"
,
"ck::Sequence<1,32,1,8>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"256"
,
"64"
,
"256"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"2"
,
"16"
,
"16"
,
"1"
,
"16"
,
"8"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"ck::Sequence<
4
,32,1>"
,
"
true
"
,
"ck::Sequence<
8
,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
"8"
,
"ck::Sequence<1,16,1,16>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"256"
,
"64"
,
"256"
,
"64"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"16"
,
"16"
,
"1"
,
"16"
,
"4"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"ck::Sequence<
4,32
,1>"
,
"
true
"
,
"ck::Sequence<
16,16
,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"2"
,
"false"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
"4"
,
"ck::Sequence<1,32,1,8>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"1"
,
"256"
,
"128"
,
"128"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"4"
,
"4"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<16,16,1>"
,
"false"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"false"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"1"
,
"256"
,
"128"
,
"64"
,
"32"
,
"128"
,
"32"
,
"8"
,
"8"
,
"2"
,
"32"
,
"32"
,
"2"
,
"1"
,
"2"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"
1
"
,
"
true
"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"true"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"false"
,
"1"
,
"8"
,
"1"
,
"1"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"1"
,
"256"
,
"6
4
"
,
"
25
6"
,
"128"
,
"40"
,
"64"
,
"32"
,
"8"
,
"4"
,
"4"
,
"2"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"4"
,
"2"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<8,32,1>"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<16,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"1"
,
"256"
,
"64"
,
"256"
,
"128"
,
"40"
,
"128"
,
"32"
,
"8"
,
"8"
,
"4"
,
"4"
,
"2"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"4"
,
"4"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"8"
,
"1"
,
"1"
,
"false"
,
"1"
,
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
"8"
,
"false"
,
"std::ratio<1, 8>"
},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "64",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "2",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<16,16,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
// {"ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::ColumnMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::tensor_layout::gemm::RowMajor",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "ck::half_t",
// "float",
// "ck::half_t",
// "ck_passthrough",
// "ck_passthrough",
// "ck_scale",
// "ck_passthrough",
// "ck_passthrough",
// "ck::tensor_operation::device::GemmSpecialization::MNKOPadding",
// "1",
// "256",
// "128",
// "256",
// "40",
// "128",
// "32",
// "4",
// "4",
// "2",
// "32",
// "32",
// "1",
// "8",
// "4",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<2,128,1>",
// "ck::Sequence<1,0,2>",
// "ck::Sequence<1,0,2>",
// "2",
// "4",
// "4",
// "false",
// "ck::Sequence<8,32,1>",
// "ck::Sequence<0,2,1>",
// "ck::Sequence<0,2,1>",
// "1",
// "4",
// "2",
// "false",
// "1",
// "2",
// "ck::Sequence<1,32,1,8>",
// "8",
// "false",
// "std::ratio<1, 8>"},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"1"
,
"64"
,
"64"
,
"256"
,
"128"
,
"128"
,
"40"
,
"64"
,
"32"
,
"8"
,
"8"
,
"4"
,
"4"
,
"2"
,
"32"
,
"32"
,
"1"
,
"4"
,
"2"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"
8
"
,
"
8
"
,
"
1
"
,
"ck::Sequence<
4
,16,1>"
,
"
4
"
,
"
4
"
,
"
false
"
,
"ck::Sequence<
16
,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"2"
,
"false"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
"2"
,
"ck::Sequence<1,32,1,8>"
,
"8"
,
"false"
,
"std::ratio<1, 8>"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_scale"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::MNKOPadding"
,
"1"
,
"64"
,
"64"
,
"32"
,
"256"
,
"128"
,
"128"
,
"40"
,
"128"
,
"32"
,
"8"
,
"8"
,
"4"
,
"4"
,
"2"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"4"
,
"4"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<2,128,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"4"
,
"4"
,
"false"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
}},
{{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"256"
,
"128"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"256"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"256"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"2"
,
"4"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"256"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"128"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"64"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"64"
,
"128"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<8,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"64"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"64"
,
"128"
,
"32"
,
"2"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<16,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"1"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"64"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"64"
,
"32"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
}},
{{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"256"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"256"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"32"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"32"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"64"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"64"
,
"32"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::RowMajor"
,
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<1,0,2>"
,
"ck::Sequence<1,0,2>"
,
"2"
,
"8"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
}},
{{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"256"
,
"128"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"256"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"256"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"2"
,
"4"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"256"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"4"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"128"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"4"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"64"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"128"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"64"
,
"128"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<8,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"128"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"64"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<16,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"128"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"1"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"64"
,
"128"
,
"32"
,
"2"
,
"2"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<16,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"ck::Sequence<8,32,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"2"
,
"0"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"256"
,
"64"
,
"128"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"1"
,
"8"
,
"1"
,
"ck::Sequence<4,64,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"false"
,
"1"
,
"2"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,32,1,8>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"64"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"64"
,
"32"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"2"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
},
{
"ck::tensor_layout::gemm::ColumnMajor"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::Tuple<>"
,
"ck::tensor_layout::gemm::RowMajor"
,
"ck::half_t"
,
"ck::half_t"
,
"float"
,
"float"
,
"ck::Tuple<>"
,
"ck::half_t"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck_passthrough"
,
"ck::tensor_operation::device::GemmSpecialization::Default"
,
"1"
,
"64"
,
"32"
,
"64"
,
"32"
,
"8"
,
"8"
,
"32"
,
"32"
,
"1"
,
"2"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"2"
,
"8"
,
"1"
,
"ck::Sequence<4,16,1>"
,
"ck::Sequence<0,2,1>"
,
"ck::Sequence<0,2,1>"
,
"1"
,
"4"
,
"8"
,
"1"
,
"1"
,
"1"
,
"ck::Sequence<1,16,1,4>"
,
"8"
}}};
"false"
,
"std::ratio<1, 8>"
}}};
auto
it
=
std
::
find_if
(
instances
.
begin
(),
instances
.
end
(),
[
&
](
const
auto
&
v
)
{
return
pred
(
v
[
0
]);
});
return
it
->
at
(
i
);
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
View file @
f3fcfcc7
...
...
@@ -69,7 +69,7 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
n
=
b_shape
.
lens
[
0
];
// col-major
constexpr
const
auto
n
=
b_shape
.
lens
[
1
];
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK0
=
k
/
BK1
;
...
...
@@ -85,9 +85,9 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
// row-major
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
// row-major
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// row
l
-major
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// row-major
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
...
...
@@ -139,11 +139,11 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
static_assert
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
));
//
static_assert(GridwiseGemm::CheckValidity(a_grid_desc_ak0_m_ak1,
//
b_grid_desc_bk0_n_bk1,
//
b1_grid_desc_bk0_n_bk1,
//
c_grid_desc_m_n,
//
block_2_ctile_map));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
to_ck_const_pointer
(
b1
.
data
()),
...
...
src/targets/gpu/target.cpp
View file @
f3fcfcc7
...
...
@@ -56,6 +56,7 @@
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/device_name.hpp>
#include <migraphx/gpu/fuse_ck.hpp>
#include <migraphx/gpu/fuse_ck_gemm_softmax_gemm.hpp>
#include <migraphx/gpu/fuse_mlir.hpp>
#include <migraphx/gpu/fuse_ops.hpp>
#include <migraphx/gpu/prefuse_ops.hpp>
...
...
@@ -131,6 +132,8 @@ std::vector<pass> target::get_passes(migraphx::context& gctx, const compile_opti
simplify_reshapes
{},
propagate_constant
{},
dead_code_elimination
{},
fuse_ck_gemm_softmax_gemm
{
&
ctx
},
dead_code_elimination
{},
enable_pass
(
not
enabled
(
MIGRAPHX_DISABLE_POINTWISE_FUSION
{}),
fuse_pointwise
{}),
dead_code_elimination
{},
fuse_mlir
{
&
ctx
},
...
...
test/verify/0ck_gemm_softmax_gemm.cpp
View file @
f3fcfcc7
...
...
@@ -31,15 +31,39 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
{
migraphx
::
program
create_program
()
const
{
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {16, 12, 384, 64}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {16, 12, 384, 384}};
// auto m2_elements = 16 * 12 * 384 * 384;
// auto a = mm->add_parameter("1", m1_shape);
// auto b = mm->add_parameter("2", m1_shape);
// auto b1 = mm->add_parameter("3", m1_shape);
// auto c = mm->add_parameter("4", m1_shape);
// std::vector<float> eights(m2_elements, 0.125);
// auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
// std::vector<float> zeros(m2_elements, 0);
// auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
// std::vector<float> ones(m2_elements, 1);
// auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}), b);
// auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b);
// auto scale = mm->add_instruction(migraphx::make_op("mul"), gemm1, eight);
// auto bias = mm->add_instruction(migraphx::make_op("add"), scale, zero);
// auto softmax = mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
1
*
12
*
256
*
256
;
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
size_t
batch
=
2
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
384
,
2304
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
12
,
384
,
384
}};
auto
m2_elements
=
batch
*
12
*
384
*
384
;
auto
g
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
...
...
@@ -47,7 +71,13 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
std
::
vector
<
float
>
ones
(
m2_elements
,
1
);
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch
,
384
,
36
,
64
}}}),
g
);
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
g
);
auto
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
g
);
auto
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
g
);
auto
b1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
g
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
...
...
tools/tune_ck.py
View file @
f3fcfcc7
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
ck_function
=
-
1
@
contextlib
.
contextmanager
def
tmp_file
(
dump
=
None
):
...
...
@@ -20,6 +21,9 @@ def pretty_print(obj):
def
run_driver
(
b
):
print
(
b
)
outfile
=
open
(
"temp2.json"
,
"w"
)
json
.
dump
(
b
,
outfile
)
outfile
.
close
()
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
...
...
@@ -45,7 +49,7 @@ def get_device_time(s):
def
benchmark_ck
(
config
,
tuning
):
try
:
b
=
{
b
0
=
{
'settings'
:
{
'iterations'
:
100
},
...
...
@@ -56,6 +60,18 @@ def benchmark_ck(config, tuning):
'inputs'
:
config
}
}
b1
=
{
'settings'
:
{
'iterations'
:
100
},
'compile_op'
:
{
'name'
:
'ck_gemm_softmax_gemm'
,
'check'
:
True
,
'tuning_val'
:
tuning
,
'inputs'
:
config
}
}
b
=
b0
if
(
ck_function
==
0
)
else
b1
for
line
in
run_driver
(
b
):
dtime
=
get_device_time
(
line
)
print
(
dtime
)
...
...
@@ -72,17 +88,26 @@ def benchmark(config, size):
def
parse_log
(
f
):
for
line
in
open
(
f
).
readlines
():
line
=
line
.
strip
()
if
not
line
.
startswith
(
'ck_gemm:'
):
continue
line
=
line
[
len
(
'ck_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
yield
config
global
ck_function
if
line
.
startswith
(
'ck_gemm:'
):
line
=
line
[
len
(
'ck_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
ck_function
=
0
yield
config
if
line
.
startswith
(
'ck_gemm_softmax_gemm:'
):
line
=
line
[
len
(
'ck_gemm_softmax_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
ck_function
=
1
yield
config
def
benchmark_log
(
f
,
n
):
result
=
[]
for
config
in
parse_log
(
f
):
tuned
=
benchmark
(
config
,
n
)
logs
=
parse_log
(
f
)
for
config
in
logs
:
additional_tv
=
ck_function
*
2
tuned
=
benchmark
(
config
,
n
+
additional_tv
)
print
(
"Tuned:"
,
tuned
)
result
.
append
([
config
,
tuned
])
return
result
...
...
tools/tune_ck.sh
0 → 100755
View file @
f3fcfcc7
#!/bin/bash
MODEL
=
$1
LOG
=
"ck_bbc.log"
TUNING_DB
=
"ck_bbc.json"
rm
$LOG
touch
$LOG
for
N
in
1 16 32 64
do
MIGRAPHX_LOG_CK_GEMM
=
1 ./bin/driver run
$MODEL
-g
--fill1
input_ids
--input-dim
@input_ids
$N
384 |
grep
'ck_gemm.*: \[{'
|
sort
-u
>>
$LOG
done
python3 ../tools/tune_ck.py
-n
16
-l
$LOG
-o
$TUNING_DB
\ No newline at end of file
tools/tune_ck_gsg.py
0 → 100644
View file @
f3fcfcc7
import
os
,
json
,
subprocess
,
tempfile
,
sys
,
argparse
,
contextlib
@
contextlib
.
contextmanager
def
tmp_file
(
dump
=
None
):
tmp_name
=
None
try
:
with
tempfile
.
NamedTemporaryFile
(
mode
=
'w+'
,
delete
=
False
)
as
f
:
tmp_name
=
f
.
name
if
dump
:
dump
(
f
)
yield
tmp_name
finally
:
os
.
unlink
(
tmp_name
)
def
pretty_print
(
obj
):
print
(
json
.
dumps
(
obj
,
indent
=
2
))
def
run_driver
(
b
):
print
(
b
)
with
tmp_file
(
lambda
tf
:
json
.
dump
(
b
,
tf
))
as
tf
:
cp
=
subprocess
.
run
(
'./bin/gpu-driver {}'
.
format
(
tf
),
capture_output
=
True
,
check
=
True
,
shell
=
True
)
for
line
in
cp
.
stdout
.
decode
().
split
(
"
\n
"
):
s
=
line
.
strip
()
if
not
s
:
continue
if
not
']: '
in
s
:
continue
yield
s
.
split
(
']: '
)[
1
].
strip
()
def
convert_to_float
(
s
):
return
s
[:
-
2
]
def
get_device_time
(
s
):
fields
=
s
.
split
(
','
)
return
convert_to_float
(
fields
[
-
1
].
strip
())
def
benchmark_ck
(
config
,
tuning
):
try
:
b
=
{
'settings'
:
{
'iterations'
:
100
},
'compile_op'
:
{
'name'
:
'ck_gemm_softmax_gemm'
,
'check'
:
True
,
'tuning_val'
:
tuning
,
'inputs'
:
config
}
}
for
line
in
run_driver
(
b
):
dtime
=
get_device_time
(
line
)
print
(
dtime
)
return
float
(
dtime
)
except
:
return
sys
.
float_info
.
max
def
benchmark
(
config
,
size
):
times
=
[
benchmark_ck
(
config
,
i
)
for
i
in
range
(
size
)]
return
times
.
index
(
min
(
times
))
def
parse_log
(
f
):
for
line
in
open
(
f
).
readlines
():
line
=
line
.
strip
()
if
not
line
.
startswith
(
'ck_gemm_softmax_gemm:'
):
continue
line
=
line
[
len
(
'ck_gemm_softmax_gemm:'
):].
strip
()
config
=
json
.
loads
(
line
)
yield
config
def
benchmark_log
(
f
,
n
):
result
=
[]
for
config
in
parse_log
(
f
):
tuned
=
benchmark
(
config
,
n
)
print
(
"Tuned:"
,
tuned
)
result
.
append
([
config
,
tuned
])
return
result
def
parse_args
():
parser
=
argparse
.
ArgumentParser
(
description
=
"Simple tuner for CK gemms"
)
parser
.
add_argument
(
'--log'
,
'-l'
,
type
=
str
,
metavar
=
'file'
,
help
=
'Path to logfile'
)
parser
.
add_argument
(
'--out'
,
'-o'
,
type
=
str
,
metavar
=
'file'
,
help
=
'Output json file to save tunings'
)
parser
.
add_argument
(
'-n'
,
type
=
int
,
help
=
'Number of instances to tune'
)
args
=
parser
.
parse_args
()
return
args
def
run
(
args
):
tuned
=
benchmark_log
(
args
.
log
,
args
.
n
)
json
.
dump
(
tuned
,
open
(
args
.
out
,
'w+'
))
run
(
parse_args
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment