Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
02f8c487
"profiler/src/main.cpp" did not exist on "67423a22754e7879893827eabe2c25f3bfc5227b"
Commit
02f8c487
authored
Oct 09, 2024
by
carlushuang
Browse files
add single issue api
parent
bafb600b
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
530 additions
and
436 deletions
+530
-436
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+54
-16
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+30
-18
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+253
-245
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+73
-42
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+10
-6
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
...ile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
+67
-69
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
.../fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
+22
-20
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+12
-12
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
+4
-4
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+2
-1
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+1
-1
No files found.
include/ck_tile/core/tensor/load_tile.hpp
View file @
02f8c487
...
@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
...
@@ -21,28 +21,32 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
T
,
template
<
typename
T
,
...
@@ -50,6 +54,7 @@ template <typename T,
...
@@ -50,6 +54,7 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -57,10 +62,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile_window
.
load_raw
(
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
template
<
typename
T
,
...
@@ -68,6 +75,7 @@ template <typename T,
...
@@ -68,6 +75,7 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
...
@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -75,10 +83,12 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile_window
.
load_raw
(
tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
// for this API we force user to use CK_TILE_LDS_ADDR attribute specified smem
...
@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
...
@@ -89,6 +99,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
...
@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
...
@@ -96,9 +107,11 @@ async_load_tile(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
async_load
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
...
@@ -106,15 +119,18 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_tile
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
async_load
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
async_load
(
lds_tile
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
...
@@ -122,6 +138,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
...
@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -130,11 +147,14 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
...
@@ -142,6 +162,7 @@ template <typename LdsTileWindow_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
...
@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -149,27 +170,44 @@ CK_TILE_DEVICE auto async_load_tile_raw(LdsTileWindow_&& lds_tile,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
LinearBottomDims_
>&
tile_window
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
WindowLengths
>
template
<
typename
WindowLengths
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
)
CK_TILE_DEVICE
auto
load_tile
(
const
null_tile_window
<
WindowLengths
>&
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
null_tensor
{};
return
null_tensor
{};
}
}
template
<
typename
T
,
typename
WindowLengths
>
template
<
typename
T
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
/*null_tile*/
,
const
null_tile_window
<
WindowLengths
>&
)
typename
WindowLengths
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
/*null_tile*/
,
const
null_tile_window
<
WindowLengths
>&
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
}
}
// TODO: this function requires some sub-fileds exist for the target tile window
// TODO: this function requires some sub-fileds exist for the target tile window
template
<
typename
TileWindow
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
TileWindow
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
const
TileWindow
&
w
,
CK_TILE_DEVICE
auto
load_tile_raw
(
const
TileWindow
&
w
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
...
@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
...
@@ -178,7 +216,7 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load_tile_raw
(
t
,
w
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
load_tile_raw
(
t
,
w
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
t
;
return
t
;
}
}
...
...
include/ck_tile/core/tensor/store_tile.hpp
View file @
02f8c487
...
@@ -18,10 +18,12 @@ namespace ck_tile {
...
@@ -18,10 +18,12 @@ namespace ck_tile {
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
store_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
...
@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
...
@@ -35,16 +37,18 @@ store_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>& t
tile_window_tmp
.
get_window_origin
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_dstr
);
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
store_tile_raw
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
...
@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
...
@@ -58,63 +62,71 @@ store_tile_raw(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_
tile_window_tmp
.
get_window_origin
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_dstr
);
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
store_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
store_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
store_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile
(
CK_TILE_DEVICE
void
store_tile
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
LinearBottomDims_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_tile_raw
(
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_window.hpp
View file @
02f8c487
...
@@ -18,6 +18,23 @@
...
@@ -18,6 +18,23 @@
namespace
ck_tile
{
namespace
ck_tile
{
// TODO: NumCoord no need anymore?
#define WINDOW_DISPATCH_ISSUE_2() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumCoord, 1>{}([&](auto iCoord) { \
static_for<0, NumAccessPerCoord, 1>{}( \
[&](auto iCoordAccess) { issue(iCoord, iCoordAccess); }); \
}); \
} \
else \
{ \
static_assert(i_access < (NumCoord * NumAccessPerCoord)); \
constexpr auto iCoordAccess = number<i_access % NumAccessPerCoord>{}; \
constexpr auto iCoord = number<i_access / NumAccessPerCoord>{}; \
issue(iCoord, iCoordAccess); \
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
StaticTileDistribution_
,
...
@@ -283,8 +300,8 @@ struct tile_window_with_static_distribution
...
@@ -283,8 +300,8 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -296,12 +313,11 @@ struct tile_window_with_static_distribution
...
@@ -296,12 +313,11 @@ struct tile_window_with_static_distribution
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -316,20 +332,17 @@ struct tile_window_with_static_distribution
...
@@ -316,20 +332,17 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
});
#else
#else
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
...
@@ -347,14 +360,19 @@ struct tile_window_with_static_distribution
...
@@ -347,14 +360,19 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
});
};
});
WINDOW_DISPATCH_ISSUE_2
();
return
dst_tensor
;
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -377,12 +395,11 @@ struct tile_window_with_static_distribution
...
@@ -377,12 +395,11 @@ struct tile_window_with_static_distribution
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
...
@@ -393,8 +410,7 @@ struct tile_window_with_static_distribution
...
@@ -393,8 +410,7 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
...
@@ -405,8 +421,7 @@ struct tile_window_with_static_distribution
...
@@ -405,8 +421,7 @@ struct tile_window_with_static_distribution
pre_nop_
);
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
#endif
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -419,17 +434,18 @@ struct tile_window_with_static_distribution
...
@@ -419,17 +434,18 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
});
};
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
WINDOW_DISPATCH_ISSUE_2
();
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
}
// TODO: currently async load only implemented in inline asm
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -470,12 +486,11 @@ struct tile_window_with_static_distribution
...
@@ -470,12 +486,11 @@ struct tile_window_with_static_distribution
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
...
@@ -501,12 +516,14 @@ struct tile_window_with_static_distribution
...
@@ -501,12 +516,14 @@ struct tile_window_with_static_distribution
m0_inc_with_memory
(
size_per_issue
);
m0_inc_with_memory
(
size_per_issue
);
}
}
});
};
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
...
@@ -544,12 +561,11 @@ struct tile_window_with_static_distribution
...
@@ -544,12 +561,11 @@ struct tile_window_with_static_distribution
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
// read from bottom tensor
...
@@ -569,12 +585,13 @@ struct tile_window_with_static_distribution
...
@@ -569,12 +585,13 @@ struct tile_window_with_static_distribution
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
}
})
;
}
;
}
);
WINDOW_DISPATCH_ISSUE_2
(
);
}
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -586,12 +603,11 @@ struct tile_window_with_static_distribution
...
@@ -586,12 +603,11 @@ struct tile_window_with_static_distribution
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -604,13 +620,11 @@ struct tile_window_with_static_distribution
...
@@ -604,13 +620,11 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
...
@@ -620,10 +634,7 @@ struct tile_window_with_static_distribution
...
@@ -620,10 +634,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -636,12 +647,13 @@ struct tile_window_with_static_distribution
...
@@ -636,12 +647,13 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
})
;
}
;
}
);
WINDOW_DISPATCH_ISSUE_2
(
);
}
}
CK_TILE_DEVICE
void
template
<
index_t
i_access
=
-
1
>
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -652,12 +664,11 @@ struct tile_window_with_static_distribution
...
@@ -652,12 +664,11 @@ struct tile_window_with_static_distribution
static
constexpr
bool
oob_conditional_check
=
true
;
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -668,12 +679,10 @@ struct tile_window_with_static_distribution
...
@@ -668,12 +679,10 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -694,12 +703,14 @@ struct tile_window_with_static_distribution
...
@@ -694,12 +703,14 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
});
};
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -710,12 +721,11 @@ struct tile_window_with_static_distribution
...
@@ -710,12 +721,11 @@ struct tile_window_with_static_distribution
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -727,13 +737,11 @@ struct tile_window_with_static_distribution
...
@@ -727,13 +737,11 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
...
@@ -741,10 +749,7 @@ struct tile_window_with_static_distribution
...
@@ -741,10 +749,7 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -757,8 +762,9 @@ struct tile_window_with_static_distribution
...
@@ -757,8 +762,9 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
});
};
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
...
@@ -857,6 +863,8 @@ struct tile_window_with_static_distribution
...
@@ -857,6 +863,8 @@ struct tile_window_with_static_distribution
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
};
};
#undef WINDOW_DISPATCH_ISSUE_2
// TODO: use strategy
// TODO: use strategy
template
<
typename
TensorView_
,
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
02f8c487
...
@@ -18,6 +18,17 @@
...
@@ -18,6 +18,17 @@
namespace
ck_tile
{
namespace
ck_tile
{
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
//
//
// This version of tile window will pre-cache offset/flags based on need
// This version of tile window will pre-cache offset/flags based on need
//
//
...
@@ -443,8 +454,8 @@ struct tile_window_linear
...
@@ -443,8 +454,8 @@ struct tile_window_linear
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
traits
::
NumAccess
;
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
vector_t
=
typename
traits
::
vector_t
;
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
...
@@ -453,9 +464,8 @@ struct tile_window_linear
...
@@ -453,9 +464,8 @@ struct tile_window_linear
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
...
@@ -494,17 +504,22 @@ struct tile_window_linear
...
@@ -494,17 +504,22 @@ struct tile_window_linear
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
#endif
});
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
// negative means loop over all num_access
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
using
vector_t
=
typename
traits
::
vector_t
;
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
static
constexpr
index_t
YElementSize
=
...
@@ -516,11 +531,10 @@ struct tile_window_linear
...
@@ -516,11 +531,10 @@ struct tile_window_linear
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access
==
0
&&
if
constexpr
(
pre_nop
&&
i_access
_
==
0
&&
BottomTensorView
::
buffer_view
::
get_address_space
()
==
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
)
address_space_enum
::
global
)
return
bool_constant
<
true
>
{};
return
bool_constant
<
true
>
{};
...
@@ -550,16 +564,18 @@ struct tile_window_linear
...
@@ -550,16 +564,18 @@ struct tile_window_linear
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
#endif
});
};
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
WINDOW_DISPATCH_ISSUE
();
"scratch memory"
::
);
#endif
}
}
// TODO: currently async load only implemented in inline asm
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -600,10 +616,10 @@ struct tile_window_linear
...
@@ -600,10 +616,10 @@ struct tile_window_linear
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access
==
0
)
if
constexpr
(
pre_nop
&&
i_access
_
==
0
)
return
bool_constant
<
true
>
{};
return
bool_constant
<
true
>
{};
else
else
return
bool_constant
<
false
>
{};
return
bool_constant
<
false
>
{};
...
@@ -618,15 +634,18 @@ struct tile_window_linear
...
@@ -618,15 +634,18 @@ struct tile_window_linear
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
pre_nop_
);
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
pre_nop_
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
if
constexpr
(
i_access
_
!=
(
NumAccess
-
1
))
{
{
m0_inc_with_memory
(
size_per_issue
);
m0_inc_with_memory
(
size_per_issue
);
}
}
});
};
WINDOW_DISPATCH_ISSUE
();
}
}
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
...
@@ -667,8 +686,8 @@ struct tile_window_linear
...
@@ -667,8 +686,8 @@ struct tile_window_linear
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
...
@@ -682,15 +701,18 @@ struct tile_window_linear
...
@@ -682,15 +701,18 @@ struct tile_window_linear
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
if
constexpr
(
i_access
_
!=
(
NumAccess
-
1
))
{
{
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
}
});
};
WINDOW_DISPATCH_ISSUE
();
}
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
...
@@ -700,8 +722,8 @@ struct tile_window_linear
...
@@ -700,8 +722,8 @@ struct tile_window_linear
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
...
@@ -732,13 +754,15 @@ struct tile_window_linear
...
@@ -732,13 +754,15 @@ struct tile_window_linear
bottom_tensor_flag
,
bottom_tensor_flag
,
vec_value
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
});
};
WINDOW_DISPATCH_ISSUE
();
}
}
CK_TILE_DEVICE
void
template
<
index_t
i_access
=
-
1
>
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{})
const
{
{
using
vector_t
=
typename
traits
::
vector_t
;
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
...
@@ -746,8 +770,8 @@ struct tile_window_linear
...
@@ -746,8 +770,8 @@ struct tile_window_linear
static
constexpr
bool
oob_conditional_check
=
true
;
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
...
@@ -773,11 +797,14 @@ struct tile_window_linear
...
@@ -773,11 +797,14 @@ struct tile_window_linear
get_bottom_tensor_view
()
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
);
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
);
});
};
WINDOW_DISPATCH_ISSUE
();
}
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
...
@@ -787,8 +814,8 @@ struct tile_window_linear
...
@@ -787,8 +814,8 @@ struct tile_window_linear
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumAccess
,
1
>
{}(
[
&
](
auto
i_access
)
{
auto
issue
=
[
&
](
auto
i_access
_
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
IAccess
=
number
<
i_access
_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
...
@@ -820,7 +847,9 @@ struct tile_window_linear
...
@@ -820,7 +847,9 @@ struct tile_window_linear
bottom_tensor_flag
,
bottom_tensor_flag
,
vec_value
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
});
};
WINDOW_DISPATCH_ISSUE
();
}
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
...
@@ -920,6 +949,8 @@ struct tile_window_linear
...
@@ -920,6 +949,8 @@ struct tile_window_linear
array
<
bool
,
traits
::
NumAccess
>
cached_flags_
;
array
<
bool
,
traits
::
NumAccess
>
cached_flags_
;
};
};
#undef WINDOW_DISPATCH_ISSUE
namespace
impl
{
namespace
impl
{
template
<
address_space_enum
,
index_t
len_
>
template
<
address_space_enum
,
index_t
len_
>
struct
default_linear_bottom_dims_impl
struct
default_linear_bottom_dims_impl
...
...
include/ck_tile/core/tensor/update_tile.hpp
View file @
02f8c487
...
@@ -17,10 +17,12 @@ namespace ck_tile {
...
@@ -17,10 +17,12 @@ namespace ck_tile {
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
update_tile
(
tile_window_with_static_lengths
<
BottomTensorView_
,
WindowLengths_
>&
tile_window_tmp
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView_
::
DataType
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
using
TileDstr
=
remove_cvref_t
<
TileDistribution_
>
;
...
@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
...
@@ -34,22 +36,24 @@ update_tile(tile_window_with_static_lengths<BottomTensorView_, WindowLengths_>&
tile_window_tmp
.
get_window_origin
(),
tile_window_tmp
.
get_window_origin
(),
tile_dstr
);
tile_dstr
);
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{})
{
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{}
);
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
02f8c487
...
@@ -33,8 +33,8 @@
...
@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
//
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp
View file @
02f8c487
...
@@ -205,7 +205,7 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -205,7 +205,7 @@ struct BlockFmhaPipelineQRAsyncEx
"wrong!"
);
"wrong!"
);
// K tile in LDS
// K tile in LDS
auto
k_lds_store
=
[
&
](){
auto
k_lds_store
=
[
&
]()
{
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i_buf
)
{
[
&
](
auto
i_buf
)
{
...
@@ -218,17 +218,16 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -218,17 +218,16 @@ struct BlockFmhaPipelineQRAsyncEx
number
<
Policy
::
NumPrefetchK
>
{});
number
<
Policy
::
NumPrefetchK
>
{});
}();
}();
auto
k_lds_load
=
[
&
](){
auto
k_lds_load
=
[
&
]()
{
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
auto
k_lds_ptr
=
reinterpret_cast
<
KDataType
*>
(
smem_ptr
);
return
make_tile_window
(
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
make_tensor_view
<
address_space_enum
::
lds
>
(
k_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>()),
k_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>().
get_lengths
(),
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>()),
{
0
,
0
});
Policy
::
template
MakeSmemLoadDesc_K
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
}();
// V tile in LDS
// V tile in LDS
auto
v_lds_store
=
[
&
](){
auto
v_lds_store
=
[
&
]()
{
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
generate_tuple
(
return
generate_tuple
(
[
&
](
auto
i_buf
)
{
[
&
](
auto
i_buf
)
{
...
@@ -241,13 +240,12 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -241,13 +240,12 @@ struct BlockFmhaPipelineQRAsyncEx
number
<
Policy
::
NumPrefetchV
>
{});
number
<
Policy
::
NumPrefetchV
>
{});
}();
}();
auto
v_lds_load
=
[
&
](){
auto
v_lds_load
=
[
&
]()
{
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
auto
v_lds_ptr
=
reinterpret_cast
<
VDataType
*>
(
smem_ptr
);
return
make_tile_window
(
return
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
make_tensor_view
<
address_space_enum
::
lds
>
(
v_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>()),
v_lds_ptr
,
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>().
get_lengths
(),
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>()),
{
0
,
0
});
Policy
::
template
MakeSmemLoadDesc_V
<
Problem
>().
get_lengths
(),
{
0
,
0
});
}();
}();
// reduction function for softmax
// reduction function for softmax
...
@@ -258,16 +256,14 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -258,16 +256,14 @@ struct BlockFmhaPipelineQRAsyncEx
constexpr
auto
warp_gemm_0
=
Policy
::
template
GetWarpGemm_0
<
Problem
>();
constexpr
auto
warp_gemm_0
=
Policy
::
template
GetWarpGemm_0
<
Problem
>();
constexpr
auto
warp_gemm_1
=
Policy
::
template
GetWarpGemm_1
<
Problem
>();
constexpr
auto
warp_gemm_1
=
Policy
::
template
GetWarpGemm_1
<
Problem
>();
auto
gemm_0
=
[
&
](){
auto
gemm_0
=
[
&
]()
{
constexpr
index_t
total_repeats
=
Repeat_M0
*
Repeat_N0
*
Repeat_K0
;
constexpr
index_t
total_repeats
=
Repeat_M0
*
Repeat_N0
*
Repeat_K0
;
// n*k*m, more relaxed ds_read
// n*k*m, more relaxed ds_read
static_for
<
0
,
total_repeats
,
1
>
{}(
static_for
<
0
,
total_repeats
,
1
>
{}([
&
](
auto
i_r
)
{
[
&
](
auto
i_r
){
constexpr
index_t
i_m
=
i_r
%
Repeat_M0
;
constexpr
index_t
i_m
=
i_r
%
Repeat_M0
;
constexpr
index_t
i_k
=
(
i_r
/
Repeat_M0
)
%
Repeat_K0
;
constexpr
index_t
i_k
=
(
i_r
/
Repeat_M0
)
%
Repeat_K0
;
constexpr
index_t
i_n
=
i_r
/
(
Repeat_M0
*
Repeat_K0
);
constexpr
index_t
i_n
=
i_r
/
(
Repeat_M0
*
Repeat_K0
);
}
});
);
};
};
auto
q_dram_window
=
make_tile_window_raw
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
auto
q_dram_window
=
make_tile_window_raw
(
q_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -285,7 +281,8 @@ struct BlockFmhaPipelineQRAsyncEx
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
using
SaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_0
<
Problem
>());
using
SaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_0
<
Problem
>());
auto
s_accs
=
generate_tuple
([
&
](
auto
)
{
return
SaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
s_accs
=
generate_tuple
([
&
](
auto
)
{
return
SaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
// infer Sacc, S, P, M, L, Oacc type
// infer Sacc, S, P, M, L, Oacc type
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_accs
));
using
SBlockTileType
=
decltype
(
cast_tile
<
SMPLComputeDataType
>
(
s_accs
));
...
@@ -296,7 +293,8 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -296,7 +293,8 @@ struct BlockFmhaPipelineQRAsyncEx
using
OaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_1
<
Problem
>());
using
OaccBlockTileType
=
decltype
(
Policy
::
template
MakeBlockGemmAccTile_1
<
Problem
>());
// init Oacc, M, L
// init Oacc, M, L
auto
o_accs
=
generate_tuple
([
&
](
auto
)
{
return
OaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
o_accs
=
generate_tuple
([
&
](
auto
)
{
return
OaccBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ms
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ms
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ls
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
auto
ls
=
generate_tuple
([
&
](
auto
)
{
return
MLBlockTileType
{};
},
number
<
UnrollStages
>
{});
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp
View file @
02f8c487
...
@@ -105,14 +105,14 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -105,14 +105,14 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBlockGemmAccTile_0
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBlockGemmAccTile_0
()
{
{
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_0
())
::
CWarpDstrEncoding
;
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_0
())
::
CWarpDstrEncoding
;
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M0
,
Problem
::
BlockFmhaShape
::
Block_N0
>
;
using
BlockTile_
=
using
BlockWarps_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block
Warps
_M0
,
Problem
::
BlockFmhaShape
::
Block
Warps
_N0
>
;
sequence
<
Problem
::
BlockFmhaShape
::
Block_M0
,
Problem
::
BlockFmhaShape
::
Block_N0
>
;
using
WarpTile
_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M0
,
Problem
::
BlockFmhaShape
::
Warp_N0
>
;
using
BlockWarps
_
=
sequence
<
Problem
::
BlockFmhaShape
::
BlockWarps_M0
,
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
Problem
::
BlockFmhaShape
::
BlockWarps_N0
>
;
AccWarpDescEnc_
,
using
WarpTile_
=
BlockTile_
,
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M0
,
Problem
::
BlockFmhaShape
::
Warp_N0
>
;
BlockWarps_
,
constexpr
auto
enc
=
WarpTile_
>
();
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
SaccDataType
>
(
dstr
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
SaccDataType
>
(
dstr
);
return
t
;
return
t
;
...
@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -443,8 +443,10 @@ struct BlockFmhaPipelineQRAsyncEx
{
{
if
constexpr
(
Problem
::
kHasDropout
)
if
constexpr
(
Problem
::
kHasDropout
)
{
{
constexpr
index_t
kMPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_M0
*
Problem
::
BlockFmhaShape
::
Warp_M0
;
constexpr
index_t
kMPerStep
=
constexpr
index_t
kNPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_N0
*
Problem
::
BlockFmhaShape
::
Warp_N0
;
Problem
::
BlockFmhaShape
::
BlockWarps_M0
*
Problem
::
BlockFmhaShape
::
Warp_M0
;
constexpr
index_t
kNPerStep
=
Problem
::
BlockFmhaShape
::
BlockWarps_N0
*
Problem
::
BlockFmhaShape
::
Warp_N0
;
return
(
kMPerStep
+
1
)
*
kNPerStep
*
sizeof
(
uint8_t
);
return
(
kMPerStep
+
1
)
*
kNPerStep
*
sizeof
(
uint8_t
);
}
}
...
@@ -612,14 +614,14 @@ struct BlockFmhaPipelineQRAsyncEx
...
@@ -612,14 +614,14 @@ struct BlockFmhaPipelineQRAsyncEx
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBlockGemmAccTile_1
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBlockGemmAccTile_1
()
{
{
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_1
())
::
CWarpDstrEncoding
;
using
AccWarpDescEnc_
=
typename
decltype
(
GetWarpGemm_1
())
::
CWarpDstrEncoding
;
using
BlockTile_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block_M1
,
Problem
::
BlockFmhaShape
::
Block_N1
>
;
using
BlockTile_
=
using
BlockWarps_
=
sequence
<
Problem
::
BlockFmhaShape
::
Block
Warps
_M1
,
Problem
::
BlockFmhaShape
::
Block
Warps
_N1
>
;
sequence
<
Problem
::
BlockFmhaShape
::
Block_M1
,
Problem
::
BlockFmhaShape
::
Block_N1
>
;
using
WarpTile
_
=
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M1
,
Problem
::
BlockFmhaShape
::
Warp_N1
>
;
using
BlockWarps
_
=
sequence
<
Problem
::
BlockFmhaShape
::
BlockWarps_M1
,
constexpr
auto
enc
=
make_block_gemm_acc_enc
<
Problem
::
BlockFmhaShape
::
BlockWarps_N1
>
;
AccWarpDescEnc_
,
using
WarpTile_
=
BlockTile_
,
sequence
<
Problem
::
BlockFmhaShape
::
Warp_M1
,
Problem
::
BlockFmhaShape
::
Warp_N1
>
;
BlockWarps_
,
constexpr
auto
enc
=
WarpTile_
>
();
make_block_gemm_acc_enc
<
AccWarpDescEnc_
,
BlockTile_
,
BlockWarps_
,
WarpTile_
>
();
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
constexpr
auto
dstr
=
make_static_tile_distribution
(
enc
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
OaccDataType
>
(
dstr
);
auto
t
=
make_static_distributed_tensor
<
typename
Problem
::
OaccDataType
>
(
dstr
);
return
t
;
return
t
;
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
02f8c487
include/ck_tile/ops/gemm/block/block_gemm_utils.hpp
View file @
02f8c487
...
@@ -7,7 +7,7 @@
...
@@ -7,7 +7,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
AccWarpDescEnc
,
template
<
typename
AccWarpDescEnc
,
typename
BlockTile
,
// seq<M, N>
typename
BlockTile
,
// seq<M, N>
typename
BlockWarps
,
typename
BlockWarps
,
typename
WarpTile
>
typename
WarpTile
>
...
...
include/ck_tile/ops/image_to_column.hpp
View file @
02f8c487
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c)
2018-
2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
View file @
02f8c487
...
@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
...
@@ -56,7 +56,7 @@ struct TopkSoftmaxWarpPerRowPipeline
{
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
#else
#else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment