Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
02f8c487
Commit
02f8c487
authored
Oct 09, 2024
by
carlushuang
Browse files
add single issue api
parent
bafb600b
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
530 additions
and
436 deletions
+530
-436
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+54
-16
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+30
-18
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+253
-245
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+73
-42
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+10
-6
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
...ile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
+67
-69
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
.../fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
+22
-20
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+12
-12
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
+4
-4
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+2
-1
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+1
-1
No files found.
include/ck_tile/core/tensor/load_tile.hpp
View file @
02f8c487
...
...
@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
T
,
...
...
@@ -50,6 +54,7 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
...
@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile_window
.
load_raw
(
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
...
...
@@ -68,6 +75,7 @@ template <typename T,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
...
@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile_window
.
load_raw
(
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
...
...
@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
...
...
@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
async_load
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
async_load
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
...
...
@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
...
...
@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
...
...
@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
template
<
typename
WindowLengths
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
return
null_tensor
{};
}
template
<
typename
T
,
typename
WindowLengths
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
/*null_tile*/
,
const
null_tile_window
<
WindowLengths
>&
)
template
<
typename
T
,
typename
WindowLengths
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
/*null_tile*/
,
const
null_tile_window
<
WindowLengths
>&
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
}
// TODO: this function requires some sub-fileds exist for the target tile window
template
<
typename
TileWindow
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
TileWindow
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
const
TileWindow
&
w
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
...
...
@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load_tile_raw
(
t
,
w
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
load_tile_raw
(
t
,
w
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
t
;
}
...
...
include/ck_tile/core/tensor/store_tile.hpp
View file @
02f8c487
...
...
@@ -18,10 +18,12 @@ namespace ck_tile {
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
...
...
@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
...
...
@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_window.hpp
View file @
02f8c487
...
...
@@ -18,6 +18,23 @@
namespace
ck_tile
{
// TODO: NumCoord no need anymore?
#define WINDOW_DISPATCH_ISSUE_2() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumCoord, 1>{}([&](auto iCoord) { \
static_for<0, NumAccessPerCoord, 1>{}( \
[&](auto iCoordAccess) { issue(iCoord, iCoordAccess); }); \
}); \
} \
else \
{ \
static_assert(i_access < (NumCoord * NumAccessPerCoord)); \
constexpr auto iCoordAccess = number<i_access % NumAccessPerCoord>{}; \
constexpr auto iCoord = number<i_access / NumAccessPerCoord>{}; \
issue(iCoord, iCoordAccess); \
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
...
...
@@ -283,8 +300,8 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -296,65 +313,66 @@ struct tile_window_with_static_distribution
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
Traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
Traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step_static
(
iAccess
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step_static
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
};
WINDOW_DISPATCH_ISSUE_2
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
...
...
@@ -377,59 +395,57 @@ struct tile_window_with_static_distribution
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
0
/**/
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
0
/**/
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
};
WINDOW_DISPATCH_ISSUE_2
();
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
...
...
@@ -470,43 +486,44 @@ struct tile_window_with_static_distribution
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
pre_nop_
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
m0_inc_with_memory
(
size_per_issue
);
}
});
});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
m0_inc_with_memory
(
size_per_issue
);
}
};
WINDOW_DISPATCH_ISSUE_2
();
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
...
...
@@ -544,37 +561,37 @@ struct tile_window_with_static_distribution
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
})
;
}
);
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
}
;
WINDOW_DISPATCH_ISSUE_2
(
);
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -586,62 +603,57 @@ struct tile_window_with_static_distribution
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
// vector_type_t vec;
vector_t
vec_value
;
// read from distributed tensor
// vector_type_t vec;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// const vector_t vec_value = vec.template get_as<vector_t>().template at<0>();
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
})
;
}
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
;
WINDOW_DISPATCH_ISSUE_2
(
);
}
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
template
<
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -652,54 +664,53 @@ struct tile_window_with_static_distribution
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
0
,
vec_value
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
0
,
vec_value
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
};
WINDOW_DISPATCH_ISSUE_2
();
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -710,55 +721,50 @@ struct tile_window_with_static_distribution
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
)
;
// read from distributed tensor
vector_t
vec_value
;
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
};
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
WINDOW_DISPATCH_ISSUE_2
();
}
// move thread's botom tensor coordiante
...
...
@@ -857,6 +863,8 @@ struct tile_window_with_static_distribution
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
};
#undef WINDOW_DISPATCH_ISSUE_2
// TODO: use strategy
template
<
typename
TensorView_
,
typename
WindowLengths_
,
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
02f8c487
...
...
@@ -18,6 +18,17 @@
namespace
ck_tile
{
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
//
// This version of tile window will pre-cache offset/flags based on need
//
...
...
@@ -443,8 +454,8 @@ struct tile_window_linear
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
...
...
@@ -453,9 +464,8 @@ struct tile_window_linear
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
...
...
@@ -494,17 +504,22 @@ struct tile_window_linear
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
});
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
// negative means loop over all num_access
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
...
...
@@ -516,11 +531,10 @@ struct tile_window_linear
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access
==
0
&&
if
constexpr
(
pre_nop
&&
i_access
_
==
0
&&
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
)
return
bool_constant
<
true
>
{};
...
...
@@ -550,16 +564,18 @@ struct tile_window_linear
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
...
...
@@ -600,10 +616,10 @@ struct tile_window_linear
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access
==
0
)
if
constexpr
(
pre_nop
&&
i_access
_
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
...
...
@@ -618,15 +634,18 @@ struct tile_window_linear
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
if
constexpr
(
i_access
_
!=
(
NumAccess
-
1
))
{
m0_inc_with_memory
(
size_per_issue
);
}
});
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
...
...
@@ -667,8 +686,8 @@ struct tile_window_linear
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
...
...
@@ -682,15 +701,18 @@ struct tile_window_linear
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
if
constexpr
(
i_access
_
!=
(
NumAccess
-
1
))
{
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
});
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
...
...
@@ -700,8 +722,8 @@ struct tile_window_linear
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
...
...
@@ -732,13 +754,15 @@ struct tile_window_linear
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
});
};
WINDOW_DISPATCH_ISSUE
();
}
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
template
<
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
...
...
@@ -746,8 +770,8 @@ struct tile_window_linear
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
...
...
@@ -773,11 +797,14 @@ struct tile_window_linear
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
);
});
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
...
...
@@ -787,8 +814,8 @@ struct tile_window_linear
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
...
...
@@ -820,7 +847,9 @@ struct tile_window_linear
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
});
};
WINDOW_DISPATCH_ISSUE
();
}
// move thread's botom tensor coordiante
...
...
@@ -920,6 +949,8 @@ struct tile_window_linear
array
<
bool
,
traits
::
NumAccess
>
cached_flags_
;
};
#undef WINDOW_DISPATCH_ISSUE
namespace
impl
{
template
<
address_space_enum
,
index_t
len_
>
struct
default_linear_bottom_dims_impl
...
...
include/ck_tile/core/tensor/update_tile.hpp
View file @
02f8c487
...
...
@@ -17,10 +17,12 @@ namespace ck_tile {
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
...
...
@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
02f8c487
...
...
@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
//
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
View file @
02f8c487
...
...
@@ -45,33 +45,33 @@ struct BlockFmhaPipelineQRAsyncEx
static
constexpr
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
index_t
kK0BlockLength
=
BlockFmhaShape
::
kK0BlockLength
;
static
constexpr
index_t
Block_M0
=
BlockFmhaShape
::
Block_M0
;
static
constexpr
index_t
Block_N0
=
BlockFmhaShape
::
Block_N0
;
static
constexpr
index_t
Block_K0
=
BlockFmhaShape
::
Block_K0
;
static
constexpr
index_t
Block_M0
=
BlockFmhaShape
::
Block_M0
;
static
constexpr
index_t
Block_N0
=
BlockFmhaShape
::
Block_N0
;
static
constexpr
index_t
Block_K0
=
BlockFmhaShape
::
Block_K0
;
static
constexpr
index_t
BlockWarps_M0
=
BlockFmhaShape
::
BlockWarps_M0
;
static
constexpr
index_t
BlockWarps_N0
=
BlockFmhaShape
::
BlockWarps_N0
;
static
constexpr
index_t
BlockWarps_K0
=
BlockFmhaShape
::
BlockWarps_K0
;
static
constexpr
index_t
Warps_M0
=
BlockFmhaShape
::
Warps_M0
;
static
constexpr
index_t
Warps_N0
=
BlockFmhaShape
::
Warps_N0
;
static
constexpr
index_t
Warps_K0
=
BlockFmhaShape
::
Warps_K0
;
static
constexpr
index_t
Repeat_M0
=
BlockFmhaShape
::
Repeat_M0
;
static
constexpr
index_t
Repeat_N0
=
BlockFmhaShape
::
Repeat_N0
;
static
constexpr
index_t
Repeat_K0
=
BlockFmhaShape
::
Repeat_K0
;
static
constexpr
index_t
Block_M1
=
BlockFmhaShape
::
Block_M1
;
static
constexpr
index_t
Block_N1
=
BlockFmhaShape
::
Block_N1
;
static
constexpr
index_t
Block_K1
=
BlockFmhaShape
::
Block_K1
;
static
constexpr
index_t
Warps_M0
=
BlockFmhaShape
::
Warps_M0
;
static
constexpr
index_t
Warps_N0
=
BlockFmhaShape
::
Warps_N0
;
static
constexpr
index_t
Warps_K0
=
BlockFmhaShape
::
Warps_K0
;
static
constexpr
index_t
Repeat_M0
=
BlockFmhaShape
::
Repeat_M0
;
static
constexpr
index_t
Repeat_N0
=
BlockFmhaShape
::
Repeat_N0
;
static
constexpr
index_t
Repeat_K0
=
BlockFmhaShape
::
Repeat_K0
;
static
constexpr
index_t
Block_M1
=
BlockFmhaShape
::
Block_M1
;
static
constexpr
index_t
Block_N1
=
BlockFmhaShape
::
Block_N1
;
static
constexpr
index_t
Block_K1
=
BlockFmhaShape
::
Block_K1
;
static
constexpr
index_t
BlockWarps_M1
=
BlockFmhaShape
::
BlockWarps_M1
;
static
constexpr
index_t
BlockWarps_N1
=
BlockFmhaShape
::
BlockWarps_N1
;
static
constexpr
index_t
BlockWarps_K1
=
BlockFmhaShape
::
BlockWarps_K1
;
static
constexpr
index_t
Warps_M1
=
BlockFmhaShape
::
Warps_M1
;
static
constexpr
index_t
Warps_N1
=
BlockFmhaShape
::
Warps_N1
;
static
constexpr
index_t
Warps_K1
=
BlockFmhaShape
::
Warps_K1
;
static
constexpr
index_t
Repeat_M1
=
BlockFmhaShape
::
Repeat_M1
;
static
constexpr
index_t
Repeat_N1
=
BlockFmhaShape
::
Repeat_N1
;
static
constexpr
index_t
Repeat_K1
=
BlockFmhaShape
::
Repeat_K1
;
static
constexpr
index_t
Warps_M1
=
BlockFmhaShape
::
Warps_M1
;
static
constexpr
index_t
Warps_N1
=
BlockFmhaShape
::
Warps_N1
;
static
constexpr
index_t
Warps_K1
=
BlockFmhaShape
::
Warps_K1
;
static
constexpr
index_t
Repeat_M1
=
BlockFmhaShape
::
Repeat_M1
;
static
constexpr
index_t
Repeat_N1
=
BlockFmhaShape
::
Repeat_N1
;
static
constexpr
index_t
Repeat_K1
=
BlockFmhaShape
::
Repeat_K1
;
static
constexpr
index_t
UnrollStages
=
2
;
// pipeline unroll the gemm/softmax/gemm
static
constexpr
index_t
UnrollStages
=
2
;
// pipeline unroll the gemm/softmax/gemm
static
constexpr
bool
kIsGroupMode
=
Problem
::
kIsGroupMode
;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
...
...
@@ -205,49 +205,47 @@ struct BlockFmhaPipelineQRAsyncEx
"wrong!"
);
// K tile in LDS
auto
k_lds_store
=
[
&
](){
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_store
=
[
&
]()
{
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemStoreDesc_K
<
Problem
>(
i_buf
)),
Policy
::
template
MakeSmemStoreDesc_K
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemStoreDesc_K
<
Problem
>(
i_buf
)),
Policy
::
template
MakeSmemStoreDesc_K
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchK
>
{});
}();
auto
k_lds_load
=
[
&
](){
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>()),
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
k_lds_load
=
[
&
]()
{
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>()),
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
// V tile in LDS
auto
v_lds_store
=
[
&
](){
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
auto
v_lds_store
=
[
&
]()
{
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
generate_tuple
(
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemStoreDesc_V
<
Problem
>(
i_buf
)),
Policy
::
template
MakeSmemStoreDesc_V
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchV
>
{});
[
&
](
auto
i_buf
)
{
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemStoreDesc_V
<
Problem
>(
i_buf
)),
Policy
::
template
MakeSmemStoreDesc_V
<
Problem
>(
i_buf
).
get_lengths
(),
{
0
,
0
,
0
});
},
number
<
Policy
::
NumPrefetchV
>
{});
}();
auto
v_lds_load
=
[
&
](){
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>()),
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>().
get_lengths
(),
{
0
,
0
});
auto
v_lds_load
=
[
&
]()
{
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>()),
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
// reduction function for softmax
...
...
@@ -258,22 +256,20 @@ struct BlockFmhaPipelineQRAsyncEx
constexpr
auto
warp_gemm_0
=
Policy
::
template
GetWarpGemm_0
<
Problem
>();
constexpr
auto
warp_gemm_1
=
Policy
::
template
GetWarpGemm_1
<
Problem
>();
auto
gemm_0
=
[
&
](){
auto
gemm_0
=
[
&
]()
{
constexpr
index_t
total_repeats
=
Repeat_M0
*
Repeat_N0
*
Repeat_K0
;
// n*k*m, more relaxed ds_read
static_for
<
0
,
total_repeats
,
1
>
{}(
[
&
](
auto
i_r
){
constexpr
index_t
i_m
=
i_r
%
Repeat_M0
;
constexpr
index_t
i_k
=
(
i_r
/
Repeat_M0
)
%
Repeat_K0
;
constexpr
index_t
i_n
=
i_r
/
(
Repeat_M0
*
Repeat_K0
);
}
);
static_for
<
0
,
total_repeats
,
1
>
{}([
&
](
auto
i_r
)
{
constexpr
index_t
i_m
=
i_r
%
Repeat_M0
;
constexpr
index_t
i_k
=
(
i_r
/
Repeat_M0
)
%
Repeat_K0
;
constexpr
index_t
i_n
=
i_r
/
(
Repeat_M0
*
Repeat_K0
);
});
};
auto
q_dram_window
=
make_tile_window_raw
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalDesc_Q
<
Problem
>());
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeGlobalDesc_Q
<
Problem
>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
...
...
@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_0
<
Problem
>());
auto
s_accs
=
generate_tuple
([
&
](
auto
)
{
return
SaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
s_accs
=
generate_tuple
([
&
](
auto
)
{
return
SaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_accs
));
...
...
@@ -296,9 +293,10 @@ struct BlockFmhaPipelineQRAsyncEx
using
OaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_1
<
Problem
>());
// init Oacc, M, L
auto
o_accs
=
generate_tuple
([
&
](
auto
)
{
return
OaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ms
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ls
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
o_accs
=
generate_tuple
([
&
](
auto
)
{
return
OaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ms
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ls
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
static_for
<
0
,
UnrollStages
,
1
>
{}([
&
](
auto
i
)
{
clear_tile
(
o_accs
(
i
));
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
View file @
02f8c487
...
...
@@ -105,16 +105,16 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBlockGemmAccTile_0
()
{
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_0
())
::
CWarpDstrEncoding
;
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M0
,
Problem
::
BlockFmhaShape
::
Block_N0
>
;
using
BlockWarps_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block
Warps
_M0
,
Problem
::
BlockFmhaShape
::
Block
Warps
_N0
>
;
using
WarpTile
_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M0
,
Problem
::
BlockFmhaShape
::
Warp_N0
>
;
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M0
,
Problem
::
BlockFmhaShape
::
Block_N0
>
;
using
BlockWarps
_
=
sequence
<
Problem
::
BlockFmhaShape
::
BlockWarps_M0
,
Problem
::
BlockFmhaShape
::
BlockWarps_N0
>
;
using
WarpTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M0
,
Problem
::
BlockFmhaShape
::
Warp_N0
>
;
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
SaccDataType
>
(
dstr
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
SaccDataType
>
(
dstr
);
return
t
;
}
...
...
@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
{
if
constexpr
(
Problem
::
kHasDropout
)
{
constexpr
index_t
kMPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_M0
*
Problem
::
BlockFmhaShape
::
Warp_M0
;
constexpr
index_t
kNPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_N0
*
Problem
::
BlockFmhaShape
::
Warp_N0
;
constexpr
index_t
kMPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_M0
*
Problem
::
BlockFmhaShape
::
Warp_M0
;
constexpr
index_t
kNPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_N0
*
Problem
::
BlockFmhaShape
::
Warp_N0
;
return
(
kMPerStep
+
1
)
*
kNPerStep
*
sizeof
(
uint8_t
);
}
...
...
@@ -612,16 +614,16 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBlockGemmAccTile_1
()
{
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_1
())
::
CWarpDstrEncoding
;
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M1
,
Problem
::
BlockFmhaShape
::
Block_N1
>
;
using
BlockWarps_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block
Warps
_M1
,
Problem
::
BlockFmhaShape
::
Block
Warps
_N1
>
;
using
WarpTile
_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M1
,
Problem
::
BlockFmhaShape
::
Warp_N1
>
;
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M1
,
Problem
::
BlockFmhaShape
::
Block_N1
>
;
using
BlockWarps
_
=
sequence
<
Problem
::
BlockFmhaShape
::
BlockWarps_M1
,
Problem
::
BlockFmhaShape
::
BlockWarps_N1
>
;
using
WarpTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M1
,
Problem
::
BlockFmhaShape
::
Warp_N1
>
;
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
OaccDataType
>
(
dstr
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
OaccDataType
>
(
dstr
);
return
t
;
}
};
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
02f8c487
...
...
@@ -43,15 +43,15 @@ struct TileFmhaShape
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
;
// gemm-0 shapes TODO: naming?
static
constexpr
index_t
Block_M0
=
kM0
;
static
constexpr
index_t
Block_N0
=
kN0
;
static
constexpr
index_t
Block_K0
=
kK0
;
static
constexpr
index_t
Block_M0
=
kM0
;
static
constexpr
index_t
Block_N0
=
kN0
;
static
constexpr
index_t
Block_K0
=
kK0
;
static
constexpr
index_t
BlockWarps_M0
=
Gemm0BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
BlockWarps_N0
=
Gemm0BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
BlockWarps_K0
=
Gemm0BlockWarps
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warps_M0
=
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warps_N0
=
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warps_K0
=
Gemm0WarpTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warps_M0
=
Gemm0WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warps_N0
=
Gemm0WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warps_K0
=
Gemm0WarpTile
::
at
(
number
<
2
>
{});
static_assert
(
Block_M0
%
(
BlockWarps_M0
*
Warps_M0
)
==
0
);
static_assert
(
Block_N0
%
(
BlockWarps_N0
*
Warps_N0
)
==
0
);
static_assert
(
Block_K0
%
(
BlockWarps_K0
*
Warps_K0
)
==
0
);
...
...
@@ -59,15 +59,15 @@ struct TileFmhaShape
static
constexpr
index_t
Repeat_N0
=
Block_N0
/
(
BlockWarps_N0
*
Warps_N0
);
static
constexpr
index_t
Repeat_K0
=
Block_K0
/
(
BlockWarps_K0
*
Warps_K0
);
static
constexpr
index_t
Block_M1
=
kM0
;
static
constexpr
index_t
Block_N1
=
kN1
;
static
constexpr
index_t
Block_K1
=
kK1
;
static
constexpr
index_t
Block_M1
=
kM0
;
static
constexpr
index_t
Block_N1
=
kN1
;
static
constexpr
index_t
Block_K1
=
kK1
;
static
constexpr
index_t
BlockWarps_M1
=
Gemm1BlockWarps
::
at
(
number
<
0
>
{});
static
constexpr
index_t
BlockWarps_N1
=
Gemm1BlockWarps
::
at
(
number
<
1
>
{});
static
constexpr
index_t
BlockWarps_K1
=
Gemm1BlockWarps
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warps_M1
=
Gemm1WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warps_N1
=
Gemm1WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warps_K1
=
Gemm1WarpTile
::
at
(
number
<
2
>
{});
static
constexpr
index_t
Warps_M1
=
Gemm1WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warps_N1
=
Gemm1WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
Warps_K1
=
Gemm1WarpTile
::
at
(
number
<
2
>
{});
static_assert
(
Block_M1
%
(
BlockWarps_M1
*
Warps_M1
)
==
0
);
static_assert
(
Block_N1
%
(
BlockWarps_N1
*
Warps_N1
)
==
0
);
static_assert
(
Block_K1
%
(
BlockWarps_K1
*
Warps_K1
)
==
0
);
...
...
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
View file @
02f8c487
...
...
@@ -7,10 +7,10 @@
namespace
ck_tile
{
template
<
typename
AccWarpDescEnc
,
typename
BlockTile
,
// seq<M, N>
typename
BlockWarps
,
typename
WarpTile
>
template
<
typename
AccWarpDescEnc
,
typename
BlockTile
,
// seq<M, N>
typename
BlockWarps
,
typename
WarpTile
>
CK_TILE_DEVICE_HOST
constexpr
auto
make_block_gemm_acc_enc
()
{
constexpr
index_t
Block_M
=
BlockTile
::
at
(
number
<
0
>
{});
...
...
include/ck_tile/ops/image_to_column.hpp
View file @
02f8c487
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c)
2018-
2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
View file @
02f8c487
...
...
@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
#else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment