Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
d7ea085c
"vscode:/vscode.git/clone" did not exist on "c93e86956ca421336d1dae830e6d3f45a7d1e4e3"
Commit
d7ea085c
authored
Nov 15, 2022
by
Alan Turner
Browse files
Add gemm-softmax-gemm fusion
parent
4c370d64
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
4084 additions
and
4 deletions
+4084
-4
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+101
-4
src/targets/gpu/gemm_impl.cpp
src/targets/gpu/gemm_impl.cpp
+17
-0
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+317
-0
src/targets/gpu/jit/ck_gsg_instances.cpp
src/targets/gpu/jit/ck_gsg_instances.cpp
+3025
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+22
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
...kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
+170
-0
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
...nclude/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
+291
-0
test/onnx/gemm_softmax_gemm_test.onnx
test/onnx/gemm_softmax_gemm_test.onnx
+46
-0
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+36
-0
test/verify/0ck_gemm_softmax_gemm.cpp
test/verify/0ck_gemm_softmax_gemm.cpp
+59
-0
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
d7ea085c
...
@@ -49,6 +49,44 @@ struct ck_gemm
...
@@ -49,6 +49,44 @@ struct ck_gemm
};
};
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
MIGRAPHX_REGISTER_OP
(
ck_gemm
);
struct
ck_gemm_scale_bias_softmax_gemm
{
operation
op
=
make_op
(
"dot"
);
template
<
class
Self
,
class
F
>
static
auto
reflect
(
Self
&
self
,
F
f
)
{
return
pack
(
f
(
self
.
op
,
"op"
));
}
std
::
string
name
()
const
{
return
"gpu::ck_gemm_softmax_gemm"
;
}
void
check_gemm_shape
(
const
shape
&
s
)
const
{
if
(
not
contains
(
range
(
s
.
strides
().
rbegin
(),
s
.
strides
().
rbegin
()
+
3
),
1
))
MIGRAPHX_THROW
(
"Invalid shape for ck_gemm_scale_bias_softmax_gemm"
);
}
shape
compute_shape
(
std
::
vector
<
shape
>
inputs
,
const
std
::
vector
<
module_ref
>&
mods
)
const
{
check_shapes
{
inputs
,
*
this
}.
same_ndims
();
// if(mods.size() != 1)
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
{
//std::cout << input << std::endl;
check_gemm_shape
(
input
);
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
}
};
MIGRAPHX_REGISTER_OP
(
ck_gemm_scale_bias_softmax_gemm
);
namespace
{
namespace
{
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
MIGRAPHX_PRED_MATCHER
(
is_ck_gemm
,
instruction_ref
ins
)
...
@@ -116,14 +154,73 @@ struct find_ck_gemm
...
@@ -116,14 +154,73 @@ struct find_ck_gemm
}
}
};
};
struct
find_ck_gemm_scale_bias_softmax_gemm
{
auto
matcher
()
const
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
pw
=
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
pw
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
std
::
cout
<<
"Matched"
<<
std
::
endl
;
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
sm_ins
=
r
.
instructions
[
"softmax"
];
auto
pw_ins
=
r
.
instructions
[
"scale_bias"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
gemm2_ins
->
debug_print
();
sm_ins
->
debug_print
();
pw_ins
->
debug_print
();
gemm1_ins
->
debug_print
();
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
//inputs.push_back(pw_ins->inputs().back()); // C
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
}
// auto matcher() const
// {
// auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax");
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// {
// std::cout << "Matched" << std::endl;
// auto ins = r.result;
// auto gemm2_ins = r.instructions["gemm2"];
// auto sm_ins = r.instructions["softmax"];
// auto gemm1_ins = r.instructions["gemm1"];
// gemm2_ins->debug_print();
// sm_ins->debug_print();
// gemm1_ins->debug_print();
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
};
}
// namespace
}
// namespace
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
void
fuse_ck
::
apply
(
module_pass_manager
&
mpm
)
const
{
{
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM_FUSION
{}))
// mpm.get_module().debug_print();
match
::
find_matches
(
mpm
,
find_ck_gemm_pointwise
{});
match
::
find_matches
(
mpm
,
find_ck_gemm_scale_bias_softmax_gemm
{});
if
(
not
enabled
(
MIGRAPHX_DISABLE_CK_GEMM
{}))
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM_FUSION{}))
match
::
find_matches
(
mpm
,
find_ck_gemm
{});
// match::find_matches(mpm, find_ck_gemm_pointwise{});
// if(not enabled(MIGRAPHX_DISABLE_CK_GEMM{}))
// match::find_matches(mpm, find_ck_gemm{});
}
}
}
// namespace gpu
}
// namespace gpu
...
...
src/targets/gpu/gemm_impl.cpp
View file @
d7ea085c
...
@@ -26,6 +26,8 @@
...
@@ -26,6 +26,8 @@
#include <migraphx/reduce_dims.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/permutation.hpp>
#include <migraphx/permutation.hpp>
#include <cstdio>
namespace
migraphx
{
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
namespace
gpu
{
...
@@ -116,6 +118,21 @@ void gemm_impl(context& ctx,
...
@@ -116,6 +118,21 @@ void gemm_impl(context& ctx,
{
{
beta
=
0
;
beta
=
0
;
}
}
// else
// {
// if (args[2].get_shape().lens()[1] == 12 and args[2].get_shape().lens()[2] == 2)
// {
// args[2].visit([&](auto output){
// std::cout << args[2].get_shape() << std::endl;
// for (auto i = 0; i < args[2].get_shape().elements(); ++i)
// {
// //if (output[i] == 0 )
// std::cout << output[i] << ", ";
// }
// std::cout << std::endl;
// });
// }
// }
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transa
=
is_transposed
(
args
[
0
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
bool
transb
=
is_transposed
(
args
[
1
].
get_shape
());
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
0 → 100644
View file @
d7ea085c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include <fstream>
#include <filesystem>
#include <migraphx/gpu/compiler.hpp>
#include <migraphx/make_op.hpp>
#include <migraphx/gpu/context.hpp>
#include <migraphx/gpu/compile_hip_code_object.hpp>
#include <migraphx/gpu/compile_hip.hpp>
#include <migraphx/gpu/compile_gen.hpp>
#include <migraphx/ranges.hpp>
#include <migraphx/env.hpp>
#include <migraphx/reduce_dims.hpp>
#include <migraphx/stringutils.hpp>
#include <migraphx/module.hpp>
#include <migraphx/env.hpp>
#include <migraphx/file_buffer.hpp>
const
std
::
vector
<
std
::
string
>&
get_instance
(
std
::
size_t
i
,
const
std
::
function
<
bool
(
const
std
::
vector
<
std
::
string
>&
)
>&
pred
);
namespace
migraphx
{
inline
namespace
MIGRAPHX_INLINE_NS
{
namespace
gpu
{
using
namespace
migraphx
::
gpu
::
gen
;
// NOLINT
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_LOG_CK_GEMM
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_TUNING
);
MIGRAPHX_DECLARE_ENV_VAR
(
MIGRAPHX_CK_DEBUG
);
// NOLINTNEXTLINE
static
const
char
*
const
ck_gemm_softmax_gemm_kernel
=
R"__migraphx__(
#include <args.hpp>
#include <migraphx/kernels/ck_gemm_softmax_gemm.hpp>
#include <migraphx/kernels/pointwise.hpp>
namespace migraphx {
${preamble}
template <ck::index_t... Is>
using S = ck::Sequence<Is...>;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using F16 = ck::half_t;
using F32 = float;
using PassThrough = ck_passthrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
using AElementOp = PassThrough;
using B0ElementOp = PassThrough;
using Acc0ElementOp = ck_scale;//ck::tensor_operation::element_wise::Scale;
using B1ElementOp = PassThrough;
using CElementOp = PassThrough;//ck_add;//ck::tensor_operation::element_wise::Add;
using gemm = CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle< Row, Col, Row, Row, F16, F16, F16, F16, F32, F16, AElementOp, B0ElementOp, Acc0ElementOp, B1ElementOp, CElementOp, GemmDefault, 1, 256, 256, 128, 32, 64, 32, 8, 8, 2, 32, 32, 2, 4, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 8, 8, true, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 2, false, 1, 2, S<1, 32, 1, 8>, 8, false, std::ratio<1, 8>>;
extern "C" {
__global__ void ${kernel}(${params})
{
// transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
// ck_gemm_softmax_gemm<CK_DeviceGemmMultipleD<${instance}>, ${blocks_per_batch}>(xs...);
// });
transform_args(make_tensors(), rotate_last())(${args})([](auto... xs) {
ck_gemm_softmax_gemm<gemm, ${blocks_per_batch}>(xs...);
});
}
}
} // namespace migraphx
)__migraphx__"
;
static
std
::
size_t
int_div_ceil
(
std
::
size_t
x
,
std
::
size_t
y
)
{
return
(
x
+
y
-
1
)
/
y
;
}
struct
instance
{
std
::
vector
<
std
::
string
>
params
;
static
const
std
::
size_t
block_size_index
=
17
;
std
::
size_t
int_at
(
std
::
size_t
i
)
const
{
return
std
::
stoull
(
params
[
i
]);
}
std
::
size_t
get_block_size
()
const
{
return
int_at
(
block_size_index
);
}
std
::
size_t
get_pb
(
std
::
size_t
i
)
const
{
assert
(
i
<
4
);
return
int_at
(
block_size_index
+
1
+
i
);
}
std
::
array
<
std
::
size_t
,
3
>
get_pad
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
{
std
::
array
<
std
::
size_t
,
3
>
result
{};
for
(
auto
i
:
range
(
config
.
size
()))
{
result
[
i
]
=
int_div_ceil
(
config
[
i
],
get_pb
(
i
))
*
get_pb
(
i
)
-
config
[
i
];
}
return
result
;
}
std
::
size_t
get_grid_size
(
const
std
::
array
<
std
::
size_t
,
3
>&
config
)
const
{
return
int_div_ceil
(
config
[
0
],
get_pb
(
0
))
*
int_div_ceil
(
config
[
1
],
get_pb
(
1
));
}
void
set_ds_layout
(
const
std
::
string
&
s
)
{
assert
(
params
[
2
]
==
"ck::Tuple<>"
);
params
[
2
]
=
s
;
}
void
set_ds_type
(
const
std
::
string
&
s
)
{
assert
(
params
[
8
]
==
"ck::Tuple<>"
);
params
[
8
]
=
s
;
}
void
set_ds_op
(
const
std
::
string
&
s
)
{
assert
(
params
[
12
]
==
"ck_passthrough"
);
params
[
12
]
=
s
;
}
void
set_gemm
(
const
std
::
string
&
s
)
{
assert
(
params
[
13
]
==
"ck::tensor_operation::device::GemmSpecialization::Default"
);
params
[
13
]
=
s
;
}
std
::
string
str
()
const
{
return
join_strings
(
params
,
","
);
}
};
template
<
class
F
,
class
Action
>
auto
action_decorate
(
F
f
,
Action
action
)
{
return
[
=
](
auto
&&
...
xs
)
{
action
();
f
(
std
::
forward
<
decltype
(
xs
)
>
(
xs
)...);
};
}
using
tuning_entry
=
std
::
pair
<
std
::
vector
<
shape
>
,
size_t
>
;
static
std
::
vector
<
tuning_entry
>
read_tuning
(
const
std
::
string
&
s
)
{
if
(
not
fs
::
exists
(
s
))
return
{};
return
from_value
<
std
::
vector
<
tuning_entry
>>
(
from_json_string
(
read_string
(
s
)));
}
static
std
::
size_t
get_tuning_for
(
const
std
::
vector
<
shape
>&
inputs
)
{
static
auto
tuning
=
read_tuning
(
string_value_of
(
MIGRAPHX_CK_TUNING
{},
""
));
if
(
tuning
.
empty
())
std
::
cout
<<
"*********** Warning: No CK tuning!"
<<
std
::
endl
;
auto
it
=
std
::
find_if
(
tuning
.
begin
(),
tuning
.
end
(),
[
&
](
const
auto
&
p
)
{
return
p
.
first
==
inputs
;
});
if
(
it
==
tuning
.
end
())
{
std
::
cout
<<
"*********** Warning: CK tuning missing for config!"
<<
std
::
endl
;
return
4
;
}
return
it
->
second
;
}
struct
ck_gemm_softmax_gemm_compiler
:
compiler
<
ck_gemm_softmax_gemm_compiler
>
{
static
std
::
string
get_layout
(
const
shape
&
s
)
{
return
s
.
transposed
()
?
"ck::tensor_layout::gemm::ColumnMajor"
:
"ck::tensor_layout::gemm::RowMajor"
;
}
static
std
::
string
get_type
(
const
shape
&
s
)
{
if
(
s
.
type
()
==
shape
::
half_type
)
return
"ck::half_t"
;
return
shape
::
cpp_type
(
s
.
type
());
}
template
<
class
Iterator
,
class
F
>
static
std
::
string
ck_tuple
(
Iterator
start
,
Iterator
last
,
F
f
)
{
std
::
vector
<
std
::
string
>
s
;
std
::
transform
(
start
,
last
,
std
::
back_inserter
(
s
),
f
);
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
m
=
a_shape
.
lens
()[
0
];
auto
k
=
a_shape
.
lens
()[
1
];
auto
n
=
c_shape
.
lens
()[
1
];
auto
rank
=
a_shape
.
lens
().
size
();
// std::array<char, 3> keys{'M', 'N', 'K'};
// std::array<std::size_t, 3> config{
// c_shape.lens()[rank - 2], c_shape.lens().back(), a_shape.lens().back()};
// auto tuning_val = v.get("tuning_val", get_tuning_for({a_shape, b_shape, c_shape}));
// auto ip = instance{get_instance(tuning_val, [&](const auto& x) -> bool {
// return get_layout(a_shape) == x[0] and get_layout(b_shape) == x[1] and
// get_layout(c_shape) == x[3] and get_type(a_shape) == x[4] and
// get_type(b_shape) == x[5] and get_type(c_shape) == x[9];
// })};
// assert(inputs.size() < 4 or v.contains("post"));
// if(v.contains("post"))
// {
// ip.set_ds_layout(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_layout));
// ip.set_ds_type(ck_tuple(inputs.begin() + 2, inputs.end() - 1, &get_type));
// ip.set_ds_op(v.at("post").to<std::string>());
// }
// auto padding = ip.get_pad(config);
// std::string gemm_type;
// for(auto i : range(padding.size()))
// {
// if(padding[i] != 0)
// gemm_type += keys[i];
// }
// if(gemm_type.empty())
// gemm_type = "Default";
// else
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto
gemm1_nperblock
=
64
;
auto
gemm01_mperblock
=
256
;
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
int_div_ceil
(
n
,
gemm1_nperblock
);
//ip.get_grid_size(config);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
hip_compile_options
options
;
auto
block_size
=
256
;
//ip.get_block_size();
auto
grid_size
=
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_softmax_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
if
(
v
.
get
(
"check"
,
false
)
or
enabled
(
MIGRAPHX_CK_DEBUG
{}))
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
{{
"instance"
,
""
/* ip.str() */
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"preamble"
,
v
.
get
(
"preamble"
,
std
::
string
{})},
{
"kernel"
,
options
.
kernel_name
}});
return
compile_hip_code_object
(
src
,
options
);
}
compiler_replace
compile
(
context
&
ctx
,
instruction_ref
ins
,
const
operation
&
op
)
const
{
auto
v
=
op
.
to_value
();
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_kernel"
;
if
(
not
ins
->
module_inputs
().
empty
())
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
auto
shapes
=
to_shapes
(
ins
->
inputs
());
return
action_decorate
(
replace
(
compile_op
(
ctx
,
shapes
,
v
)),
[
=
]
{
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
});
}
};
}
// namespace gpu
}
// namespace MIGRAPHX_INLINE_NS
}
// namespace migraphx
src/targets/gpu/jit/ck_gsg_instances.cpp
0 → 100644
View file @
d7ea085c
This diff is collapsed.
Click to expand it.
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
View file @
d7ea085c
...
@@ -109,6 +109,28 @@ struct ck_passthrough
...
@@ -109,6 +109,28 @@ struct ck_passthrough
}
}
};
};
struct
ck_scale
{
constexpr
ck_scale
(
float
s
)
:
scale
(
s
)
{}
template
<
class
T
,
class
U
>
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
{
y
=
x
*
static_cast
<
U
>
(
scale
);
}
float
scale
;
};
struct
ck_add
{
template
<
class
T
,
class
U
>
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
{
y
+=
x
;
}
};
#ifdef MIGRAPHX_CK_CHECK
#ifdef MIGRAPHX_CK_CHECK
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#define MIGRAPHX_CK_STATIC_ASSERT static_assert
#else
#else
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
0 → 100644
View file @
d7ea085c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GEMM_SOFTMAX_GEMM_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GEMM_SOFTMAX_GEMM_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <migraphx/kernels/ck.hpp>
#include <migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp>
#include <migraphx/kernels/gemm_batcher.hpp>
namespace
migraphx
{
// In CK, the B matrix is ordered as N,K instead of K,N
template
<
class
Dims
>
constexpr
auto
ck_transposeb_dims
(
Dims
dims
)
{
return
unpack
(
dims
,
[](
auto
k
,
auto
n
)
{
return
make_const_array
(
n
,
k
);
});
}
template
<
class
Tensor
>
using
ck_transposeb
=
decltype
(
make_shape
(
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
lens
),
ck_transposeb_dims
(
get_shape_c
<
Tensor
>
{}.
strides
)));
template
<
class
G
,
class
C
,
class
A
,
class
B
,
class
B1
>
__device__
void
ck_gemm_softmax_gemm_matrix
(
C
c
,
A
a
,
B
b
,
B1
b1
)
{
constexpr
const
G
gemm
{};
constexpr
const
auto
a_shape
=
get_shape_c
<
A
>
{};
constexpr
const
auto
m
=
a_shape
.
lens
[
0
];
constexpr
const
auto
k
=
a_shape
.
lens
[
1
];
constexpr
const
auto
sa
=
a_shape
.
strides
[
0
];
constexpr
const
auto
a_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
k
),
ck
::
make_tuple
(
sa
,
1
));
constexpr
const
auto
a_grid_desc_mraw_kraw
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
a_tensor
);
constexpr
const
auto
AK1
=
gemm
.
get_AK1
();
constexpr
const
auto
AK0
=
k
/
AK1
;
constexpr
const
auto
a_grid_desc_ak0_m_ak1
=
ck
::
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_pass_through_transform
(
m
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
n
=
b_shape
.
lens
[
0
];
// col-major
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK0
=
k
/
BK1
;
constexpr
const
auto
b_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n
,
k
),
ck
::
make_tuple
(
sb
,
1
));
constexpr
const
auto
b_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
b_tensor
);
constexpr
const
auto
b_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_pass_through_transform
(
n
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
// row-major
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
// row-major
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// rowl-major
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
constexpr
const
auto
b1_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n1
,
k1
),
ck
::
make_tuple
(
1
,
sb1
));
constexpr
const
auto
b1_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadB1Descriptor_N_K
(
b1_tensor
);
constexpr
const
auto
b1_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
B1K0
,
B1K1
)),
ck
::
make_pass_through_transform
(
n1
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
c_shape
=
get_shape_c
<
C
>
{};
constexpr
const
auto
sc
=
c_shape
.
strides
[
0
];
constexpr
const
auto
c_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
n1
),
ck
::
make_tuple
(
sc
,
1
));
constexpr
const
auto
c_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
c_tensor
);
constexpr
const
auto
MPerBlock
=
gemm
.
get_mperblock
();
constexpr
const
auto
Gemm1NPerBlock
=
gemm
.
get_gemm1nperblock
();
constexpr
const
auto
MBlock
=
m
/
MPerBlock
;
constexpr
const
auto
NBlock
=
n1
/
Gemm1NPerBlock
;
constexpr
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
ck
::
transform_tensor_descriptor
(
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
MBlock
,
ck
::
Number
<
MPerBlock
>
{})),
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
NBlock
,
ck
::
Number
<
Gemm1NPerBlock
>
{}))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
,
3
>
{}));
constexpr
const
auto
block_2_ctile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
(
c_grid_desc_m_n
);
const
C0MatrixMask
c0_matrix_mask
(
n
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{});
using
gridwise
=
typename
G
::
template
rt_gridwisegemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_m_n
)>;
using
GridwiseGemm
=
typename
gridwise
::
GridwiseGemm
;
constexpr
const
bool
HasMainKBlockLoop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
static_assert
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
block_2_ctile_map
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
to_ck_const_pointer
(
b1
.
data
()),
to_ck_pointer
(
c
.
data
()),
p_shared
,
gemm
.
a_element_op
,
gemm
.
b_element_op
,
gemm
.
acc_element_op
,
gemm
.
b1_element_op
,
gemm
.
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
,
c0_matrix_mask
);
}
template
<
class
G
,
index_int
BlocksPerBatch
,
class
...
Ts
>
__device__
void
ck_gemm_softmax_gemm
(
Ts
...
xs
)
{
gemm_batch_args
(
make_index
(),
_c
<
BlocksPerBatch
>
,
xs
...)(
[](
auto
...
ys
)
{
ck_gemm_softmax_gemm_matrix
<
G
>
(
ys
...);
});
}
}
// namespace migraphx
#endif
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
0 → 100644
View file @
d7ea085c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_KERNELS_CK_GSG_INCLUDES_HPP
#define MIGRAPHX_GUARD_KERNELS_CK_GSG_INCLUDES_HPP
#include <migraphx/kernels/index.hpp>
#include <migraphx/kernels/algorithm.hpp>
#include <migraphx/kernels/integral_constant.hpp>
#include <migraphx/kernels/tensor_view.hpp>
#include <ratio>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_batched_gemm_softmax_gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_xdl_cshuffle_v1.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/tensor_specialization.hpp"
namespace
migraphx
{
template
<
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt
{
static
constexpr
auto
I0
=
ck
::
Number
<
0
>
{};
static
constexpr
auto
I1
=
ck
::
Number
<
1
>
{};
static
constexpr
auto
I2
=
ck
::
Number
<
2
>
{};
static
constexpr
auto
I3
=
ck
::
Number
<
3
>
{};
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
ck
::
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
}
__host__
__device__
constexpr
ck
::
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
ck
::
index_t
grid_size
=
M0
*
N0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
ck
::
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
ck
::
index_t
idx_N0
=
block_1d_id
%
N0
;
ck
::
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
ck
::
index_t
idx_M00
=
idx_M0
/
M01_
;
ck
::
index_t
idx_M01
=
idx_M0
%
M01_
;
ck
::
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
ck
::
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
constexpr
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
ck
::
index_t
M01_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
// to track the points which need to be set to -inf on C0
// Note: no need to reset M padding value, because they will not be stored out.
struct
C0MatrixMask
{
__device__
C0MatrixMask
(
ck
::
index_t
NRaw
)
:
NRaw_
(
NRaw
)
{}
__device__
bool
IsUpperTriangle
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
return
n
>
m
;
}
__device__
bool
IsNOutOfBound
(
/*ck::index_t m, */
ck
::
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
__device__
bool
IsMaskedElement
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
return
IsUpperTriangle
(
m
,
n
)
||
IsNOutOfBound
(
n
);
}
private:
// ck::index_t MRaw_;
ck
::
index_t
NRaw_
;
};
template
<
typename
ALayout
,
typename
BLayout
,
// B0Layout
typename
B1Layout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
B1DataType
,
typename
CDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
AccElementwiseOperation
,
typename
B1ElementwiseOperation
,
typename
CElementwiseOperation
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
ck
::
index_t
NumGemmKPrefetchStage
,
ck
::
index_t
BlockSize
,
ck
::
index_t
MPerBlock
,
ck
::
index_t
NPerBlock
,
// Gemm0NPerBlock
ck
::
index_t
KPerBlock
,
// Gemm0KPerBlock
ck
::
index_t
Gemm1NPerBlock
,
ck
::
index_t
Gemm1KPerBlock
,
ck
::
index_t
AK1
,
ck
::
index_t
BK1
,
ck
::
index_t
B1K1
,
ck
::
index_t
MPerXDL
,
ck
::
index_t
NPerXDL
,
ck
::
index_t
MXdlPerWave
,
ck
::
index_t
NXdlPerWave
,
ck
::
index_t
Gemm1NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
ck
::
index_t
ABlockTransferSrcVectorDim
,
ck
::
index_t
ABlockTransferSrcScalarPerVector
,
ck
::
index_t
ABlockTransferDstScalarPerVector_AK1
,
bool
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
ck
::
index_t
BBlockTransferSrcVectorDim
,
ck
::
index_t
BBlockTransferSrcScalarPerVector
,
ck
::
index_t
BBlockTransferDstScalarPerVector_BK1
,
bool
BBlockLdsExtraN
,
typename
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
B1BlockTransferThreadClusterArrangeOrder
,
typename
B1BlockTransferSrcAccessOrder
,
ck
::
index_t
B1BlockTransferSrcVectorDim
,
ck
::
index_t
B1BlockTransferSrcScalarPerVector
,
ck
::
index_t
B1BlockTransferDstScalarPerVector_BK1
,
bool
B1BlockLdsExtraN
,
ck
::
index_t
CShuffleMXdlPerWavePerShuffle
,
ck
::
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
ck
::
index_t
CShuffleBlockTransferScalarPerVector_NPerBlock
,
bool
MaskOutUpperTriangle
,
typename
Alpha
,
ck
::
LoopScheduler
LoopSched
=
ck
::
LoopScheduler
::
Default
>
struct
CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
GemmGemmPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
static
constexpr
auto
get_AK1
()
{
return
AK1
;
};
static
constexpr
auto
get_BK1
()
{
return
BK1
;
};
static
constexpr
auto
get_B1K1
()
{
return
B1K1
;
};
static
constexpr
auto
get_mperblock
()
{
return
MPerBlock
;
};
static
constexpr
auto
get_gemm1nperblock
()
{
return
Gemm1NPerBlock
;
};
static
constexpr
float
alpha
=
float
(
Alpha
::
num
)
/
Alpha
::
den
;
static
constexpr
auto
get_alpha
()
{
return
alpha
;
};
AElementwiseOperation
a_element_op
{};
BElementwiseOperation
b_element_op
{};
B1ElementwiseOperation
b1_element_op
{};
CElementwiseOperation
c_element_op
{};
AccElementwiseOperation
acc_element_op
{
alpha
};
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
>
struct
rt_gridwisegemm
{
// GridwiseGemm
using
GridwiseGemm
=
ck
::
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
AccElementwiseOperation
,
B1ElementwiseOperation
,
CElementwiseOperation
,
ck
::
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
B1GridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
,
Gemm1KPerBlock
,
AK1
,
BK1
,
B1K1
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
Gemm1NXdlPerWave
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_AK1
,
true
,
ABlockLdsExtraM
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferDstScalarPerVector_BK1
,
true
,
BBlockLdsExtraN
,
B1BlockTransferThreadClusterLengths_BK0_N_BK1
,
B1BlockTransferThreadClusterArrangeOrder
,
B1BlockTransferSrcAccessOrder
,
B1BlockTransferSrcVectorDim
,
B1BlockTransferSrcScalarPerVector
,
B1BlockTransferDstScalarPerVector_BK1
,
false
,
B1BlockLdsExtraN
,
CShuffleMXdlPerWavePerShuffle
,
CShuffleNXdlPerWavePerShuffle
,
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
CShuffleBlockTransferScalarPerVector_NPerBlock
,
LoopSched
,
matrix_padder
.
PadN
,
MaskOutUpperTriangle
>
;
};
};
}
// namespace migraphx
#endif
test/onnx/gemm_softmax_gemm_test.onnx
0 → 100644
View file @
d7ea085c
gemm_softmax_gemm_test:
a
b gemm1_out"MatMul
!
gemm1_out
scalemul1_out"Mul
mul1_out
cadd1_out"Add
add1_outsoftmax_out"Softmax
softmax_out
b1out"MatMulgemm_softmax_gemm_test*
*`BscaleZ
a
Z
b
Z
c
Z
b1
Z
bias
b
out
B
\ No newline at end of file
test/onnx/gen_onnx.py
View file @
d7ea085c
...
@@ -1986,6 +1986,42 @@ def gemm_half_test():
...
@@ -1986,6 +1986,42 @@ def gemm_half_test():
return
([
node
],
[
m1
,
m2
,
m3
],
[
y
])
return
([
node
],
[
m1
,
m2
,
m3
],
[
y
])
@
onnx_test
def
gemm_softmax_gemm_test
():
a
=
helper
.
make_tensor_value_info
(
'a'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
b
=
helper
.
make_tensor_value_info
(
'b'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
c
=
helper
.
make_tensor_value_info
(
'c'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
b1
=
helper
.
make_tensor_value_info
(
'b1'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
out
=
helper
.
make_tensor_value_info
(
'out'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
scale_array
=
np
.
array
([(
1
/
8
)])
scale_tensor
=
helper
.
make_tensor
(
name
=
'scale'
,
data_type
=
TensorProto
.
FLOAT16
,
dims
=
scale_array
.
shape
,
vals
=
scale_array
.
flatten
().
astype
(
np
.
float16
))
gemm1
=
onnx
.
helper
.
make_node
(
'MatMul'
,
inputs
=
[
'a'
,
'b'
],
outputs
=
[
'gemm1_out'
])
mul1
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'gemm1_out'
,
'scale'
],
outputs
=
[
'mul1_out'
])
add1
=
onnx
.
helper
.
make_node
(
'Add'
,
inputs
=
[
'mul1_out'
,
'c'
],
outputs
=
[
'add1_out'
])
softmax
=
onnx
.
helper
.
make_node
(
'Softmax'
,
inputs
=
[
'add1_out'
],
outputs
=
[
'softmax_out'
])
gemm2
=
onnx
.
helper
.
make_node
(
'MatMul'
,
inputs
=
[
'softmax_out'
,
'b1'
],
outputs
=
[
'out'
])
return
([
gemm1
,
mul1
,
add1
,
softmax
,
gemm2
],
[
a
,
b
,
c
,
b1
,
bias
],
[
out
],
[
scale_tensor
])
@
onnx_test
@
onnx_test
def
globalavgpool_test
():
def
globalavgpool_test
():
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
1
,
3
,
16
,
16
])
x
=
helper
.
make_tensor_value_info
(
'0'
,
TensorProto
.
FLOAT
,
[
1
,
3
,
16
,
16
])
...
...
test/verify/0ck_gemm_softmax_gemm.cpp
0 → 100644
View file @
d7ea085c
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_gemm_softmax_gemm
:
verify_program
<
ck_gemm_softmax_gemm
>
{
migraphx
::
program
create_program
()
const
{
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
1
*
12
*
256
*
256
;
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
auto
zero
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
zeros
});
std
::
vector
<
float
>
ones
(
m2_elements
,
1
);
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment