Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
c54a975f
Commit
c54a975f
authored
Oct 10, 2024
by
carlushuang
Browse files
do not support single issue in old tile window
parent
6a25d081
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
261 additions
and
255 deletions
+261
-255
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+2
-1
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+257
-253
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+2
-1
No files found.
include/ck_tile/core/tensor/load_tile.hpp
View file @
c54a975f
...
@@ -216,7 +216,8 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
...
@@ -216,7 +216,8 @@ CK_TILE_DEVICE auto load_tile_raw(const TileWindow& w,
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
auto
t
=
make_static_distributed_tensor
<
DataType
>
(
TileDstr
{});
load_tile_raw
(
t
,
w
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
load_tile_raw
(
t
,
w
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
return
t
;
return
t
;
}
}
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
c54a975f
...
@@ -18,23 +18,8 @@
...
@@ -18,23 +18,8 @@
namespace
ck_tile
{
namespace
ck_tile
{
// TODO: NumCoord no need anymore?
// Note: this tile window do not support single issue
#define WINDOW_DISPATCH_ISSUE_2() \
// you need to use tile_window_linear structure for this purpose
if constexpr(i_access < 0) \
{ \
static_for<0, NumCoord, 1>{}([&](auto iCoord) { \
static_for<0, NumAccessPerCoord, 1>{}( \
[&](auto iCoordAccess) { issue(iCoord, iCoordAccess); }); \
}); \
} \
else \
{ \
static_assert(i_access < (NumCoord * NumAccessPerCoord)); \
constexpr auto iCoordAccess = number<i_access % NumAccessPerCoord>{}; \
constexpr auto iCoord = number<i_access / NumAccessPerCoord>{}; \
issue(iCoord, iCoordAccess); \
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
StaticTileDistribution_
,
...
@@ -300,8 +285,9 @@ struct tile_window_with_static_distribution
...
@@ -300,8 +285,9 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
auto
load
(
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -313,11 +299,12 @@ struct tile_window_with_static_distribution
...
@@ -313,11 +299,12 @@ struct tile_window_with_static_distribution
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -332,17 +319,20 @@ struct tile_window_with_static_distribution
...
@@ -332,17 +319,20 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
});
#else
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
...
@@ -360,19 +350,18 @@ struct tile_window_with_static_distribution
...
@@ -360,19 +350,18 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
};
});
});
WINDOW_DISPATCH_ISSUE_2
();
return
dst_tensor
;
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
index_t
i_access
_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
number
<
i_access
_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -395,11 +384,12 @@ struct tile_window_with_static_distribution
...
@@ -395,11 +384,12 @@ struct tile_window_with_static_distribution
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
...
@@ -410,7 +400,8 @@ struct tile_window_with_static_distribution
...
@@ -410,7 +400,8 @@ struct tile_window_with_static_distribution
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
static_assert
(
d
%
Traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
...
@@ -421,7 +412,8 @@ struct tile_window_with_static_distribution
...
@@ -421,7 +412,8 @@ struct tile_window_with_static_distribution
pre_nop_
);
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
#endif
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -434,18 +426,17 @@ struct tile_window_with_static_distribution
...
@@ -434,18 +426,17 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
};
});
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
// TODO: currently async load only implemented in inline asm
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
index_t
i_access
_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
number
<
i_access
_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -486,11 +477,12 @@ struct tile_window_with_static_distribution
...
@@ -486,11 +477,12 @@ struct tile_window_with_static_distribution
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
//
/
TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
...
@@ -516,14 +508,15 @@ struct tile_window_with_static_distribution
...
@@ -516,14 +508,15 @@ struct tile_window_with_static_distribution
m0_inc_with_memory
(
size_per_issue
);
m0_inc_with_memory
(
size_per_issue
);
}
}
};
});
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
number
<
i_access
_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
...
@@ -561,11 +554,12 @@ struct tile_window_with_static_distribution
...
@@ -561,11 +554,12 @@ struct tile_window_with_static_distribution
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
//
/
TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
// read from bottom tensor
...
@@ -585,13 +579,13 @@ struct tile_window_with_static_distribution
...
@@ -585,13 +579,13 @@ struct tile_window_with_static_distribution
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
}
}
;
})
;
WINDOW_DISPATCH_ISSUE_2
(
);
}
);
}
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
number
<
i_access
_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -603,11 +597,11 @@ struct tile_window_with_static_distribution
...
@@ -603,11 +597,11 @@ struct tile_window_with_static_distribution
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoordAccess
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -620,11 +614,13 @@ struct tile_window_with_static_distribution
...
@@ -620,11 +614,13 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
...
@@ -634,7 +630,10 @@ struct tile_window_with_static_distribution
...
@@ -634,7 +630,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -647,13 +646,13 @@ struct tile_window_with_static_distribution
...
@@ -647,13 +646,13 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
}
;
})
;
WINDOW_DISPATCH_ISSUE_2
(
);
}
);
}
}
template
<
index_t
i_access
=
-
1
>
template
<
index_t
i_access
_unsupport_
=
-
1
>
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{})
const
number
<
i_access
_unsupport_
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -664,11 +663,12 @@ struct tile_window_with_static_distribution
...
@@ -664,11 +663,12 @@ struct tile_window_with_static_distribution
static
constexpr
bool
oob_conditional_check
=
true
;
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -679,10 +679,12 @@ struct tile_window_with_static_distribution
...
@@ -679,10 +679,12 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
});
...
@@ -703,14 +705,13 @@ struct tile_window_with_static_distribution
...
@@ -703,14 +705,13 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
};
});
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access
_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
number
<
i_access
_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -721,11 +722,12 @@ struct tile_window_with_static_distribution
...
@@ -721,11 +722,12 @@ struct tile_window_with_static_distribution
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
iCoord
,
auto
iCoord
Access
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
// data index [y0, y1, ...]
...
@@ -737,11 +739,13 @@ struct tile_window_with_static_distribution
...
@@ -737,11 +739,13 @@ struct tile_window_with_static_distribution
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_array
(
constexpr
auto
idx_ys
=
generate_array
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
},
number
<
NDimY
>
{});
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
...
@@ -749,7 +753,10 @@ struct tile_window_with_static_distribution
...
@@ -749,7 +753,10 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -762,9 +769,8 @@ struct tile_window_with_static_distribution
...
@@ -762,9 +769,8 @@ struct tile_window_with_static_distribution
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
};
});
});
WINDOW_DISPATCH_ISSUE_2
();
}
}
// move thread's botom tensor coordiante
// move thread's botom tensor coordiante
...
@@ -863,8 +869,6 @@ struct tile_window_with_static_distribution
...
@@ -863,8 +869,6 @@ struct tile_window_with_static_distribution
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
array
<
tuple
<
WindowAdaptorCoord
,
BottomTensorCoord
>
,
NumCoord
>
pre_computed_coords_
;
};
};
#undef WINDOW_DISPATCH_ISSUE_2
// TODO: use strategy
// TODO: use strategy
template
<
typename
TensorView_
,
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
...
...
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
View file @
c54a975f
...
@@ -56,7 +56,8 @@ struct TopkSoftmaxWarpPerRowPipeline
...
@@ -56,7 +56,8 @@ struct TopkSoftmaxWarpPerRowPipeline
{
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
#else
#else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment