Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
bf214665
Commit
bf214665
authored
Aug 22, 2024
by
carlushuang
Browse files
add b_nr_kr_waveflatten pattern
parent
22ab193c
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
178 additions
and
39 deletions
+178
-39
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
..._tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
+24
-0
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
...6_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
+74
-4
example/ck_tile/06_permute/permute.cpp
example/ck_tile/06_permute/permute.cpp
+77
-35
example/ck_tile/06_permute/script/smoke_test.sh
example/ck_tile/06_permute/script/smoke_test.sh
+3
-0
No files found.
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle.cpp
View file @
bf214665
...
...
@@ -35,6 +35,18 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
auto
k
=
Kernel
(
a
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
k
);
return
ave_time
;
}
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
{
constexpr
matrix_core_permute_style
pstyle
=
matrix_core_permute_style
::
permute_b_nr_kr_kw_nw_kv
;
using
Kernel
=
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
auto
k
=
Kernel
(
a
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
k
);
return
ave_time
;
}
}
...
...
@@ -66,6 +78,18 @@ float matrix_core_swizzle(matrix_core_swizzle_traits t,
auto
k
=
Kernel
(
a
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
k
);
return
ave_time
;
}
else
if
(
t
.
permute
.
compare
(
"0,1,3,4,2,5"
)
==
0
)
{
constexpr
matrix_core_permute_style
pstyle
=
matrix_core_permute_style
::
permute_b_nr_kr_kw_nw_kv
;
using
Kernel
=
matrix_core_swizzle_kernel
<
BLOCK_SIZE
,
NPerBlock
,
KPerBlock
,
pstyle
,
Inst
>
;
auto
k
=
Kernel
(
a
);
float
ave_time
=
ck_tile
::
launch_kernel
(
s
,
k
);
return
ave_time
;
}
}
...
...
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
View file @
bf214665
...
...
@@ -32,10 +32,13 @@ struct to_warp_gemm<matrix_core_inst_enum::MFMA_16x16x16_F16>
template
<
matrix_core_inst_enum
Inst
>
using
to_warp_gemm_t
=
typename
detail
::
to_warp_gemm
<
Inst
>::
type
;
// TODO: in below permute pattern, the last 3 dim is within wave
enum
class
matrix_core_permute_style
{
permute_b_n0_k0_n1_k1_n2_k2
=
0
,
// 0,1,4,2,5,3,6
permute_b_n0_n1_k0_k1_n2_k2
=
1
,
// 0,1,2,4,5,3,6
permute_b_nr_kr_kw_nw_kv
=
2
,
// 0,1,3,4,2,5
permute_b_nr_kr_waveflatten
=
permute_b_nr_kr_kw_nw_kv
,
};
// assume this is B matrix, originally we have batch*n*k
...
...
@@ -81,6 +84,9 @@ struct matrix_core_swizzle_kernel
using
harg
=
matrix_core_swizzle_host_args
;
static
constexpr
int
BLOCK_SIZE
=
BLOCK_SIZE_
;
static
constexpr
int
WavesPerBlock_N
=
4
;
static
constexpr
int
WavesPerBlock_K
=
1
;
static_assert
(
WavesPerBlock_N
*
WavesPerBlock_K
*
64
==
BLOCK_SIZE
);
static
constexpr
int
NPerBlock
=
NPerBlock_
;
static
constexpr
int
KPerBlock
=
KPerBlock_
;
static
constexpr
matrix_core_permute_style
pstyle
=
pstyle_
;
...
...
@@ -171,7 +177,7 @@ struct matrix_core_swizzle_kernel
sequence
<
0
,
0
,
0
>>
{});
// clang-format on
}
else
else
if
constexpr
(
pstyle
==
matrix_core_permute_style
::
permute_b_n0_n1_k0_k1_n2_k2
)
{
// clang-format off
return
make_static_tile_distribution
(
...
...
@@ -189,6 +195,39 @@ struct matrix_core_swizzle_kernel
sequence
<
0
,
0
,
0
>>
{});
// clang-format on
}
else
{
// clang-format off
// permute_b_nr_kr_kw_nw_kv or permute_b_nr_kr_waveflatten
constexpr
index_t
Kv
=
Alignment
;
constexpr
index_t
Nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
Kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
static_assert
(
KPerBlock
%
(
K1
*
K2
)
==
0
);
constexpr
index_t
Nr
=
NPerBlock
/
Nw
;
constexpr
index_t
Kr
=
KPerBlock
/
(
Kv
*
Kw
);
constexpr
index_t
Nr_p
=
WavesPerBlock_N
;
constexpr
index_t
Kr_p
=
WavesPerBlock_K
;
constexpr
index_t
Nr_y
=
Nr
/
Nr_p
;
constexpr
index_t
Kr_y
=
Kr
/
Kr_p
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
// 0
// major 1 2 3
// minor 0 1 0 1 0 1 2
tuple
<
sequence
<
Nr_y
,
Nr_p
>
,
sequence
<
Kr_y
,
Kr_p
>
,
sequence
<
Kw
,
Nw
,
Kv
>>
,
// Nr_p, Kr_p Kw Nw
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
,
3
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
,
1
>>
,
// Nr_y Kr_y Kv
sequence
<
1
,
2
,
3
>
,
sequence
<
0
,
0
,
2
>>
{});
// clang-format on
}
}
__device__
void
operator
()(
karg
a_
)
...
...
@@ -242,7 +281,7 @@ struct matrix_core_swizzle_kernel
number
<
Alignment
>
{});
// control vector load
return
tmp
;
}
else
else
if
constexpr
(
pstyle
==
matrix_core_permute_style
::
permute_b_n0_n1_k0_k1_n2_k2
)
{
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_dst
,
...
...
@@ -250,6 +289,21 @@ struct matrix_core_swizzle_kernel
number
<
Alignment
>
{});
// control vector load
return
tmp
;
}
else
{
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
waveflatten
=
kw
*
nw
*
kv
;
const
index_t
kr
=
a_
.
k
/
(
k1
*
k2
);
const
index_t
nr
=
a_
.
n
/
nw
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_dst
,
make_tuple
(
nr
,
kr
,
waveflatten
),
number
<
Alignment
>
{});
// control vector load
return
tmp
;
}
}();
auto
dst_window
=
[
&
]()
{
...
...
@@ -265,7 +319,7 @@ struct matrix_core_swizzle_kernel
{
i_n
*
n0_tile
,
i_k
*
k0_tile
,
0
,
0
,
0
,
0
},
get_dst_dist
());
}
else
else
if
constexpr
(
pstyle
==
matrix_core_permute_style
::
permute_b_n0_n1_k0_k1_n2_k2
)
{
return
make_tile_window
(
dst_view
,
make_tuple
(
number
<
n0_tile
>
{},
...
...
@@ -277,6 +331,22 @@ struct matrix_core_swizzle_kernel
{
i_n
*
n0_tile
,
0
,
i_k
*
k0_tile
,
0
,
0
,
0
},
get_dst_dist
());
}
else
{
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
constexpr
index_t
waveflatten_tile
=
kw
*
nw
*
kv
;
constexpr
index_t
nr_tile
=
NPerBlock
/
nw
;
constexpr
index_t
kr_tile
=
KPerBlock
/
(
kw
*
kv
);
return
make_tile_window
(
dst_view
,
make_tuple
(
number
<
nr_tile
>
{},
number
<
kr_tile
>
{},
number
<
waveflatten_tile
>
{}),
{
i_n
*
nr_tile
,
i_k
*
kr_tile
,
0
},
get_dst_dist
());
}
}();
// actual load store
...
...
example/ck_tile/06_permute/permute.cpp
View file @
bf214665
...
...
@@ -258,8 +258,49 @@ bool run(const ck_tile::ArgParser& arg_parser)
};
#ifdef PERMUTE_USE_ALTERNATIVE_IMPL
// batch* n0*n1*n2*k0*k1*k2 -> batch* n0*k0*n1*k1*n2*k2
if
(
rank
==
7
&&
(
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,4,2,5,3,6"
)
||
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,2,4,5,3,6"
)))
if
((
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,4,2,5,3,6"
)
||
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,2,4,5,3,6"
)
||
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,3,4,2,5"
)))
{
if
(
arg_parser
.
get_str
(
"perm"
)
==
std
::
string
(
"0,1,3,4,2,5"
))
{
// permute_b_nr_kr_kw_nw_kv = 2, // 0,1,3,4,2,5
matrix_core_swizzle_traits
t
;
t
.
data_type
=
data_type
;
t
.
permute
=
arg_parser
.
get_str
(
"perm"
);
matrix_core_swizzle_args
a
;
a
.
p_src
=
x_buf
.
GetDeviceBuffer
();
a
.
p_dst
=
y_buf
.
GetDeviceBuffer
();
a
.
batch
=
shape
[
0
];
auto
nr
=
shape
[
1
];
auto
nw
=
shape
[
2
];
auto
kr
=
shape
[
3
];
auto
kw
=
shape
[
4
];
auto
kv
=
shape
[
5
];
a
.
n
=
nr
*
nw
;
a
.
k
=
kr
*
kw
*
kv
;
if
(
kv
==
8
&&
kw
==
4
&&
nw
==
16
&&
nr
%
4
==
0
&&
kr
%
8
==
0
)
{
t
.
inst
=
"16x16x16"
;
std
::
cout
<<
", matrix_core_swizzle_waveflatten_"
<<
t
.
inst
<<
std
::
flush
;
ave_time
=
matrix_core_swizzle
(
t
,
a
,
stream_config
);
}
else
if
(
kv
==
8
&&
kw
==
2
&&
nw
==
32
&&
nr
%
4
==
0
&&
kr
%
8
==
0
)
{
t
.
inst
=
"32x32x8"
;
std
::
cout
<<
", matrix_core_swizzle_waveflatten_"
<<
t
.
inst
<<
std
::
flush
;
ave_time
=
matrix_core_swizzle
(
t
,
a
,
stream_config
);
}
else
{
ave_time
=
run_permute
();
}
}
else
{
matrix_core_swizzle_traits
t
;
t
.
data_type
=
data_type
;
...
...
@@ -271,8 +312,8 @@ bool run(const ck_tile::ArgParser& arg_parser)
a
.
batch
=
shape
[
0
];
a
.
n
=
shape
[
1
]
*
shape
[
2
]
*
shape
[
3
];
a
.
k
=
shape
[
4
]
*
shape
[
5
]
*
shape
[
6
];
if
(
shape
[
6
]
==
8
&&
shape
[
3
]
==
32
&&
shape
[
5
]
==
2
&&
shape
[
2
]
==
4
&&
shape
[
4
]
%
8
==
0
&&
shape
[
1
]
%
2
==
0
)
if
(
shape
[
6
]
==
8
&&
shape
[
3
]
==
32
&&
shape
[
5
]
==
2
&&
shape
[
2
]
==
4
&&
shape
[
4
]
%
8
==
0
&&
shape
[
1
]
%
2
==
0
)
{
// 32x32x8 inst
// perm=0,1,4,2,5,3,6
...
...
@@ -301,6 +342,7 @@ bool run(const ck_tile::ArgParser& arg_parser)
ave_time
=
run_permute
();
}
}
}
else
#endif
{
...
...
example/ck_tile/06_permute/script/smoke_test.sh
View file @
bf214665
...
...
@@ -15,6 +15,9 @@ $EXE -prec=fp16 -shape=3,8,4,16,16,4,8 -perm=0,1,4,2,5,3,6 $COMMON_ARGS
$EXE
-prec
=
fp16
-shape
=
3,6,4,32,16,2,8
-perm
=
0,1,2,4,5,3,6
$COMMON_ARGS
$EXE
-prec
=
fp16
-shape
=
5,10,4,32,8,2,8
-perm
=
0,1,2,4,5,3,6
$COMMON_ARGS
$EXE
-prec
=
fp16
-shape
=
3,8,4,16,16,4,8
-perm
=
0,1,2,4,5,3,6
$COMMON_ARGS
$EXE
-prec
=
fp16
-shape
=
2,8,16,8,4,8
-perm
=
0,1,3,4,2,5
$COMMON_ARGS
$EXE
-prec
=
fp16
-shape
=
1,24,32,16,2,8
-perm
=
0,1,3,4,2,5
$COMMON_ARGS
echo
"------------------------------------------------------------------"
for
prec
in
"fp8"
"fp16"
"fp32"
;
do
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment