Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
33ceea62
Commit
33ceea62
authored
Aug 24, 2024
by
carlushuang
Browse files
merge to convert address
parent
1ba8a08f
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
63 additions
and
13 deletions
+63
-13
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
...6_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
+63
-13
No files found.
example/ck_tile/06_permute/alternative_impl/matrix_core_swizzle_kernel.hpp
View file @
33ceea62
...
@@ -7,6 +7,11 @@
...
@@ -7,6 +7,11 @@
#include "ck_tile/host.hpp"
#include "ck_tile/host.hpp"
#include "ck_tile/ops/gemm.hpp"
#include "ck_tile/ops/gemm.hpp"
// if set to 1, slightly more instructions generated to calculate address
#ifndef MERGE_2D_013425
#define MERGE_2D_013425 0
#endif
enum
class
matrix_core_inst_enum
enum
class
matrix_core_inst_enum
{
{
MFMA_32x32x8_F16
=
0
,
MFMA_32x32x8_F16
=
0
,
...
@@ -213,19 +218,35 @@ struct matrix_core_swizzle_kernel
...
@@ -213,19 +218,35 @@ struct matrix_core_swizzle_kernel
constexpr
index_t
Kr_y
=
Kr
/
Kr_p
;
constexpr
index_t
Kr_y
=
Kr
/
Kr_p
;
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
#if MERGE_2D_013425
tile_distribution_encoding
<
tile_distribution_encoding
<
sequence
<
1
>
,
// 0
sequence
<
1
>
,
// 0 R
// major 1 2
// minor 0 1 2 0 1 2 3
tuple
<
sequence
<
Nr_y
,
Nr_p
,
Nw
>
,
sequence
<
Kr_y
,
Kr_p
,
Kw
,
Kv
>>
,
// H
// Nr_p, Kr_p Kw Nw
tuple
<
sequence
<
1
,
2
>
,
sequence
<
2
,
1
>>
,
// p major
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
// p minor
// Nr_y Kr_y Kv
sequence
<
1
,
2
,
2
>
,
// Y major
sequence
<
0
,
0
,
3
>>
{});
// y minor
#else
tile_distribution_encoding
<
sequence
<
1
>
,
// 0 R
// major 1 2 3
// major 1 2 3
// minor 0 1 0 1 0 1 2
// minor 0 1 0 1 0 1 2
tuple
<
sequence
<
Nr_y
,
Nr_p
>
,
sequence
<
Kr_y
,
Kr_p
>
,
sequence
<
Kw
,
Nw
,
Kv
>>
,
tuple
<
sequence
<
Nr_y
,
Nr_p
>
,
sequence
<
Kr_y
,
Kr_p
>
,
sequence
<
Kw
,
Nw
,
Kv
>>
,
// H
// Nr_p, Kr_p Kw Nw
// Nr_p, Kr_p Kw Nw
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
,
3
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
3
,
3
>>
,
// p major
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
0
,
1
>>
,
// p minor
// Nr_y Kr_y Kv
// Nr_y Kr_y Kv
sequence
<
1
,
2
,
3
>
,
sequence
<
1
,
2
,
3
>
,
// Y major
sequence
<
0
,
0
,
2
>>
{});
sequence
<
0
,
0
,
2
>>
{});
// y minor
#endif
// clang-format on
// clang-format on
}
}
}
}
...
@@ -291,6 +312,26 @@ struct matrix_core_swizzle_kernel
...
@@ -291,6 +312,26 @@ struct matrix_core_swizzle_kernel
}
}
else
else
{
{
#if MERGE_2D_013425
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
kw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kABKLane
;
// constexpr index_t waveflatten = kw*nw*kv;
const
index_t
kr
=
a_
.
k
/
(
k1
*
k2
);
const
index_t
nr
=
a_
.
n
/
nw
;
auto
tmp
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_dst
,
make_tuple
(
nr
,
kr
,
number
<
kw
>
{},
number
<
nw
>
{},
number
<
kv
>
{}),
number
<
Alignment
>
{});
// control vector load
auto
tmp_1
=
transform_tensor_view
(
tmp
,
make_tuple
(
make_merge_transform
(
make_tuple
(
nr
,
number
<
nw
>
{})),
make_merge_transform
(
make_tuple
(
kr
,
number
<
kw
>
{},
number
<
kv
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
tmp_1
;
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv,
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
...
@@ -303,6 +344,7 @@ struct matrix_core_swizzle_kernel
...
@@ -303,6 +344,7 @@ struct matrix_core_swizzle_kernel
make_tuple
(
nr
,
kr
,
waveflatten
),
make_tuple
(
nr
,
kr
,
waveflatten
),
number
<
Alignment
>
{});
// control vector load
number
<
Alignment
>
{});
// control vector load
return
tmp
;
return
tmp
;
#endif
}
}
}();
}();
...
@@ -333,6 +375,13 @@ struct matrix_core_swizzle_kernel
...
@@ -333,6 +375,13 @@ struct matrix_core_swizzle_kernel
}
}
else
else
{
{
#if MERGE_2D_013425
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
return
make_tile_window
(
dst_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
i_n
*
NPerBlock
,
i_k
*
KPerBlock
},
get_dst_dist
());
#else
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
// permute_b_nr_kr_waveflatten = permute_b_nr_kr_kw_nw_kv
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
kv
=
Alignment
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
constexpr
index_t
nw
=
WarpGemm
::
WarpGemmAttribute
::
Impl
::
kAMLane
;
...
@@ -346,6 +395,7 @@ struct matrix_core_swizzle_kernel
...
@@ -346,6 +395,7 @@ struct matrix_core_swizzle_kernel
number
<
waveflatten_tile
>
{}),
number
<
waveflatten_tile
>
{}),
{
i_n
*
nr_tile
,
i_k
*
kr_tile
,
0
},
{
i_n
*
nr_tile
,
i_k
*
kr_tile
,
0
},
get_dst_dist
());
get_dst_dist
());
#endif
}
}
}();
}();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment