Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
MIGraphX
Commits
4b96da8d
Commit
4b96da8d
authored
Dec 14, 2022
by
Alan Turner
Browse files
Add batch folding to CK Gemms with broadcasted B
parent
6be3baa1
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
101 deletions
+51
-101
src/targets/gpu/jit/ck_gemm.cpp
src/targets/gpu/jit/ck_gemm.cpp
+50
-9
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
+1
-0
test/verify/0ck_gemm_softmax_gemm.cpp
test/verify/0ck_gemm_softmax_gemm.cpp
+0
-92
No files found.
src/targets/gpu/jit/ck_gemm.cpp
View file @
4b96da8d
...
@@ -244,6 +244,38 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -244,6 +244,38 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
return
inputs
;
return
inputs
;
}
}
static
std
::
size_t
get_batch_count
(
const
shape
&
s
)
{
return
std
::
accumulate
(
s
.
lens
().
rbegin
()
+
2
,
s
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
}
static
void
fold_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
batch_count
=
get_batch_count
(
s
);
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
if
(
transposed_matrix
(
s
))
s
=
shape
{
s
.
type
(),
{
m1
,
m2
*
batch_count
}};
else
s
=
shape
{
s
.
type
(),
{
m1
*
batch_count
,
m2
}};
}
static
void
remove_batch_dims
(
shape
&
s
)
{
auto
lens
=
s
.
lens
();
if
(
lens
.
size
()
<=
2
)
return
;
auto
m1
=
lens
.
at
(
lens
.
size
()
-
2
);
auto
m2
=
lens
.
at
(
lens
.
size
()
-
1
);
s
=
shape
{
s
.
type
(),
{
m1
,
m2
}};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
std
::
vector
<
std
::
string
>
names
()
const
{
return
{
"ck_gemm"
,
"gpu::ck_gemm"
};
}
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
operation
compile_op
(
context
&
/* ctx */
,
const
std
::
vector
<
shape
>&
inputs
,
const
value
&
v
)
const
...
@@ -253,11 +285,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -253,11 +285,16 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
auto
c_shape
=
inputs
.
back
();
auto
c_shape
=
inputs
.
back
();
auto
rank
=
a_shape
.
lens
().
size
();
auto
rank
=
a_shape
.
lens
().
size
();
auto
b_strides
=
b_shape
.
strides
();
bool
can_fold_batch
=
rank
>=
3
and
b_strides
[
rank
-
3
]
==
0
;
auto
batch_count
=
get_batch_count
(
c_shape
);
auto
m
=
c_shape
.
lens
()[
rank
-
2
];
m
=
can_fold_batch
?
m
*
batch_count
:
m
;
auto
n
=
c_shape
.
lens
().
back
();
auto
k
=
a_shape
.
lens
().
back
();
std
::
array
<
char
,
3
>
keys
{
'M'
,
'N'
,
'K'
};
std
::
array
<
char
,
3
>
keys
{
'M'
,
'N'
,
'K'
};
std
::
array
<
std
::
size_t
,
3
>
config
{
std
::
array
<
std
::
size_t
,
3
>
config
{
m
,
n
,
k
};
c_shape
.
lens
()[
rank
-
2
],
c_shape
.
lens
().
back
(),
a_shape
.
lens
().
back
()};
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
tuning_val
=
v
.
get
(
"tuning_val"
,
get_tuning_for
({
a_shape
,
b_shape
,
c_shape
}));
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
auto
ip
=
instance
{
get_instance
(
tuning_val
,
[
&
](
const
auto
&
x
)
->
bool
{
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
return
get_layout
(
a_shape
)
==
x
[
0
]
and
get_layout
(
b_shape
)
==
x
[
1
]
and
...
@@ -286,19 +323,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
...
@@ -286,19 +323,23 @@ struct ck_gemm_compiler : compiler<ck_gemm_compiler>
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
ip
.
set_gemm
(
"ck::tensor_operation::device::GemmSpecialization::"
+
gemm_type
);
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
auto
blocks_per_batch
=
ip
.
get_grid_size
(
config
);
auto
batch_count
=
std
::
accumulate
(
c_shape
.
lens
().
rbegin
()
+
2
,
c_shape
.
lens
().
rend
(),
std
::
size_t
{
1
},
std
::
multiplies
<
std
::
size_t
>
());
hip_compile_options
options
;
hip_compile_options
options
;
auto
block_size
=
ip
.
get_block_size
();
auto
block_size
=
ip
.
get_block_size
();
auto
grid_size
=
batch_count
*
blocks_per_batch
;
auto
grid_size
=
can_fold_batch
?
blocks_per_batch
:
batch_count
*
blocks_per_batch
;
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
set_launch_params
(
v
,
grid_size
*
block_size
,
block_size
);
options
.
inputs
=
inputs
;
options
.
inputs
=
inputs
;
options
.
output
=
c_shape
;
options
.
output
=
c_shape
;
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
kernel_name
=
v
.
get
(
"kernel"
,
"ck_gemm_kernel"
);
options
.
virtual_inputs
=
inputs
;
options
.
virtual_inputs
=
inputs
;
if
(
can_fold_batch
)
{
auto
vinputs
=
inputs
;
fold_batch_dims
(
vinputs
[
0
]);
remove_batch_dims
(
vinputs
[
1
]);
std
::
for_each
(
vinputs
.
begin
()
+
2
,
vinputs
.
end
(),
fold_batch_dims
);
options
.
virtual_inputs
=
vinputs
;
}
if
(
v
.
get
(
"check"
,
false
)
or
enabled
(
MIGRAPHX_CK_DEBUG
{}))
if
(
v
.
get
(
"check"
,
false
)
or
enabled
(
MIGRAPHX_CK_DEBUG
{}))
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
options
.
params
+=
" -DMIGRAPHX_CK_CHECK=1"
;
...
...
src/targets/gpu/kernels/include/migraphx/kernels/ck_gemm.hpp
View file @
4b96da8d
...
@@ -53,6 +53,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
...
@@ -53,6 +53,7 @@ __device__ void ck_gemm_matrix(E e, A a, B b, Ds... ds)
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_tensor
<
A
>
());
constexpr
const
auto
a_grid_desc_m_k
=
gemm
.
matrix_padder
.
PadADescriptor_M_K
(
to_ck_tensor
<
A
>
());
constexpr
const
auto
b_grid_desc_n_k
=
constexpr
const
auto
b_grid_desc_n_k
=
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_tensor
<
ck_transposeb
<
B
>>
());
gemm
.
matrix_padder
.
PadBDescriptor_N_K
(
to_ck_tensor
<
ck_transposeb
<
B
>>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
E
>
());
constexpr
const
auto
e_grid_desc_m_n
=
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
E
>
());
constexpr
const
auto
ds_grid_desc_m_n
=
constexpr
const
auto
ds_grid_desc_m_n
=
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
Ds
>
())...);
ck
::
make_tuple
(
gemm
.
matrix_padder
.
PadCDescriptor_M_N
(
to_ck_tensor
<
Ds
>
())...);
...
...
test/verify/0ck_gemm_softmax_gemm.cpp
deleted
100644 → 0
View file @
6be3baa1
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#include "verify_program.hpp"
#include <migraphx/program.hpp>
#include <migraphx/generate.hpp>
#include <migraphx/make_op.hpp>
struct
ck_gemm_softmax_gemm
:
verify_program
<
ck_gemm_softmax_gemm
>
{
migraphx
::
program
create_program
()
const
{
// migraphx::program p;
// auto* mm = p.get_main_module();
// migraphx::shape m1_shape{migraphx::shape::half_type, {16, 12, 384, 64}};
// migraphx::shape m2_shape{migraphx::shape::half_type, {16, 12, 384, 384}};
// auto m2_elements = 16 * 12 * 384 * 384;
// auto a = mm->add_parameter("1", m1_shape);
// auto b = mm->add_parameter("2", m1_shape);
// auto b1 = mm->add_parameter("3", m1_shape);
// auto c = mm->add_parameter("4", m1_shape);
// std::vector<float> eights(m2_elements, 0.125);
// auto eight = mm->add_literal(migraphx::literal{m2_shape, eights});
// std::vector<float> zeros(m2_elements, 0);
// auto zero = mm->add_literal(migraphx::literal{m2_shape, zeros});
// std::vector<float> ones(m2_elements, 1);
// auto one = mm->add_literal(migraphx::literal{m2_shape, ones});
// // a = one;
// // b = one;
// // b1 = one;
// b = mm->add_instruction(migraphx::make_op("transpose", {{"permutation", {0, 1, 3, 2}}}),
// b); auto gemm1 = mm->add_instruction(migraphx::make_op("dot"), a, b); auto scale =
// mm->add_instruction(migraphx::make_op("mul"), gemm1, eight); auto bias =
// mm->add_instruction(migraphx::make_op("add"), scale, zero); auto softmax =
// mm->add_instruction(migraphx::make_op("softmax", {{"axis", -1}}), bias);
// mm->add_instruction(migraphx::make_op("dot"), softmax, b1);
migraphx
::
program
p
;
auto
*
mm
=
p
.
get_main_module
();
size_t
batch
=
2
;
migraphx
::
shape
m1_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
384
,
2304
}};
migraphx
::
shape
m2_shape
{
migraphx
::
shape
::
half_type
,
{
batch
,
12
,
384
,
384
}};
auto
m2_elements
=
batch
*
12
*
384
*
384
;
auto
g
=
mm
->
add_parameter
(
"1"
,
m1_shape
);
std
::
vector
<
float
>
eights
(
m2_elements
,
0.125
);
auto
eight
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
eights
});
std
::
vector
<
float
>
zeros
(
m2_elements
,
0
);
auto
zero
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
zeros
});
std
::
vector
<
float
>
ones
(
m2_elements
,
1
);
auto
one
=
mm
->
add_literal
(
migraphx
::
literal
{
m2_shape
,
ones
});
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"reshape"
,
{{
"dims"
,
{
batch
,
384
,
36
,
64
}}}),
g
);
g
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
2
,
1
,
3
}}}),
g
);
auto
a
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
0
}},
{
"ends"
,
{
12
}}}),
g
);
auto
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
12
}},
{
"ends"
,
{
24
}}}),
g
);
auto
b1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"slice"
,
{{
"axes"
,
{
1
}},
{
"starts"
,
{
24
}},
{
"ends"
,
{
36
}}}),
g
);
b
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"transpose"
,
{{
"permutation"
,
{
0
,
1
,
3
,
2
}}}),
b
);
auto
gemm1
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
a
,
b
);
auto
scale
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"mul"
),
gemm1
,
eight
);
auto
bias
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"add"
),
scale
,
zero
);
auto
softmax
=
mm
->
add_instruction
(
migraphx
::
make_op
(
"softmax"
,
{{
"axis"
,
-
1
}}),
bias
);
mm
->
add_instruction
(
migraphx
::
make_op
(
"dot"
),
softmax
,
b1
);
return
p
;
}
};
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment