Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
3133fd79
Commit
3133fd79
authored
Nov 15, 2022
by
Alan Turner
Browse files
Formatting
parent
d7ea085c
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
140 additions
and
126 deletions
+140
-126
src/targets/gpu/fuse_ck.cpp
src/targets/gpu/fuse_ck.cpp
+23
-16
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
+17
-11
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
+1
-1
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
...kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
+70
-67
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
...nclude/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
+8
-12
test/onnx/gen_onnx.py
test/onnx/gen_onnx.py
+14
-12
test/verify/0ck_gemm_softmax_gemm.cpp
test/verify/0ck_gemm_softmax_gemm.cpp
+7
-7
No files found.
src/targets/gpu/fuse_ck.cpp
View file @
3133fd79
...
@@ -74,12 +74,12 @@ struct ck_gemm_scale_bias_softmax_gemm
...
@@ -74,12 +74,12 @@ struct ck_gemm_scale_bias_softmax_gemm
// MIGRAPHX_THROW("should have one submodule.");
// MIGRAPHX_THROW("should have one submodule.");
if
(
inputs
.
size
()
<
2
)
if
(
inputs
.
size
()
<
2
)
MIGRAPHX_THROW
(
"should have at least two inputs."
);
MIGRAPHX_THROW
(
"should have at least two inputs."
);
auto
a
=
inputs
[
0
];
auto
a
=
inputs
[
0
];
auto
b
=
inputs
[
1
];
auto
b
=
inputs
[
1
];
auto
b1
=
inputs
[
2
];
auto
b1
=
inputs
[
2
];
for
(
const
auto
&
input
:
inputs
)
for
(
const
auto
&
input
:
inputs
)
{
{
//std::cout << input << std::endl;
//
std::cout << input << std::endl;
check_gemm_shape
(
input
);
check_gemm_shape
(
input
);
}
}
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
return
op
.
compute_shape
({
op
.
compute_shape
({
a
,
b
}),
b1
});
...
@@ -158,19 +158,22 @@ struct find_ck_gemm_scale_bias_softmax_gemm
...
@@ -158,19 +158,22 @@ struct find_ck_gemm_scale_bias_softmax_gemm
{
{
auto
matcher
()
const
auto
matcher
()
const
{
{
auto
gemm1
=
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
gemm1
=
auto
pw
=
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
match
::
skip
(
match
::
name
(
"contiguous"
))(
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm1"
)));
auto
pw
=
match
::
name
(
"pointwise"
)(
match
::
any_of
[
match
::
inputs
()](
gemm1
)).
bind
(
"scale_bias"
);
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
pw
)).
bind
(
"softmax"
);
auto
softmax
=
match
::
name
(
"softmax"
)(
match
::
any_of
[
match
::
inputs
()](
pw
)).
bind
(
"softmax"
);
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
return
match
::
name
(
"dot"
)(
is_ck_gemm
().
bind
(
"gemm2"
))(
match
::
any_of
[
match
::
inputs
()](
softmax
));
}
}
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
void
apply
(
module_pass_manager
&
mpm
,
const
match
::
matcher_result
&
r
)
const
{
{
std
::
cout
<<
"Matched"
<<
std
::
endl
;
std
::
cout
<<
"Matched"
<<
std
::
endl
;
auto
ins
=
r
.
result
;
auto
ins
=
r
.
result
;
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
gemm2_ins
=
r
.
instructions
[
"gemm2"
];
auto
sm_ins
=
r
.
instructions
[
"softmax"
];
auto
sm_ins
=
r
.
instructions
[
"softmax"
];
auto
pw_ins
=
r
.
instructions
[
"scale_bias"
];
auto
pw_ins
=
r
.
instructions
[
"scale_bias"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
auto
gemm1_ins
=
r
.
instructions
[
"gemm1"
];
gemm2_ins
->
debug_print
();
gemm2_ins
->
debug_print
();
...
@@ -178,18 +181,21 @@ struct find_ck_gemm_scale_bias_softmax_gemm
...
@@ -178,18 +181,21 @@ struct find_ck_gemm_scale_bias_softmax_gemm
pw_ins
->
debug_print
();
pw_ins
->
debug_print
();
gemm1_ins
->
debug_print
();
gemm1_ins
->
debug_print
();
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
auto
inputs
=
gemm1_ins
->
inputs
();
// A, B
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
inputs
.
push_back
(
gemm2_ins
->
inputs
().
back
());
// B1
//inputs.push_back(pw_ins->inputs().back()); // C
//
inputs.push_back(pw_ins->inputs().back()); // C
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
mpm
.
get_module
().
replace_instruction
(
ins
,
ck_gemm_scale_bias_softmax_gemm
{
gemm2_ins
->
get_operator
()},
inputs
);
}
}
// auto matcher() const
// auto matcher() const
// {
// {
// auto gemm1 = match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// auto gemm1 =
// auto softmax = match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax");
// match::skip(match::name("contiguous"))(match::name("dot")(is_ck_gemm().bind("gemm1")));
// return match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// auto softmax =
// match::name("softmax")(match::any_of[match::inputs()](gemm1)).bind("softmax"); return
// match::name("dot")(is_ck_gemm().bind("gemm2"))(match::any_of[match::inputs()](softmax));
// }
// }
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
// void apply(module_pass_manager& mpm, const match::matcher_result& r) const
...
@@ -207,7 +213,8 @@ struct find_ck_gemm_scale_bias_softmax_gemm
...
@@ -207,7 +213,8 @@ struct find_ck_gemm_scale_bias_softmax_gemm
// auto inputs = gemm1_ins->inputs(); // A, B
// auto inputs = gemm1_ins->inputs(); // A, B
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// inputs.push_back(gemm2_ins->inputs().back()); // B1
// mpm.get_module().replace_instruction(ins, ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// mpm.get_module().replace_instruction(ins,
// ck_gemm_scale_bias_softmax_gemm{gemm2_ins->get_operator()}, inputs);
// }
// }
};
};
...
...
src/targets/gpu/jit/ck_gemm_softmax_gemm.cpp
View file @
3133fd79
...
@@ -213,16 +213,19 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -213,16 +213,19 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
return
"ck::Tuple<"
+
join_strings
(
s
,
","
)
+
">"
;
}
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm_softmax_gemm"
,
"gpu::ck_gemm_softmax_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
{
{
auto
a_shape
=
inputs
[
0
];
auto
a_shape
=
inputs
[
0
];
auto
b_shape
=
inputs
[
1
];
auto
b_shape
=
inputs
[
1
];
auto
c_shape
=
inputs
.
back
();
auto
c_shape
=
inputs
.
back
();
auto
m
=
a_shape
.
lens
()[
0
];
auto
m
=
a_shape
.
lens
()[
0
];
auto
k
=
a_shape
.
lens
()[
1
];
auto
k
=
a_shape
.
lens
()[
1
];
auto
n
=
c_shape
.
lens
()[
1
];
auto
n
=
c_shape
.
lens
()[
1
];
auto
rank
=
a_shape
.
lens
().
size
();
auto
rank
=
a_shape
.
lens
().
size
();
...
@@ -257,16 +260,17 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -257,16 +260,17 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
// gemm_type += "Padding";
// gemm_type += "Padding";
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
// ip.set_gemm("ck::tensor_operation::device::GemmSpecialization::" + gemm_type);
auto
gemm1_nperblock
=
64
;
auto
gemm1_nperblock
=
64
;
auto
gemm01_mperblock
=
256
;
auto
gemm01_mperblock
=
256
;
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
int_div_ceil
(
n
,
gemm1_nperblock
);
//ip.get_grid_size(config);
auto
blocks_per_batch
=
int_div_ceil
(
m
,
gemm01_mperblock
)
*
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
int_div_ceil
(
n
,
gemm1_nperblock
);
// ip.get_grid_size(config);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
std
::
multiplies
<
std
::
size_t
>
());
hip_compile_options
options
;
hip_compile_options
options
;
auto
block_size
=
256
;
//ip.get_block_size();
auto
block_size
=
256
;
//
ip.get_block_size();
auto
grid_size
=
batch_count
*
blocks_per_batch
;
auto
grid_size
=
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
...
@@ -278,7 +282,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -278,7 +282,7 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
auto
src
=
interpolate_string
(
ck_gemm_softmax_gemm_kernel
,
{{
"instance"
,
""
/* ip.str() */
},
{{
"instance"
,
""
/* ip.str() */
},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"params"
,
enum_params
(
inputs
.
size
(),
"void * private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"args"
,
enum_params
(
inputs
.
size
(),
"private_p"
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
{
"blocks_per_batch"
,
to_string
(
blocks_per_batch
)},
...
@@ -296,7 +300,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -296,7 +300,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
{
{
auto
*
pm
=
ins
->
module_inputs
().
front
();
auto
*
pm
=
ins
->
module_inputs
().
front
();
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
v
[
"preamble"
]
=
generate_pointwise
(
*
pm
,
"post_ck_gemm_softmax_gemm_function"
)
+
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, post_ck_gemm_softmax_gemm_function);"
;
"
\n
MIGRAPHX_LIFT_CLASS(post_ck_gemm_softmax_gemm, "
"post_ck_gemm_softmax_gemm_function);"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"post"
]
=
"ck_function_adaptor<post_ck_gemm_softmax_gemm>"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
v
[
"kernel"
]
=
"ck_gemm_softmax_gemm_"
+
generate_name_from_ops
(
*
pm
)
+
"_kernel"
;
}
}
...
@@ -306,7 +311,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
...
@@ -306,7 +311,8 @@ struct ck_gemm_softmax_gemm_compiler : compiler<ck_gemm_softmax_gemm_compiler>
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
if
(
enabled
(
MIGRAPHX_LOG_CK_GEMM
{}))
{
{
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
vector
<
shape
>
gemm_shapes
{
shapes
[
0
],
shapes
[
1
],
shapes
.
back
()};
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
std
::
cout
<<
"ck_gemm_softmax_gemm: "
<<
to_json_string
(
to_value
(
gemm_shapes
))
<<
std
::
endl
;
}
}
});
});
}
}
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck.hpp
View file @
3133fd79
...
@@ -116,7 +116,7 @@ struct ck_scale
...
@@ -116,7 +116,7 @@ struct ck_scale
template
<
class
T
,
class
U
>
template
<
class
T
,
class
U
>
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
constexpr
void
operator
()(
T
&
y
,
U
x
)
const
{
{
y
=
x
*
static_cast
<
U
>
(
scale
);
y
=
x
*
static_cast
<
U
>
(
scale
);
}
}
float
scale
;
float
scale
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm.hpp
View file @
3133fd79
...
@@ -51,96 +51,99 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
...
@@ -51,96 +51,99 @@ __device__ void ck_gemm_softmax_gemm_matrix(C c, A a, B b, B1 b1)
constexpr
const
G
gemm
{};
constexpr
const
G
gemm
{};
constexpr
const
auto
a_shape
=
get_shape_c
<
A
>
{};
constexpr
const
auto
a_shape
=
get_shape_c
<
A
>
{};
constexpr
const
auto
m
=
a_shape
.
lens
[
0
];
constexpr
const
auto
m
=
a_shape
.
lens
[
0
];
constexpr
const
auto
k
=
a_shape
.
lens
[
1
];
constexpr
const
auto
k
=
a_shape
.
lens
[
1
];
constexpr
const
auto
sa
=
a_shape
.
strides
[
0
];
constexpr
const
auto
sa
=
a_shape
.
strides
[
0
];
constexpr
const
auto
a_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
k
),
constexpr
const
auto
a_tensor
=
ck
::
make_tuple
(
sa
,
1
));
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
k
),
ck
::
make_tuple
(
sa
,
1
));
constexpr
const
auto
a_grid_desc_mraw_kraw
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
a_tensor
);
constexpr
const
auto
a_grid_desc_mraw_kraw
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
a_tensor
);
constexpr
const
auto
AK1
=
gemm
.
get_AK1
();
constexpr
const
auto
AK1
=
gemm
.
get_AK1
();
constexpr
const
auto
AK0
=
k
/
AK1
;
constexpr
const
auto
AK0
=
k
/
AK1
;
constexpr
const
auto
a_grid_desc_ak0_m_ak1
=
ck
::
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
constexpr
const
auto
a_grid_desc_ak0_m_ak1
=
ck
::
transform_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
a_grid_desc_mraw_kraw
,
ck
::
make_pass_through_transform
(
m
)),
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
AK0
,
AK1
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_pass_through_transform
(
m
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
b_shape
=
get_shape_c
<
B
>
{};
constexpr
const
auto
n
=
b_shape
.
lens
[
0
];
// col-major
constexpr
const
auto
n
=
b_shape
.
lens
[
0
];
// col-major
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
sb
=
b_shape
.
strides
[
1
];
// col-major
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK1
=
gemm
.
get_BK1
();
constexpr
const
auto
BK0
=
k
/
BK1
;
constexpr
const
auto
BK0
=
k
/
BK1
;
constexpr
const
auto
b_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n
,
k
),
constexpr
const
auto
b_tensor
=
ck
::
make_tuple
(
sb
,
1
));
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n
,
k
),
ck
::
make_tuple
(
sb
,
1
));
constexpr
const
auto
b_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
b_tensor
);
constexpr
const
auto
b_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
b_tensor
);
constexpr
const
auto
b_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
constexpr
const
auto
b_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
b_grid_desc_nraw_kraw
,
ck
::
make_pass_through_transform
(
n
)),
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
BK0
,
BK1
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_pass_through_transform
(
n
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
b1_shape
=
get_shape_c
<
B1
>
{};
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
// row-major
constexpr
const
auto
k1
=
b1_shape
.
lens
[
0
];
// row-major
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
// row-major
constexpr
const
auto
n1
=
b1_shape
.
lens
[
1
];
// row-major
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// rowl-major
constexpr
const
auto
sb1
=
b1_shape
.
strides
[
0
];
// rowl-major
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K1
=
gemm
.
get_B1K1
();
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
constexpr
const
auto
B1K0
=
k1
/
B1K1
;
constexpr
const
auto
b1_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n1
,
k1
),
constexpr
const
auto
b1_tensor
=
ck
::
make_tuple
(
1
,
sb1
));
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
n1
,
k1
),
ck
::
make_tuple
(
1
,
sb1
));
constexpr
const
auto
b1_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadB1Descriptor_N_K
(
b1_tensor
);
constexpr
const
auto
b1_grid_desc_nraw_kraw
=
gemm
.
matrix_padder
.
PadB1Descriptor_N_K
(
b1_tensor
);
constexpr
const
auto
b1_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
b1_grid_desc_nraw_kraw
,
constexpr
const
auto
b1_grid_desc_bk0_n_bk1
=
ck
::
transform_tensor_descriptor
(
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
B1K0
,
B1K1
)),
b1_grid_desc_nraw_kraw
,
ck
::
make_pass_through_transform
(
n1
)),
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
B1K0
,
B1K1
)),
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_pass_through_transform
(
n1
)),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
1
>
{},
ck
::
Sequence
<
0
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
2
>
{},
ck
::
Sequence
<
1
>
{}));
constexpr
const
auto
c_shape
=
get_shape_c
<
C
>
{};
constexpr
const
auto
c_shape
=
get_shape_c
<
C
>
{};
constexpr
const
auto
sc
=
c_shape
.
strides
[
0
];
constexpr
const
auto
sc
=
c_shape
.
strides
[
0
];
constexpr
const
auto
c_tensor
=
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
n1
),
constexpr
const
auto
c_tensor
=
ck
::
make_tuple
(
sc
,
1
));
ck
::
make_naive_tensor_descriptor
(
ck
::
make_tuple
(
m
,
n1
),
ck
::
make_tuple
(
sc
,
1
));
constexpr
const
auto
c_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
c_tensor
);
constexpr
const
auto
c_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
c_tensor
);
constexpr
const
auto
MPerBlock
=
gemm
.
get_mperblock
();
constexpr
const
auto
MPerBlock
=
gemm
.
get_mperblock
();
constexpr
const
auto
Gemm1NPerBlock
=
gemm
.
get_gemm1nperblock
();
constexpr
const
auto
Gemm1NPerBlock
=
gemm
.
get_gemm1nperblock
();
constexpr
const
auto
MBlock
=
m
/
MPerBlock
;
constexpr
const
auto
MBlock
=
m
/
MPerBlock
;
constexpr
const
auto
NBlock
=
n1
/
Gemm1NPerBlock
;
constexpr
const
auto
NBlock
=
n1
/
Gemm1NPerBlock
;
constexpr
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
constexpr
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
ck
::
transform_tensor_descriptor
(
ck
::
transform_tensor_descriptor
(
c_grid_desc_m_n
,
c_grid_desc_m_n
,
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
MBlock
,
ck
::
Number
<
MPerBlock
>
{})),
ck
::
make_tuple
(
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
NBlock
,
ck
::
Number
<
Gemm1NPerBlock
>
{}))),
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
MBlock
,
ck
::
Number
<
MPerBlock
>
{})),
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_unmerge_transform
(
ck
::
make_tuple
(
NBlock
,
ck
::
Number
<
Gemm1NPerBlock
>
{}))),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
,
3
>
{}));
ck
::
make_tuple
(
ck
::
Sequence
<
0
>
{},
ck
::
Sequence
<
1
>
{}),
ck
::
make_tuple
(
ck
::
Sequence
<
0
,
1
>
{},
ck
::
Sequence
<
2
,
3
>
{}));
constexpr
const
auto
block_2_ctile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
(
c_grid_desc_m_n
);
constexpr
const
auto
block_2_ctile_map
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
Gemm1NPerBlock
,
decltype
(
c_grid_desc_m_n
)
>
(
c_grid_desc_m_n
);
const
C0MatrixMask
c0_matrix_mask
(
n
);
const
C0MatrixMask
c0_matrix_mask
(
n
);
const
auto
K
=
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{});
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{});
using
gridwise
=
typename
G
::
template
rt_gridwisegemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
using
gridwise
=
typename
G
::
template
rt_gridwisegemm
<
decltype
(
a_grid_desc_ak0_m_ak1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b_grid_desc_bk0_n_bk1
),
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
b1_grid_desc_bk0_n_bk1
),
decltype
(
c_grid_desc_m_n
)>;
decltype
(
c_grid_desc_m_n
)>;
using
GridwiseGemm
=
typename
gridwise
::
GridwiseGemm
;
using
GridwiseGemm
=
typename
gridwise
::
GridwiseGemm
;
constexpr
const
bool
HasMainKBlockLoop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
constexpr
const
bool
HasMainKBlockLoop
=
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
static_assert
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
static_assert
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
b1_grid_desc_bk0_n_bk1
,
c_grid_desc_m_n
,
c_grid_desc_m_n
,
block_2_ctile_map
));
block_2_ctile_map
));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
to_ck_const_pointer
(
a
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
to_ck_const_pointer
(
b
.
data
()),
to_ck_const_pointer
(
b1
.
data
()),
to_ck_const_pointer
(
b1
.
data
()),
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm_softmax_gemm_includes.hpp
View file @
3133fd79
...
@@ -121,10 +121,7 @@ struct C0MatrixMask
...
@@ -121,10 +121,7 @@ struct C0MatrixMask
__device__
bool
IsUpperTriangle
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
return
n
>
m
;
}
__device__
bool
IsUpperTriangle
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
return
n
>
m
;
}
__device__
bool
IsNOutOfBound
(
/*ck::index_t m, */
ck
::
index_t
n
)
const
__device__
bool
IsNOutOfBound
(
/*ck::index_t m, */
ck
::
index_t
n
)
const
{
return
n
>=
NRaw_
;
}
{
return
n
>=
NRaw_
;
}
__device__
bool
IsMaskedElement
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
__device__
bool
IsMaskedElement
(
ck
::
index_t
m
,
ck
::
index_t
n
)
const
{
{
...
@@ -197,8 +194,8 @@ template <typename ALayout,
...
@@ -197,8 +194,8 @@ template <typename ALayout,
ck
::
LoopScheduler
LoopSched
=
ck
::
LoopScheduler
::
Default
>
ck
::
LoopScheduler
LoopSched
=
ck
::
LoopScheduler
::
Default
>
struct
CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
struct
CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
{
{
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
ck
::
tensor_operation
::
device
::
ck
::
tensor_operation
::
device
::
GemmGemmPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
{
GemmGemmPadder
<
GemmSpec
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
MPerBlock
,
NPerBlock
,
KPerBlock
,
Gemm1NPerBlock
};
static
constexpr
auto
get_AK1
()
{
return
AK1
;
};
static
constexpr
auto
get_AK1
()
{
return
AK1
;
};
...
@@ -215,11 +212,11 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -215,11 +212,11 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
CElementwiseOperation
c_element_op
{};
CElementwiseOperation
c_element_op
{};
AccElementwiseOperation
acc_element_op
{
alpha
};
AccElementwiseOperation
acc_element_op
{
alpha
};
template
<
typename
AGridDesc_AK0_M_AK1
,
template
<
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
B1GridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
>
typename
CGridDesc_M_N
>
struct
rt_gridwisegemm
struct
rt_gridwisegemm
{
{
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
ck
::
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
using
GridwiseGemm
=
ck
::
GridwiseBatchedGemmSoftmaxGemm_Xdl_CShuffle
<
...
@@ -286,6 +283,5 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
...
@@ -286,6 +283,5 @@ struct CK_DeviceBatchedGemmSoftmaxGemm_Xdl_CShuffle
};
};
};
};
}
// namespace migraphx
}
// namespace migraphx
#endif
#endif
test/onnx/gen_onnx.py
View file @
3133fd79
...
@@ -31,6 +31,7 @@ from onnx import TensorProto
...
@@ -31,6 +31,7 @@ from onnx import TensorProto
def
onnx_test
(
op_test
):
def
onnx_test
(
op_test
):
def
run_test
():
def
run_test
():
op_info
=
op_test
()
op_info
=
op_test
()
if
len
(
op_info
)
>
3
:
if
len
(
op_info
)
>
3
:
...
@@ -1995,16 +1996,17 @@ def gemm_softmax_gemm_test():
...
@@ -1995,16 +1996,17 @@ def gemm_softmax_gemm_test():
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
bias
=
helper
.
make_tensor_value_info
(
'bias'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
out
=
helper
.
make_tensor_value_info
(
'out'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
out
=
helper
.
make_tensor_value_info
(
'out'
,
TensorProto
.
FLOAT16
,
[
1
,
1
])
scale_array
=
np
.
array
([(
1
/
8
)])
scale_array
=
np
.
array
([(
1
/
8
)])
scale_tensor
=
helper
.
make_tensor
(
name
=
'scale'
,
scale_tensor
=
helper
.
make_tensor
(
name
=
'scale'
,
data_type
=
TensorProto
.
FLOAT16
,
data_type
=
TensorProto
.
FLOAT16
,
dims
=
scale_array
.
shape
,
dims
=
scale_array
.
shape
,
vals
=
scale_array
.
flatten
().
astype
(
np
.
float16
))
vals
=
scale_array
.
flatten
().
astype
(
np
.
float16
))
gemm1
=
onnx
.
helper
.
make_node
(
'MatMul'
,
gemm1
=
onnx
.
helper
.
make_node
(
'MatMul'
,
inputs
=
[
'a'
,
'b'
],
inputs
=
[
'a'
,
'b'
],
outputs
=
[
'gemm1_out'
])
outputs
=
[
'gemm1_out'
])
mul1
=
onnx
.
helper
.
make_node
(
'Mul'
,
mul1
=
onnx
.
helper
.
make_node
(
'Mul'
,
inputs
=
[
'gemm1_out'
,
'scale'
],
inputs
=
[
'gemm1_out'
,
'scale'
],
outputs
=
[
'mul1_out'
])
outputs
=
[
'mul1_out'
])
...
@@ -2012,14 +2014,14 @@ def gemm_softmax_gemm_test():
...
@@ -2012,14 +2014,14 @@ def gemm_softmax_gemm_test():
inputs
=
[
'mul1_out'
,
'c'
],
inputs
=
[
'mul1_out'
,
'c'
],
outputs
=
[
'add1_out'
])
outputs
=
[
'add1_out'
])
softmax
=
onnx
.
helper
.
make_node
(
'Softmax'
,
softmax
=
onnx
.
helper
.
make_node
(
'Softmax'
,
inputs
=
[
'add1_out'
],
inputs
=
[
'add1_out'
],
outputs
=
[
'softmax_out'
])
outputs
=
[
'softmax_out'
])
gemm2
=
onnx
.
helper
.
make_node
(
'MatMul'
,
gemm2
=
onnx
.
helper
.
make_node
(
'MatMul'
,
inputs
=
[
'softmax_out'
,
'b1'
],
inputs
=
[
'softmax_out'
,
'b1'
],
outputs
=
[
'out'
])
outputs
=
[
'out'
])
return
([
gemm1
,
mul1
,
add1
,
softmax
,
gemm2
],
[
a
,
b
,
c
,
b1
,
bias
],
[
out
],
[
scale_tensor
])
return
([
gemm1
,
mul1
,
add1
,
softmax
,
gemm2
],
[
a
,
b
,
c
,
b1
,
bias
],
[
out
],
[
scale_tensor
])
@
onnx_test
@
onnx_test
...
...
test/verify/0ck_gemm_softmax_gemm.cpp
View file @
3133fd79
...
@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
...
@@ -36,10 +36,10 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
1
,
12
,
256
,
256
}};
auto
m2_elements
=
1
*
12
*
256
*
256
;
auto
m2_elements
=
1
*
12
*
256
*
256
;
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
a
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b
=
mm
->
add_parameter
(
"2"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
b1
=
mm
->
add_parameter
(
"3"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
auto
c
=
mm
->
add_parameter
(
"4"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
...
@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
...
@@ -48,9 +48,9 @@ struct ck_gemm_softmax_gemm : verify_program<ck_gemm_softmax_gemm>
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment