Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
1f9546e0
Commit
1f9546e0
authored
Dec 12, 2024
by
root
Browse files
Merge branch 'develop' into gemm_bf16_sk_muozturk
parents
78394194
86990558
Changes
472
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
3048 additions
and
265 deletions
+3048
-265
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+234
-25
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+35
-123
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+252
-45
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+1207
-0
include/ck_tile/core/tensor/tile_window_utils.hpp
include/ck_tile/core/tensor/tile_window_utils.hpp
+54
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+53
-3
include/ck_tile/core/utility/amd_address_space.hpp
include/ck_tile/core/utility/amd_address_space.hpp
+37
-0
include/ck_tile/core/utility/functional_with_tuple.hpp
include/ck_tile/core/utility/functional_with_tuple.hpp
+173
-0
include/ck_tile/core/utility/literals.hpp
include/ck_tile/core/utility/literals.hpp
+22
-0
include/ck_tile/core/utility/magic_div.hpp
include/ck_tile/core/utility/magic_div.hpp
+22
-5
include/ck_tile/core/utility/reduce_operator.hpp
include/ck_tile/core/utility/reduce_operator.hpp
+95
-0
include/ck_tile/core/utility/static_counter.hpp
include/ck_tile/core/utility/static_counter.hpp
+116
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+9
-1
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+35
-0
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+175
-6
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+126
-18
include/ck_tile/host/joinable_thread.hpp
include/ck_tile/host/joinable_thread.hpp
+27
-0
include/ck_tile/host/reference/reference_elementwise.hpp
include/ck_tile/host/reference/reference_elementwise.hpp
+47
-0
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+196
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+133
-39
No files found.
Too many changes to show.
To preserve performance only
472 of 472+
files are displayed.
Plain diff
Email patch
include/ck_tile/core/tensor/tensor_view.hpp
View file @
1f9546e0
...
@@ -16,6 +16,24 @@
...
@@ -16,6 +16,24 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
* tensor_view
* abstract the underneath memory buffer(global, LDS, etc...)
* and provide a unified get/set function for access
*
* For addressing into the buffer we use 2 variable to control:
* coord : ND tensor coordinate, will calculate the actual offset inside
* linear_offset : 1D offset, will be used in the immediate field of
* the buffer instruction to help reduce register usage
*
* User can use either of the field, or both to indexing into the tensor
*
* We usually provide 2 set of API for buffer get/set, e.g.
* get_vectorized_elements()/get_vectorized_elements_raw()
* the former usually will call intrinsic or normal C function, the later
* usually will call inline-asm function
*
*/
template
<
typename
BufferView_
,
template
<
typename
BufferView_
,
typename
TensorDesc_
,
typename
TensorDesc_
,
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
...
@@ -49,22 +67,6 @@ struct tensor_view
...
@@ -49,22 +67,6 @@ struct tensor_view
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_buffer_view
()
{
return
buf_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_buffer_view
()
{
return
buf_
;
}
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
...
@@ -75,14 +77,34 @@ struct tensor_view
...
@@ -75,14 +77,34 @@ struct tensor_view
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
return
buf_
.
template
get
<
X
>(
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
// flag
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// X is vector of DataType.
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
...
@@ -94,27 +116,109 @@ struct tensor_view
...
@@ -94,27 +116,109 @@ struct tensor_view
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
dst
,
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
CK_TILE_HOST_DEVICE
constexpr
void
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{})
const
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
async_get_raw
<
X
>(
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
}
// X is vector of DataType.
// X is vector of DataType.
...
@@ -125,11 +229,15 @@ struct tensor_view
...
@@ -125,11 +229,15 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
CK_TILE_HOST_DEVICE
constexpr
void
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
...
@@ -140,15 +248,53 @@ struct tensor_view
...
@@ -140,15 +248,53 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
CK_TILE_HOST_DEVICE
constexpr
void
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
// X is vector of DataType.
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
...
@@ -157,15 +303,78 @@ struct tensor_view
...
@@ -157,15 +303,78 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
CK_TILE_HOST_DEVICE
constexpr
void
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
buf_
.
template
update_raw
<
DstInMemOp
,
X
,
oob_conditional_check
,
pre_nop
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
{
{
printf
(
"tensor_view{"
);
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
1f9546e0
...
@@ -17,6 +17,14 @@
...
@@ -17,6 +17,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
return
Distribution
::
_get_partition_index
();
}
}
// namespace detail
// distributed span
// distributed span
template
<
index_t
...
PartialHsLengths
>
template
<
index_t
...
PartialHsLengths
>
struct
tile_distributed_span
struct
tile_distributed_span
...
@@ -83,6 +91,21 @@ struct tile_distribution
...
@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
auto
_get_partition_index
()
{
// only support warp-tile and block-tile
static_assert
(
NDimP
==
1
or
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
{
#if 0
#if 0
...
@@ -149,6 +172,16 @@ struct tile_distribution
...
@@ -149,6 +172,16 @@ struct tile_distribution
}
}
#endif
#endif
template
<
typename
PartitionIndex
=
decltype
(
_get_partition_index
())>
CK_TILE_HOST_DEVICE
auto
calculate_index
(
const
PartitionIndex
&
ps_idx
=
_get_partition_index
())
const
{
const
auto
ps_ys_idx
=
container_concat
(
ps_idx
,
array
<
index_t
,
NDimY
>
{
0
});
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
ps_ys_to_xs_
,
ps_ys_idx
);
return
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
{
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
...
@@ -421,6 +454,7 @@ struct tile_distribution_detail
...
@@ -421,6 +454,7 @@ struct tile_distribution_detail
}
// namespace detail
}
// namespace detail
#if 0
// this returns a constexpr tile_distribution
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
...
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
...
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
}
#endif
// this returns a static tile_distribution
// this returns a static tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
template
<
typename
StaticTileDistributionEncoding_
>
...
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
...
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
//***********************************************************************************
//***********************************************************************************
namespace
detail
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
// only support warp-tile and block-tile
static_assert
(
Distribution
::
NDimP
==
1
or
Distribution
::
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
Distribution
::
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
Distribution
::
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
//
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// We don't support slice cross p_dim (aka, slice different threads)
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -18,6 +18,8 @@
...
@@ -18,6 +18,8 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Note: this tile window do not support single issue
// you need to use tile_window_linear structure for this purpose
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
StaticTileDistribution_
,
...
@@ -41,6 +43,7 @@ struct tile_window_with_static_distribution
...
@@ -41,6 +43,7 @@ struct tile_window_with_static_distribution
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static_assert
(
NumCoord
==
1
);
// TODO: check WindowLengths and StaticTileDistribution are consistent
// TODO: check WindowLengths and StaticTileDistribution are consistent
...
@@ -189,7 +192,8 @@ struct tile_window_with_static_distribution
...
@@ -189,7 +192,8 @@ struct tile_window_with_static_distribution
constexpr
auto
idx_diff_ys
=
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
@@ -222,10 +226,11 @@ struct tile_window_with_static_distribution
...
@@ -222,10 +226,11 @@ struct tile_window_with_static_distribution
// move thread's window adaptor coordinate and bottom tensor coordinate
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template
<
typename
ATopIndex
>
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
A
daptor
TopIndex
&
idx_diff_adaptor_top
)
const
const
ATopIndex
&
idx_diff_adaptor_top
)
const
{
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
...
@@ -279,20 +284,31 @@ struct tile_window_with_static_distribution
...
@@ -279,20 +284,31 @@ struct tile_window_with_static_distribution
get_container_subset
(
window_adaptor_ps_ys_vector_strides
,
y_dims
));
get_container_subset
(
window_adaptor_ps_ys_vector_strides
,
y_dims
));
}
}
CK_TILE_DEVICE
constexpr
auto
get_num_access
()
const
{
return
load_store_traits
::
NumAccess
;
}
CK_TILE_DEVICE
constexpr
auto
get_num_
of_
access
()
const
{
return
load_store_traits
::
NumAccess
;
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
bool_constant
<
oob_conditional_check
>
=
{})
const
CK_TILE_DEVICE
auto
load
(
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
load
(
dst_tensor
,
number
<
i_access_unsupport_
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
dst_tensor
;
}
template
<
typename
DistributedTensor
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
&
dst_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
/// TODO: use structure binding (to be captured later) if compiled in C++20
...
@@ -308,11 +324,11 @@ struct tile_window_with_static_distribution
...
@@ -308,11 +324,11 @@ struct tile_window_with_static_distribution
// read from bottom tensor
// read from bottom tensor
const
vector_t
vec_value
=
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
#if 1
// write into distributed tensor
// write into distributed tensor
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_
array
(
constexpr
auto
idx_ys
=
generate_
tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
...
@@ -338,20 +354,23 @@ struct tile_window_with_static_distribution
...
@@ -338,20 +354,23 @@ struct tile_window_with_static_distribution
{
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
});
});
});
});
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
DstTile
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -397,6 +416,7 @@ struct tile_window_with_static_distribution
...
@@ -397,6 +416,7 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bottom_tensor_thread_coord
,
0
/**/
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
...
@@ -409,23 +429,24 @@ struct tile_window_with_static_distribution
...
@@ -409,23 +429,24 @@ struct tile_window_with_static_distribution
{
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
}
});
});
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
}
// TODO: currently async load only implemented in inline asm
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
template
<
typename
LdsTileWindow_
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -467,7 +488,7 @@ struct tile_window_with_static_distribution
...
@@ -467,7 +488,7 @@ struct tile_window_with_static_distribution
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
// TODO: use structure binding (to be captured later) if compiled in C++20
//
/
TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
...
@@ -482,15 +503,16 @@ struct tile_window_with_static_distribution
...
@@ -482,15 +503,16 @@ struct tile_window_with_static_distribution
// read from bottom tensor
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
smem
,
bottom_tensor_thread_coord
,
0
,
pre_nop_
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
@@ -501,8 +523,81 @@ struct tile_window_with_static_distribution
...
@@ -501,8 +523,81 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}));
constexpr
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
-
size_per_buf
;
constexpr
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
});
});
}
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -515,7 +610,6 @@ struct tile_window_with_static_distribution
...
@@ -515,7 +610,6 @@ struct tile_window_with_static_distribution
// loop over thread tensor space [y0, y1, ...]
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
...
@@ -530,7 +624,7 @@ struct tile_window_with_static_distribution
...
@@ -530,7 +624,7 @@ struct tile_window_with_static_distribution
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_
array
(
constexpr
auto
idx_ys
=
generate_
tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
...
@@ -548,15 +642,19 @@ struct tile_window_with_static_distribution
...
@@ -548,15 +642,19 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
@@ -565,8 +663,9 @@ struct tile_window_with_static_distribution
...
@@ -565,8 +663,9 @@ struct tile_window_with_static_distribution
});
});
}
}
CK_TILE_DEVICE
void
template
<
index_t
i_access_unsupport_
=
-
1
>
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
)
const
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -591,7 +690,7 @@ struct tile_window_with_static_distribution
...
@@ -591,7 +690,7 @@ struct tile_window_with_static_distribution
// read from distributed tensor
// read from distributed tensor
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_
array
(
constexpr
auto
idx_ys
=
generate_
tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
...
@@ -606,15 +705,16 @@ struct tile_window_with_static_distribution
...
@@ -606,15 +705,16 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
()
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
vec_value
);
bottom_tensor_thread_coord
,
0
,
vec_value
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
@@ -623,8 +723,9 @@ struct tile_window_with_static_distribution
...
@@ -623,8 +723,9 @@ struct tile_window_with_static_distribution
});
});
}
}
template
<
bool
oob_conditional_check
=
true
>
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -650,7 +751,7 @@ struct tile_window_with_static_distribution
...
@@ -650,7 +751,7 @@ struct tile_window_with_static_distribution
vector_t
vec_value
;
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_
array
(
constexpr
auto
idx_ys
=
generate_
tuple
(
[
&
](
auto
jj
)
{
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
:
idx_ys_start
[
jj
];
...
@@ -666,15 +767,86 @@ struct tile_window_with_static_distribution
...
@@ -666,15 +767,86 @@ struct tile_window_with_static_distribution
// write into bottom tensor
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
@@ -746,7 +918,8 @@ struct tile_window_with_static_distribution
...
@@ -746,7 +918,8 @@ struct tile_window_with_static_distribution
constexpr
auto
idx_diff_ys
=
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
...
@@ -798,6 +971,27 @@ make_tile_window(const TensorView_& tensor_view,
...
@@ -798,6 +971,27 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
}
// this version can't be called in a constexpr context
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
index_t
NumCoord
=
1
>
CK_TILE_DEVICE
auto
make_tile_window_raw
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
number
<
NumCoord
>
=
{})
{
auto
w
=
tile_window_with_static_distribution
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
NumCoord
>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
w
.
init_raw
();
return
w
;
}
template
<
typename
TensorView_
,
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
StaticTileDistribution_
,
...
@@ -922,6 +1116,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
...
@@ -922,6 +1116,19 @@ make_tile_window(const tile_window_with_static_lengths<TensorView, WindowLengths
tile_distribution
);
tile_distribution
);
}
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_raw
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
auto
w
=
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
w
.
init_raw
();
return
w
;
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
#define WINDOW_DISPATCH_ISSUE() \
if constexpr(i_access < 0) \
{ \
static_for<0, NumAccess, 1>{}([&](auto ia) { issue(ia); }); \
} \
else \
{ \
static_assert(i_access < NumAccess); \
issue(number<i_access>{}); \
}
//
// This version of tile window will pre-cache offset/flags based on need
//
// LinearBottomDims_, e.g seq<0, 1> for 2d tensor, the last one is linear dim
// so last dim can use immediate offset to indexing, can save register
// TODO: if using this struct, better use load_raw()/store_raw(), can control
// the the immediate offset on the fly
// space-filing-curve is non-snaked here!
//
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
>
struct
tile_window_linear
{
using
BottomTensorView
=
remove_reference_t
<
BottomTensorView_
>
;
using
WindowLengths
=
remove_cvref_t
<
WindowLengths_
>
;
using
TileDstr
=
remove_cvref_t
<
StaticTileDistribution_
>
;
using
WindowAdaptor
=
typename
TileDstr
::
PsYs2XsAdaptor
;
using
BottomTensorDesc
=
typename
BottomTensorView
::
TensorDesc
;
using
DataType
=
remove_cvref_t
<
typename
BottomTensorView
::
DataType
>
;
using
LinearBottomDims
=
remove_cvref_t
<
LinearBottomDims_
>
;
static_assert
(
LinearBottomDims
::
size
()
==
BottomTensorView
::
get_num_of_dimension
());
static
constexpr
index_t
NDimWindowAdaptorTop
=
WindowAdaptor
::
get_num_of_top_dimension
();
static
constexpr
index_t
NDimBottomTensor
=
BottomTensorDesc
::
get_num_of_dimension
();
static
constexpr
index_t
NDimP
=
TileDstr
::
get_num_of_dimension_p
();
static
constexpr
index_t
NDimY
=
TileDstr
::
get_num_of_dimension_y
();
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
// TODO: check WindowLengths and StaticTileDistribution are consistent
static_assert
(
ck_tile
::
is_known_at_compile_time
<
WindowLengths
>::
value
,
"wrong! lengths should be static"
);
static_assert
(
TileDstr
::
is_static
(),
"wrong!"
);
static_assert
(
NDimBottomTensor
==
WindowAdaptor
::
get_num_of_bottom_dimension
(),
"wrong! inconsistent # of diemsnions"
);
using
AdaptorTopIndex
=
array
<
index_t
,
NDimWindowAdaptorTop
>
;
using
BottomTensorIndex
=
array
<
index_t
,
NDimBottomTensor
>
;
using
WindowAdaptorCoord
=
decltype
(
make_tensor_adaptor_coordinate
(
WindowAdaptor
{},
AdaptorTopIndex
{}));
using
BottomTensorCoord
=
decltype
(
make_tensor_coordinate
(
BottomTensorDesc
{},
BottomTensorIndex
{}));
struct
traits
{
private:
// return vector dimension among [y0, y1, ...]
CK_TILE_DEVICE
static
constexpr
auto
get_window_adaptor_ys_safe_vector_length_strides
()
{
// bottom tensor top dimension vector lengths and strides
const
auto
[
bottom_tensor_top_dim_vector_lengths
,
bottom_tensor_top_dim_vector_strides
]
=
BottomTensorDesc
::
get_top_dimension_safe_vector_length_strides
();
// window vector lengths/strides
const
auto
window_adaptor_bottom_dim_vector_lengths
=
bottom_tensor_top_dim_vector_lengths
;
const
auto
window_adaptor_bottom_dim_vector_strides
=
bottom_tensor_top_dim_vector_strides
;
// window adaptor [p0, p1, ..., y0, y1, ...]
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_lengths
{
-
1
};
array
<
index_t
,
WindowAdaptor
::
get_num_of_hidden_dimension
()
>
window_adaptor_vector_strides
{
-
1
};
constexpr
auto
window_adaptor_bottom_dims
=
WindowAdaptor
::
get_bottom_dimension_hidden_ids
();
set_container_subset
(
window_adaptor_vector_lengths
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_lengths
);
set_container_subset
(
window_adaptor_vector_strides
,
window_adaptor_bottom_dims
,
window_adaptor_bottom_dim_vector_strides
);
const
auto
[
window_adaptor_ps_ys_vector_lengths
,
window_adaptor_ps_ys_vector_strides
]
=
WindowAdaptor
{}.
get_top_dimension_safe_vector_length_strides
(
window_adaptor_vector_lengths
,
window_adaptor_vector_strides
);
// [y0, y1, ...]
constexpr
auto
y_dims
=
typename
arithmetic_sequence_gen
<
TileDstr
::
get_num_of_dimension_p
(),
NDimWindowAdaptorTop
,
1
>::
type
{};
return
make_tuple
(
get_container_subset
(
window_adaptor_ps_ys_vector_lengths
,
y_dims
),
get_container_subset
(
window_adaptor_ps_ys_vector_strides
,
y_dims
));
}
static
constexpr
auto
get_vector_dim_y_scalar_per_vector
()
{
const
auto
[
ys_vector_lengths
,
ys_vector_strides
]
=
get_window_adaptor_ys_safe_vector_length_strides
();
index_t
VectorDimY_
=
0
;
index_t
ScalarPerVector_
=
1
;
for
(
index_t
i
=
0
;
i
<
NDimY
;
++
i
)
{
if
(
ys_vector_strides
[
i
]
==
1
&&
ys_vector_lengths
[
i
]
>
ScalarPerVector_
)
{
ScalarPerVector_
=
ys_vector_lengths
[
i
];
VectorDimY_
=
i
;
}
}
return
make_tuple
(
VectorDimY_
,
ScalarPerVector_
);
}
public:
static
constexpr
index_t
VectorDimY
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
0
>();
static
constexpr
index_t
ScalarPerVector
=
get_vector_dim_y_scalar_per_vector
().
template
at
<
1
>();
using
vector_t
=
thread_buffer
<
DataType
,
ScalarPerVector
>
;
private:
static
constexpr
auto
scalars_per_access_
=
[]
{
constexpr
auto
scalars_per_access_arr
=
generate_array
(
[
&
](
auto
i
)
{
return
(
i
==
VectorDimY
)
?
ScalarPerVector
:
1
;
},
number
<
NDimY
>
{});
/// TODO: add non-automatic storage argument support to macro TO_SEQUENCE()
constexpr
auto
NDimY_
=
NDimY
;
return
TO_SEQUENCE
(
scalars_per_access_arr
,
NDimY_
);
}();
static
constexpr
auto
get_space_filling_curve
()
{
constexpr
auto
thread_tensor_lengths_ys
=
to_sequence
(
TileDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
// FIXME: need logic to judge dim access order
using
DimAccessOrder
=
typename
arithmetic_sequence_gen
<
0
,
NDimY
,
1
>::
type
;
return
space_filling_curve
<
decltype
(
thread_tensor_lengths_ys
),
DimAccessOrder
,
decltype
(
scalars_per_access_
),
false
/*!!! no snaked curve! */
>
{};
}
public:
using
SFC_Ys
=
decltype
(
get_space_filling_curve
());
static
constexpr
index_t
NumAccess
=
SFC_Ys
::
get_num_of_access
();
static_assert
(
0
<
NumAccess
,
"Wrong! NumAccess should be larger than 0"
);
private:
static
constexpr
auto
get_num_non_linear_access
()
{
constexpr
auto
sfc_access_lens
=
SFC_Ys
::
access_lengths
;
using
ys_to_rhs_major
=
typename
decltype
(
TileDstr
{}.
get_static_tile_distribution_encoding
())
::
Ys2RHsMajor
;
constexpr
auto
non_linear
=
[
&
]()
{
index_t
cnt
=
1
;
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i_dim_y
)
{
constexpr
auto
rhs_major
=
ys_to_rhs_major
{}[
i_dim_y
];
constexpr
auto
target_h_dim
=
number
<
rhs_major
-
1
>
{};
// no r dim here!
if
constexpr
(
LinearBottomDims
{}[
target_h_dim
]
==
0
)
{
cnt
*=
sfc_access_lens
[
i_dim_y
];
}
});
return
cnt
;
}();
return
non_linear
;
}
// example:
// non_linear_access_map: sequence<0, 0, 0, 0, 1, 1, 1, 1> for 8 access, totally 2 register
// used
// -> histogram : sequence<4, 4>
// -> prefixsum : seqneuce<0, 4, 8>
// non_linear_access_map: sequence<0, 1, 2, 3, 4, 5, 6, 7> for 8 access, totally 8 register
// used, will pre-cache 8
// -> histogram : sequence<1, 1, 1, 1, 1, 1, 1, 1>
// -> prefixsum : seqneuce<0, 1, 2, 3, 4, 5, 6, 7, 8>
// non_linear_access_map: sequence<0, 0, 1, 1, 2, 2, 3, 3> for 8 access, totally 4 register
// used, will pre-cache 4
// -> histogram : sequence<2, 2, 2, 2>
// -> prefixsum : seqneuce<0, 2, 4, 6, 8>
static
constexpr
auto
get_non_linear_access_map
()
{
constexpr
auto
sfc_access_lens
=
SFC_Ys
::
access_lengths
;
using
ys_to_rhs_major
=
typename
decltype
(
TileDstr
{}.
get_static_tile_distribution_encoding
())
::
Ys2RHsMajor
;
constexpr
auto
non_linear_map
=
[
&
]()
{
array
<
index_t
,
NumAccess
>
m_
{
0
};
index_t
cumulative_len_
=
1
;
index_t
cumulative_non_linear_len_
=
1
;
static_for
<
0
,
NDimY
,
1
>
{}([
&
](
auto
i_y
)
{
constexpr
auto
i_dim_y
=
number
<
NDimY
-
i_y
-
1
>
{};
// from right to left
constexpr
auto
rhs_major
=
ys_to_rhs_major
{}[
i_dim_y
];
constexpr
auto
target_h_dim
=
number
<
rhs_major
-
1
>
{};
// no r dim here!
constexpr
auto
is_linear_dim
=
LinearBottomDims
{}[
target_h_dim
];
array
<
index_t
,
NumAccess
>
current_m_
{
0
};
constexpr
auto
current_len_
=
sfc_access_lens
[
i_dim_y
];
// copy cumulative length as current pattern
for
(
auto
i_
=
0
;
i_
<
cumulative_len_
;
i_
++
)
{
current_m_
(
i_
)
=
m_
[
i_
];
}
for
(
auto
j_
=
0
;
j_
<
current_len_
;
j_
++
)
{
auto
j_offset_
=
is_linear_dim
?
0
:
j_
*
cumulative_non_linear_len_
;
for
(
auto
i_
=
0
;
i_
<
cumulative_len_
;
i_
++
)
{
m_
(
j_
*
cumulative_len_
+
i_
)
=
current_m_
[
i_
]
+
j_offset_
;
}
}
cumulative_len_
*=
current_len_
;
if
(
!
is_linear_dim
)
cumulative_non_linear_len_
*=
current_len_
;
});
return
m_
;
}();
return
TO_SEQUENCE
(
non_linear_map
,
NumAccess
);
}
static
constexpr
auto
get_non_linear_access_histogram
()
{
constexpr
auto
m_
=
get_non_linear_access_map
();
// m_.foo();
constexpr
auto
r_
=
typename
arithmetic_sequence_gen
<
0
,
get_num_non_linear_access
()
+
1
,
1
>::
type
{};
constexpr
auto
h_
=
histogram_sorted_sequence
(
m_
,
r_
);
return
h_
;
}
static
constexpr
auto
get_non_linear_access_histogram_prefix_sum
()
{
constexpr
auto
h_
=
get_non_linear_access_histogram
();
constexpr
auto
h_prefix_sum_
=
prefix_sum_sequence
(
h_
);
return
h_prefix_sum_
;
}
public:
static
constexpr
index_t
NumAccess_NonLinear
=
get_num_non_linear_access
();
using
AccessMap_NonLinear
=
decltype
(
get_non_linear_access_map
());
// sequence
using
AccessHistogram_NonLinear
=
decltype
(
get_non_linear_access_histogram
());
using
AccessPrefixSum_NonLinear
=
decltype
(
get_non_linear_access_histogram_prefix_sum
());
};
static
constexpr
index_t
NumAccess
=
traits
::
NumAccess
;
static
constexpr
index_t
NumAccess_NonLinear
=
traits
::
NumAccess_NonLinear
;
using
AccessMap_NonLinear
=
typename
traits
::
AccessMap_NonLinear
;
using
AccessHistogram_NonLinear
=
typename
traits
::
AccessHistogram_NonLinear
;
using
AccessPrefixSum_NonLinear
=
typename
traits
::
AccessPrefixSum_NonLinear
;
CK_TILE_DEVICE
constexpr
tile_window_linear
()
=
default
;
CK_TILE_DEVICE
constexpr
tile_window_linear
(
const
BottomTensorView
&
bottom_tensor_view
,
const
WindowLengths
&
window_lengths
,
const
BottomTensorIndex
&
window_origin
,
const
TileDstr
&
tile_distribution
)
:
bottom_tensor_view_
{
bottom_tensor_view
},
window_lengths_
{
window_lengths
},
window_origin_
{
window_origin
},
tile_dstr_
{
tile_distribution
},
cached_coords_
{},
cached_flags_
{}
{
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_distribution
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
make_tuple
(
get_warp_id
(),
get_lane_id
()),
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimY
>
{})));
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// future load/store() calls (might allocate more registers)
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
i_access
]
>
{};
constexpr
auto
need_save_non_linear_coord
=
bool_constant
<
AccessPrefixSum_NonLinear
{}[
non_linear_id
]
==
i_access
>
{};
if
constexpr
(
need_save_non_linear_coord
)
{
cached_coords_
(
non_linear_id
)
=
bottom_tensor_thread_coord_tmp
;
}
// TODO: need pad_tensor_view to check which dim need use flag to check
// cached flag is independent from non-linear-coord
// but need be updated in move_tile, with proper dims
cached_flags_
(
i_access
)
=
coordinate_has_valid_offset_assuming_top_index_is_valid
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_coord_tmp
);
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
i_access
);
// tuple of number
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord_tmp
,
bottom_tensor_thread_coord_tmp
,
idx_diff_ps_ys
);
}
});
}
CK_TILE_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
{
return
NDimBottomTensor
;
}
CK_TILE_DEVICE
static
constexpr
bool
has_static_tile_distribution
()
{
return
TileDstr
::
is_static
();
}
CK_TILE_DEVICE
constexpr
auto
get_window_lengths
()
const
{
return
window_lengths_
;
}
CK_TILE_DEVICE
constexpr
auto
get_tile_distribution
()
const
{
return
tile_dstr_
;
}
CK_TILE_DEVICE
constexpr
auto
get_bottom_tensor_view
()
const
{
return
bottom_tensor_view_
;
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
template
<
typename
ATopIndex
>
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
WindowAdaptorCoord
&
window_adaptor_thread_coord
,
BottomTensorCoord
&
bottom_tensor_thread_coord
,
const
ATopIndex
&
idx_diff_adaptor_top
)
const
{
array
<
index_t
,
NDimBottomTensor
>
idx_diff_adaptor_bottom
;
move_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
window_adaptor_thread_coord
,
idx_diff_adaptor_top
,
idx_diff_adaptor_bottom
);
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_coord
,
idx_diff_adaptor_bottom
);
}
template
<
index_t
i_access
>
CK_TILE_DEVICE
static
constexpr
auto
get_bottom_linear_coordinate
(
number
<
i_access
>
)
{
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
idx_ys
=
SFC_Ys
::
get_index
(
number
<
i_access
>
{});
using
ys_to_rhs_major
=
typename
decltype
(
TileDstr
{}.
get_static_tile_distribution_encoding
())
::
Ys2RHsMajor
;
constexpr
auto
modified_idx_ys
=
generate_tuple
(
[
&
](
auto
i_dim_y
)
{
constexpr
auto
rhs_major
=
ys_to_rhs_major
{}[
i_dim_y
];
constexpr
auto
target_h_dim
=
number
<
rhs_major
-
1
>
{};
// no r dim here!
if
constexpr
(
LinearBottomDims
{}[
target_h_dim
]
==
0
)
{
return
number
<
0
>
{};
}
else
{
return
number
<
idx_ys
[
i_dim_y
]
>
{};
}
},
number
<
NDimY
>
{});
constexpr
auto
adaptor_
=
TileDstr
{}.
get_ps_ys_to_xs_adaptor
();
constexpr
auto
idx_
=
container_concat
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{}),
modified_idx_ys
);
return
adaptor_
.
calculate_bottom_index
(
idx_
);
}
template
<
index_t
i_access
>
CK_TILE_DEVICE
static
constexpr
index_t
get_bottom_linear_offset
(
number
<
i_access
>
)
{
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
number
<
i_access
>
{});
constexpr
auto
is_pure_linear_tensor
=
reduce_on_sequence
(
LinearBottomDims
{},
multiplies
{},
number
<
1
>
{});
if
constexpr
(
is_pure_linear_tensor
)
{
// this case usually is a LDS window, everything is known at compile tile.
// we directly use BottomTensorView transform to compute the offset, in case padding
auto
bottom_tensor_coord
=
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
return
bottom_tensor_coord
.
get_offset
();
}
else
{
// this case usually is a global window, where last dim can be linear
// we hack here, that use the original TileDstr to compute the linear offset
// ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
constexpr
index_t
linear_offset
=
[
&
]()
{
constexpr
auto
x_idx_
=
linear_coord
;
constexpr
auto
x_len_
=
TileDstr
{}.
get_lengths
();
static_assert
(
x_idx_
.
size
()
==
x_len_
.
size
());
constexpr
index_t
x_dims_
=
x_idx_
.
size
();
index_t
cu_stride_
=
1
;
index_t
cu_offset_
=
0
;
static_for
<
0
,
x_dims_
,
1
>
{}([
&
](
auto
i_
)
{
auto
r_i_
=
number
<
x_dims_
-
i_
-
1
>
{};
cu_offset_
+=
x_idx_
[
r_i_
]
*
cu_stride_
;
cu_stride_
*=
x_len_
[
r_i_
];
});
return
cu_offset_
;
}();
return
linear_offset
;
}
}
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
// negative means loop over all num_access
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static
constexpr
index_t
YElementSize
=
TileDstr
{}.
get_ys_to_d_descriptor
().
get_element_space_size
();
static_assert
(
YElementSize
%
traits
::
ScalarPerVector
==
0
);
using
vectorized_tbuf
=
array
<
vector_t
,
YElementSize
/
traits
::
ScalarPerVector
>
;
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
&
dst_vec_tbuf
=
reinterpret_cast
<
vectorized_tbuf
&>
(
dst_tensor
.
get_thread_buffer
());
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access_
==
0
&&
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
linear_offset
/**/
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
};
WINDOW_DISPATCH_ISSUE
();
}
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert
(
NumAccess_NonLinear
==
NumAccess
);
static_assert
(
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
);
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
m0_set_with_memory
(
m0_init_value
);
// This should be wave independent
using
vector_t
=
typename
traits
::
vector_t
;
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
;
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
i_access_
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// get this flag anyway
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
pre_nop_
);
// move thread coordinate
if
constexpr
(
i_access_
!=
(
NumAccess
-
1
))
{
m0_inc_with_memory
(
size_per_issue
);
}
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
typename
LdsTileWindow_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// currently we only support everything is non linear dim
// actually it's not performant if we have linear dim(e.g. fast changing)
static_assert
(
NumAccess_NonLinear
==
NumAccess
);
static_assert
(
BottomTensorView
::
buffer_view
::
get_address_space
()
==
address_space_enum
::
global
);
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
// TODO: LDS offset is not good for intrinsic based implementation(compiler can't figure out
// dependency) hence avoid use offset based solution. size_per_buf should be zero (how to
// check?)
constexpr
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}));
constexpr
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
-
size_per_buf
;
constexpr
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
using
vector_t
=
typename
traits
::
vector_t
;
// TODO: we force CK_TILE_LDS_ADDR
CK_TILE_LDS_ADDR
LdsDataType
*
smem
=
lds_tile
.
get_bottom_tensor_view
().
get_buffer_view
().
p_data_
+
m0_init_value
;
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
,
0
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
// move thread coordinate
if
constexpr
(
i_access_
!=
(
NumAccess
-
1
))
{
smem
+=
size_per_issue
;
// Note we manually increase the per-issue offset
}
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
store
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
set_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
>
CK_TILE_DEVICE
void
store_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
static
constexpr
bool
oob_conditional_check
=
true
;
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
()
.
template
set_vectorized_elements_raw
<
vector_t
,
oob_conditional_check
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
);
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
IAccess
=
number
<
i_access
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
i_access
]
>
{};
constexpr
auto
need_update_non_linear_coord
=
bool_constant
<
AccessPrefixSum_NonLinear
{}[
non_linear_id
]
==
i_access
>
{};
if
constexpr
(
need_update_non_linear_coord
)
{
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
cached_coords_
(
non_linear_id
),
step
);
}
// move the current coord with linear_coords
auto
tmp_coords
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
IAccess
);
move_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
tmp_coords
,
linear_coord
);
cached_flags_
(
IAccess
)
=
coordinate_has_valid_offset_assuming_top_index_is_valid
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
tmp_coords
);
});
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
TileDstr
{}.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
make_tuple
(
get_warp_id
(),
get_lane_id
()),
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimY
>
{})));
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin_
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// future load/store() calls (might allocate more registers)
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
static_for
<
0
,
NumAccess
,
1
>
{}([
&
](
auto
i_access
)
{
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
i_access
]
>
{};
constexpr
auto
need_save_non_linear_coord
=
bool_constant
<
AccessPrefixSum_NonLinear
{}[
non_linear_id
]
==
i_access
>
{};
if
constexpr
(
need_save_non_linear_coord
)
{
cached_coords_
(
non_linear_id
)
=
bottom_tensor_thread_coord_tmp
;
}
if
constexpr
(
i_access
!=
(
NumAccess
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
i_access
);
// tuple of number
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord_tmp
,
bottom_tensor_thread_coord_tmp
,
idx_diff_ps_ys
);
}
});
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
bottom_tensor_view_
.
init_raw
();
}
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
//
WindowLengths
window_lengths_
;
// origin ([x0', x1', ...]) of window on bottom tensor
BottomTensorIndex
window_origin_
;
// Tile tensor distribution, which contains:
// 1. adaptor for window: [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...]
// 2. thread descriptor for thread tensor in register: [y0, y1, ...] ==> [d]
TileDstr
tile_dstr_
;
// this contains:
array
<
BottomTensorCoord
,
traits
::
NumAccess_NonLinear
>
cached_coords_
;
array
<
bool
,
traits
::
NumAccess
>
cached_flags_
;
};
#undef WINDOW_DISPATCH_ISSUE
namespace
impl
{
template
<
address_space_enum
,
index_t
len_
>
struct
default_linear_bottom_dims_impl
{
using
type
=
typename
uniform_sequence_gen
<
len_
,
0
>::
type
;
};
template
<
index_t
len_
>
struct
default_linear_bottom_dims_impl
<
address_space_enum
::
global
,
len_
>
{
// global default to seq<0,0,....1>
using
type
=
typename
sequence_merge
<
typename
uniform_sequence_gen
<
len_
-
1
,
0
>::
type
,
sequence
<
1
>>::
type
;
};
template
<
index_t
len_
>
struct
default_linear_bottom_dims_impl
<
address_space_enum
::
lds
,
len_
>
{
// lds default to seq<1,1.....1>
using
type
=
typename
uniform_sequence_gen
<
len_
,
1
>::
type
;
};
}
// namespace impl
template
<
typename
TensorView_
>
using
default_linear_bottom_dims
=
typename
impl
::
default_linear_bottom_dims_impl
<
TensorView_
::
buffer_view
::
get_address_space
(),
TensorView_
::
get_num_of_dimension
()
>::
type
;
// if using this API, will create a tile_window_linear
// this structure can have the chance to use immediate value, save register
// need pass in LinearBottomDims_ properly to control which dim is linear
// so to generate a constexpr offset as linear_offset for this dim
// (and finally pass to the immediate offset of buffer/lds instruction)
//
// Note: there is no internal check for which dim is OK to use linear offset
// user must make sure by themselves
//
// e.g.
// 2d global matrix, set LinearBottomDims_=seq<0, 1>, the last dim will generate
// immediate offset if each thread has multiple issue along last dim
//
// 2d LDS buffer, set LinearBottomDims_=seq<1, 1>, then only one vgpr used as offset
// everything else is just using immediate offset.
//
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
TensorView_
>
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_linear
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
static_assert
(
LinearBottomDims_
::
size
()
==
TensorView_
::
get_num_of_dimension
());
return
tile_window_linear
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
remove_cvref_t
<
LinearBottomDims_
>>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
}
template
<
typename
TileWindow_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
typename
TileWindow_
::
BottomTensorView
>
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_linear
(
const
TileWindow_
&
tile_window
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
return
make_tile_window_linear
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
,
LinearBottomDims_
{});
}
// this version must not be called under a constexpr context
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
TensorView_
>
>
CK_TILE_DEVICE
auto
make_tile_window_linear_raw
(
const
TensorView_
&
tensor_view
,
const
WindowLengths_
&
window_lengths
,
const
multi_index
<
TensorView_
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
static_assert
(
LinearBottomDims_
::
size
()
==
TensorView_
::
get_num_of_dimension
());
auto
w
=
tile_window_linear
<
remove_cvref_t
<
TensorView_
>
,
remove_cvref_t
<
WindowLengths_
>
,
remove_cvref_t
<
StaticTileDistribution_
>
,
remove_cvref_t
<
LinearBottomDims_
>>
{
tensor_view
,
window_lengths
,
origin
,
tile_distribution
};
w
.
init_raw
();
return
w
;
}
template
<
typename
TileWindow_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
=
default_linear_bottom_dims
<
typename
TileWindow_
::
BottomTensorView
>
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window_linear_raw
(
const
TileWindow_
&
tile_window
,
const
StaticTileDistribution_
&
tile_distribution
,
LinearBottomDims_
=
{})
{
return
make_tile_window_linear_raw
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
,
LinearBottomDims_
{});
}
template
<
typename
TensorView_
,
typename
WindowLengths_
,
typename
StaticTileDistribution_
,
typename
LinearBottomDims_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_linear
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
LinearBottomDims_
>&
window
,
const
typename
tile_window_linear
<
TensorView_
,
WindowLengths_
,
StaticTileDistribution_
,
LinearBottomDims_
>::
BottomTensorIndex
&
step
)
{
window
.
move
(
step
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/tile_window_utils.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace
ck_tile
{
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template
<
typename
LdsTileWindow_
>
CK_TILE_DEVICE
auto
get_async_store_smem_info
(
LdsTileWindow_
&&
lds_tile
)
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
return
make_tuple
(
m0_init_value
,
size_per_issue
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/update_tile.hpp
View file @
1f9546e0
...
@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
...
@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
update_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/utility/amd_address_space.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace
ck_tile
{
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
include/ck_tile/core/utility/functional_with_tuple.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
// This file should not be included inside tuple.hpp!
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include <stdint.h>
#include <utility>
namespace
ck_tile
{
namespace
detail
{
// RemainLengths: sequence<...>
// Orders: sequence<...>
template
<
class
RemainLengths
,
class
RamainUnpacks
,
class
Orders
>
struct
static_uford_impl
{
CK_TILE_HOST_DEVICE
constexpr
static_uford_impl
()
{
static_assert
(
RemainLengths
::
size
()
>
0
,
"wrong! should not get here"
);
static_assert
(
RamainUnpacks
::
size
()
>
0
,
"wrong! should not get here"
);
}
template
<
class
F
,
class
CurrentUnpackIds
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentUnpackIds
)
const
{
constexpr
index_t
pack_len
=
RamainUnpacks
::
front
();
static_for
<
0
,
RemainLengths
::
front
(),
pack_len
>
{}([
=
](
auto
I
)
{
constexpr
auto
new_pack
=
generate_tuple
(
[
&
](
auto
idx_
)
{
constexpr
auto
i_new_pack
=
number
<
I
+
idx_
%
pack_len
>
{};
constexpr
auto
i_pre_pack
=
number
<
idx_
/
pack_len
>
{};
return
CurrentUnpackIds
{}.
at
(
i_pre_pack
).
push_back
(
i_new_pack
);
},
number
<
CurrentUnpackIds
::
size
()
*
pack_len
>
{});
static_uford_impl
<
decltype
(
RemainLengths
::
pop_front
()),
decltype
(
RamainUnpacks
::
pop_front
()),
Orders
>
{}(
f
,
new_pack
);
});
}
};
template
<
class
Orders
>
struct
static_uford_impl
<
sequence
<>
,
sequence
<>
,
Orders
>
{
template
<
class
F
,
class
PackedId
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
PackedId
)
const
{
constexpr
auto
origin_packs
=
transform_tuples
(
[](
auto
pack_
)
{
return
decltype
(
pack_
)
::
reorder_old_to_new
(
Orders
{});
},
PackedId
{});
unpack
(
f
,
origin_packs
);
}
};
template
<
class
RemainLengths
,
class
RamainUnpacks
,
class
Orders
>
struct
static_uford_one_shot_impl
{
template
<
class
F
,
class
CurrentUnpackIds
,
index_t
current_acc
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
CurrentUnpackIds
,
number
<
current_acc
>
)
const
{
constexpr
auto
r_lens_stride
=
reverse_exclusive_scan_sequence
(
RemainLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
r_upks_stride
=
reverse_exclusive_scan_sequence
(
RamainUnpacks
{},
multiplies
{},
number
<
1
>
{});
constexpr
index_t
current_stride
=
r_lens_stride
.
front
()
/
r_upks_stride
.
front
();
constexpr
index_t
pack_len
=
RamainUnpacks
::
front
();
constexpr
index_t
current_idx
=
(
current_acc
/
current_stride
)
*
pack_len
;
constexpr
auto
new_pack
=
generate_tuple
(
[
&
](
auto
idx_
)
{
constexpr
auto
i_new_pack
=
number
<
current_idx
+
idx_
%
pack_len
>
{};
constexpr
auto
i_pre_pack
=
number
<
idx_
/
pack_len
>
{};
return
CurrentUnpackIds
{}.
at
(
i_pre_pack
).
push_back
(
i_new_pack
);
},
number
<
CurrentUnpackIds
::
size
()
*
pack_len
>
{});
static_uford_one_shot_impl
<
decltype
(
RemainLengths
::
pop_front
()),
decltype
(
RamainUnpacks
::
pop_front
()),
Orders
>
{}(
f
,
new_pack
,
number
<
current_acc
%
current_stride
>
{});
}
};
template
<
class
Orders
>
struct
static_uford_one_shot_impl
<
sequence
<>
,
sequence
<>
,
Orders
>
{
template
<
class
F
,
class
PackedId
,
index_t
current_acc
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
PackedId
,
number
<
current_acc
>
)
const
{
constexpr
auto
origin_packs
=
transform_tuples
(
[](
auto
pack_
)
{
return
decltype
(
pack_
)
::
reorder_old_to_new
(
Orders
{});
},
PackedId
{});
unpack
(
f
,
origin_packs
);
}
};
}
// namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
//
// e.g.
// Lengths=seq<2, 3, 4>, Unpacks=<1, 1, 2>
// static_uford<Lengths, Unpacks>{}([&](auto i_0, auto i_1){}); // require 2 args(packs)
//
// loop #0, i_0=seq<0, 0, 0>, i_1=<0, 0, 1>
// loop #1, i_0=seq<0, 0, 2>, i_1=<0, 0, 3>
// loop #2, i_0=seq<0, 1, 0>, i_1=<0, 1, 1>
// loop #3, i_0=seq<0, 1, 2>, i_1=<0, 1, 3>
// loop #4, i_0=seq<0, 2, 0>, i_1=<0, 2, 1>
// loop #5, i_0=seq<0, 2, 2>, i_1=<0, 2, 3>
// loop #6, i_0=seq<1, 0, 0>, i_1=<1, 0, 1>
// ...
template
<
class
Lengths
,
class
Unpacks
=
typename
uniform_sequence_gen
<
Lengths
::
size
(),
1
>
::
type
,
class
Orders
=
typename
arithmetic_sequence_gen
<
0
,
Lengths
::
size
(),
1
>::
type
>
struct
static_uford
{
static
constexpr
index_t
num_packs
=
reduce_on_sequence
(
Unpacks
{},
multiplies
{},
number
<
1
>
{});
CK_TILE_HOST_DEVICE
constexpr
static_uford
()
{
static_assert
(
Lengths
::
size
()
>
0
,
"wrong! Lengths is empty"
);
static_assert
(
Lengths
::
size
()
==
Unpacks
::
size
(),
"wrong! inconsistent size"
);
static_assert
(
Lengths
::
size
()
==
Orders
::
size
(),
"wrong! inconsistent size"
);
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
[
&
](
auto
i
)
{
static_assert
(
Lengths
{}.
at
(
i
)
%
Unpacks
{}.
at
(
i
)
==
0
);
});
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
using
L_
=
decltype
(
Lengths
{}
/
Unpacks
{});
return
reduce_on_sequence
(
L_
{},
multiplies
{},
number
<
1
>
{});
}
// F signature: F(sequence<...> multi_id...)
// multi_id is the unordered multi-index
template
<
class
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
)
const
{
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
constexpr
auto
ordered_unpacks
=
Unpacks
::
reorder_new_to_old
(
Orders
{});
detail
::
static_uford_impl
<
decltype
(
ordered_lengths
),
decltype
(
ordered_unpacks
),
Orders
>
{}(
f
,
make_tuple
(
sequence
<>
{}));
}
// this version is friendly for issue function one by one
template
<
class
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
F
f
,
number
<
i_access
>
)
const
{
static_assert
(
i_access
<
get_num_of_access
());
constexpr
auto
ordered_lengths
=
Lengths
::
reorder_new_to_old
(
Orders
{});
constexpr
auto
ordered_unpacks
=
Unpacks
::
reorder_new_to_old
(
Orders
{});
detail
::
static_uford_one_shot_impl
<
decltype
(
ordered_lengths
),
decltype
(
ordered_unpacks
),
Orders
>
{}(
f
,
make_tuple
(
sequence
<>
{}),
number
<
i_access
>
{});
}
};
}
// namespace ck_tile
include/ck_tile/core/utility/literals.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
namespace
ck_tile
{
namespace
literals
{
// [P0330] Literal Suffix for (signed) size_t (C++23)
// ref: https://wg21.link/p0330r8
inline
constexpr
std
::
size_t
operator
""
_uz
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
inline
constexpr
std
::
size_t
operator
""
_zu
(
unsigned
long
long
size
)
{
return
static_cast
<
std
::
size_t
>
(
size
);
}
}
// namespace literals
}
// namespace ck_tile
include/ck_tile/core/utility/magic_div.hpp
View file @
1f9546e0
...
@@ -59,8 +59,16 @@ struct magic_division32_bit_range
...
@@ -59,8 +59,16 @@ struct magic_division32_bit_range
CK_TILE_DEVICE
static
constexpr
uint32_t
CK_TILE_DEVICE
static
constexpr
uint32_t
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
do_magic_division
(
uint32_t
dividend
,
uint32_t
multiplier
,
uint32_t
shift
)
{
{
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
if
(
__builtin_is_constant_evaluated
())
return
(
tmp
+
dividend
)
>>
shift
;
{
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend
)
>>
shift
;
}
else
{
uint32_t
tmp
=
__umulhi
(
dividend
,
multiplier
);
return
(
tmp
+
dividend
)
>>
shift
;
}
}
}
CK_TILE_HOST
static
constexpr
uint32_t
CK_TILE_HOST
static
constexpr
uint32_t
...
@@ -77,9 +85,18 @@ struct magic_division32_bit_range
...
@@ -77,9 +85,18 @@ struct magic_division32_bit_range
CK_TILE_DEVICE
static
constexpr
int32_t
CK_TILE_DEVICE
static
constexpr
int32_t
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
do_magic_division
(
int32_t
dividend_i32
,
uint32_t
multiplier
,
uint32_t
shift
)
{
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
if
(
__builtin_is_constant_evaluated
())
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
{
return
(
tmp
+
dividend_u32
)
>>
shift
;
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
(
static_cast
<
uint64_t
>
(
dividend_u32
)
*
multiplier
)
>>
32
;
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
else
{
uint32_t
dividend_u32
=
bit_cast
<
uint32_t
>
(
dividend_i32
);
uint32_t
tmp
=
__umulhi
(
dividend_u32
,
multiplier
);
return
(
tmp
+
dividend_u32
)
>>
shift
;
}
}
}
CK_TILE_HOST
static
constexpr
int32_t
CK_TILE_HOST
static
constexpr
int32_t
...
...
include/ck_tile/core/utility/reduce_operator.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
namespace
ReduceOp
{
// y = ReduceOp(y, x);
struct
Add
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
const
T
&
y
,
const
T
x
)
const
{
return
y
+
x
;
}
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
half_t
>
||
std
::
is_same_v
<
T
,
bf16_t
>>>
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
T
&
y
,
T
x
)
const
{
float
y_
=
type_convert
<
float
>
(
y
);
float
x_
=
type_convert
<
float
>
(
x
);
return
type_convert
<
T
>
(
y_
+
x_
);
}
};
struct
SquareAdd
{
template
<
typename
T
>
CK_TILE_HOST_DEVICE
static
constexpr
T
GetIdentityValue
()
{
return
type_convert
<
T
>
(
0.0
f
);
};
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
const
T
&
y
,
const
T
x
)
const
{
return
y
+
(
x
*
x
);
}
};
struct
Max
{
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
CK_TILE_HOST_DEVICE
static
constexpr
T
GetIdentityValue
()
{
return
numeric
<
T
>::
min
();
};
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
const
T
&
y
,
const
T
x
)
const
{
return
max
(
y
,
x
);
}
};
struct
AbsMax
{
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
CK_TILE_HOST_DEVICE
static
constexpr
T
GetIdentityValue
()
{
return
numeric
<
T
>::
min
();
};
template
<
typename
T
,
typename
=
std
::
enable_if_t
<
std
::
is_same_v
<
T
,
float
>
||
std
::
is_same_v
<
T
,
double
>
||
std
::
is_same_v
<
T
,
int32_t
>
||
std
::
is_same_v
<
T
,
int8_t
>>>
CK_TILE_HOST_DEVICE
constexpr
T
operator
()(
const
T
&
y
,
const
T
x
)
const
{
return
max
(
y
,
abs
(
x
));
}
};
}
// namespace ReduceOp
}
// namespace ck_tile
include/ck_tile/core/utility/static_counter.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
template
<
typename
Context
,
index_t
Start
=
0
,
index_t
Step
=
1
>
struct
static_counter
{
public:
template
<
typename
Unique
>
static
constexpr
index_t
next
()
{
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
next
()
{
struct
Unique
{
};
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
typename
Unique
>
static
constexpr
index_t
current
()
{
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
current
()
{
struct
Unique
{
};
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
private:
template
<
index_t
I
>
struct
slot
{
_Pragma
(
"GCC diagnostic push"
);
_Pragma
(
"GCC diagnostic ignored
\"
-Wundefined-internal
\"
"
);
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
);
_Pragma
(
"GCC diagnostic pop"
);
};
template
<
index_t
I
>
struct
allocate_slot
{
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
)
{
return
true
;
}
enum
{
value
=
I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
0
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
next
(
index_t
)
{
return
next
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template
<
typename
Unique
,
index_t
I
=
0
>
static
constexpr
index_t
next
(
double
)
{
return
allocate_slot
<
I
>::
value
;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
Start
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
current
(
index_t
)
{
return
current
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template
<
typename
Unique
,
index_t
I
=
Start
>
static
constexpr
index_t
current
(
double
)
{
static_assert
(
I
!=
0
,
"You must invoke next() first"
);
return
I
-
1
;
}
};
namespace
impl
{
template
<
int
I
>
struct
static_counter_uniq_
;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
1f9546e0
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
...
@@ -19,10 +20,17 @@
...
@@ -19,10 +20,17 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_moe_sorting.hpp"
#include "ck_tile/host/reference/reference_permute.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp"
#include "ck_tile/host/reference/reference_rowwise_quantization2d.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp"
#include "ck_tile/host/timer.hpp"
include/ck_tile/host/device_memory.hpp
View file @
1f9546e0
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdint.h>
#include <stdexcept>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
T
>
template
<
typename
T
>
...
@@ -36,6 +37,19 @@ struct DeviceMem
...
@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf
=
nullptr
;
mpDeviceBuf
=
nullptr
;
}
}
}
}
template
<
typename
T
>
DeviceMem
(
const
HostTensor
<
T
>&
t
)
:
mMemSize
(
t
.
get_element_space_size_in_bytes
())
{
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
ToDevice
(
t
.
data
());
}
void
Realloc
(
std
::
size_t
mem_size
)
void
Realloc
(
std
::
size_t
mem_size
)
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
@@ -92,6 +106,27 @@ struct DeviceMem
...
@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
}
}
// construct a host tensor with type T
template
<
typename
T
>
HostTensor
<
T
>
ToHost
(
std
::
size_t
cpySize
)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std
::
size_t
host_elements
=
(
cpySize
+
sizeof
(
T
)
-
1
)
/
sizeof
(
T
);
HostTensor
<
T
>
h_
({
host_elements
});
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
h_
.
data
(),
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
return
h_
;
}
template
<
typename
T
>
HostTensor
<
T
>
ToHost
()
{
return
ToHost
<
T
>
(
mMemSize
);
}
void
SetZero
()
const
void
SetZero
()
const
{
{
if
(
mpDeviceBuf
)
if
(
mpDeviceBuf
)
...
...
include/ck_tile/host/fill.hpp
View file @
1f9546e0
...
@@ -10,8 +10,10 @@
...
@@ -10,8 +10,10 @@
#include <random>
#include <random>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <utility>
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -21,13 +23,44 @@ struct FillUniformDistribution
...
@@ -21,13 +23,44 @@ struct FillUniformDistribution
float
a_
{
-
5.
f
};
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
// ATTENTION: threaded does not guarantee the distribution between thread
bool
threaded
=
false
;
template
<
typename
ForwardIter
>
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
if
(
threaded
)
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
{
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
uint32_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
auto
total
=
static_cast
<
std
::
size_t
>
(
std
::
distance
(
first
,
last
));
auto
work_per_thread
=
static_cast
<
std
::
size_t
>
((
total
+
num_thread
-
1
)
/
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
total
);
auto
thread_f
=
[
this
,
total
,
iw_begin
,
iw_end
,
&
first
]
{
if
(
iw_begin
>
total
||
iw_end
>
total
)
return
;
// need to make each thread unique, add an offset to current seed
std
::
mt19937
gen
(
seed_
.
has_value
()
?
(
*
seed_
+
iw_begin
)
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
+
iw_begin
,
first
+
iw_end
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
};
threads
[
it
]
=
joinable_thread
(
thread_f
);
}
}
else
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
}
template
<
typename
ForwardRange
>
template
<
typename
ForwardRange
>
...
@@ -41,19 +74,117 @@ struct FillUniformDistribution
...
@@ -41,19 +74,117 @@ struct FillUniformDistribution
}
}
};
};
namespace
impl
{
// clang-format off
template
<
index_t
bytes
>
struct
RawIntegerType_
{};
template
<
>
struct
RawIntegerType_
<
1
>
{
using
type
=
uint8_t
;};
template
<
>
struct
RawIntegerType_
<
2
>
{
using
type
=
uint16_t
;};
template
<
>
struct
RawIntegerType_
<
4
>
{
using
type
=
uint32_t
;};
template
<
>
struct
RawIntegerType_
<
8
>
{
using
type
=
uint64_t
;};
// clang-format on
template
<
typename
T
>
using
RawIntegerType
=
typename
RawIntegerType_
<
sizeof
(
T
)
>::
type
;
}
// namespace impl
// Note: this struct will have no const-ness will generate random
template
<
typename
T
>
struct
FillUniformDistribution_Unique
{
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
std
::
mt19937
gen_
{};
std
::
unordered_set
<
impl
::
RawIntegerType
<
T
>>
set_
{};
FillUniformDistribution_Unique
(
float
a
=
-
5.
f
,
float
b
=
5.
f
,
std
::
optional
<
uint32_t
>
seed
=
{
11939
})
:
a_
(
a
),
b_
(
b
),
seed_
(
seed
),
gen_
{
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}()},
set_
{}
{
}
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
{
std
::
mt19937
&
gen
=
gen_
;
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
auto
&
set
=
set_
;
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
,
&
set
]()
{
T
v
=
static_cast
<
T
>
(
0
);
do
{
v
=
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
}
while
(
set
.
count
(
bit_cast
<
impl
::
RawIntegerType
<
T
>>
(
v
))
==
1
);
set
.
insert
(
bit_cast
<
impl
::
RawIntegerType
<
T
>>
(
v
));
return
v
;
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
std
::
declval
<
FillUniformDistribution_Unique
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
void
clear
()
{
set_
.
clear
();
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
FillNormalDistribution
struct
FillNormalDistribution
{
{
float
mean_
{
0.
f
};
float
mean_
{
0.
f
};
float
variance_
{
1.
f
};
float
variance_
{
1.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
// ATTENTION: threaded does not guarantee the distribution between thread
bool
threaded
=
false
;
template
<
typename
ForwardIter
>
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
if
(
threaded
)
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
{
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
uint32_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
auto
total
=
static_cast
<
std
::
size_t
>
(
std
::
distance
(
first
,
last
));
auto
work_per_thread
=
static_cast
<
std
::
size_t
>
((
total
+
num_thread
-
1
)
/
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
total
);
auto
thread_f
=
[
this
,
total
,
iw_begin
,
iw_end
,
&
first
]
{
if
(
iw_begin
>
total
||
iw_end
>
total
)
return
;
// need to make each thread unique, add an offset to current seed
std
::
mt19937
gen
(
seed_
.
has_value
()
?
(
*
seed_
+
iw_begin
)
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
+
iw_begin
,
first
+
iw_end
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
};
threads
[
it
]
=
joinable_thread
(
thread_f
);
}
}
else
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
}
template
<
typename
ForwardRange
>
template
<
typename
ForwardRange
>
...
@@ -167,6 +298,44 @@ struct FillMonotonicSeq
...
@@ -167,6 +298,44 @@ struct FillMonotonicSeq
}
}
};
};
template
<
typename
T
,
bool
IsAscending
=
true
>
struct
FillStepRange
{
float
start_value_
{
0
};
float
end_value_
{
3
};
float
step_
{
1
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
start_value_
]()
mutable
{
auto
tmp
=
n
;
n
+=
step_
;
if
constexpr
(
IsAscending
)
{
if
(
n
>
end_value_
)
n
=
start_value_
;
}
else
{
if
(
n
<
end_value_
)
n
=
start_value_
;
}
return
type_convert
<
T
>
(
tmp
);
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillStepRange
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
template
<
typename
T
>
struct
FillConstant
struct
FillConstant
{
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
1f9546e0
...
@@ -8,11 +8,13 @@
...
@@ -8,11 +8,13 @@
#include <iostream>
#include <iostream>
#include <iomanip>
#include <iomanip>
#include <numeric>
#include <numeric>
#include <thread>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/ranges.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -212,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
...
@@ -212,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
}
}
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
typename
F
,
typename
...
Xs
>
template
<
typename
F
,
typename
...
Xs
>
struct
ParallelTensorFunctor
struct
ParallelTensorFunctor
{
{
...
@@ -545,6 +530,28 @@ struct HostTensor
...
@@ -545,6 +530,28 @@ struct HostTensor
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto
slice
(
std
::
vector
<
size_t
>
s_begin
,
std
::
vector
<
size_t
>
s_end
)
const
{
assert
(
s_begin
.
size
()
==
s_end
.
size
());
assert
(
s_begin
.
size
()
==
get_num_of_dimension
());
std
::
vector
<
size_t
>
s_len
(
s_begin
.
size
());
std
::
transform
(
s_end
.
begin
(),
s_end
.
end
(),
s_begin
.
begin
(),
s_len
.
begin
(),
std
::
minus
<
size_t
>
{});
HostTensor
<
T
>
sliced_tensor
(
s_len
);
sliced_tensor
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
std
::
vector
<
size_t
>
src_idx
(
idx
.
size
());
std
::
transform
(
idx
.
begin
(),
idx
.
end
(),
s_begin
.
begin
(),
src_idx
.
begin
(),
std
::
plus
<
size_t
>
{});
self
(
idx
)
=
operator
()(
src_idx
);
});
return
sliced_tensor
;
}
template
<
typename
U
=
T
>
template
<
typename
U
=
T
>
auto
AsSpan
()
const
auto
AsSpan
()
const
{
{
...
@@ -567,6 +574,107 @@ struct HostTensor
...
@@ -567,6 +574,107 @@ struct HostTensor
size
()
*
FromSize
/
ToSize
};
size
()
*
FromSize
/
ToSize
};
}
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensor
<
T
>&
t
)
{
os
<<
t
.
mDesc
;
os
<<
"["
;
for
(
typename
Data
::
size_type
idx
=
0
;
idx
<
t
.
mData
.
size
();
++
idx
)
{
if
(
0
<
idx
)
{
os
<<
", "
;
}
if
constexpr
(
std
::
is_same_v
<
T
,
bf16_t
>
||
std
::
is_same_v
<
T
,
fp16_t
>
)
{
os
<<
type_convert
<
float
>
(
t
.
mData
[
idx
])
<<
" #### "
;
}
else
{
os
<<
t
.
mData
[
idx
];
}
}
os
<<
"]"
;
return
os
;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void
loadtxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ifstream
file
(
file_name
);
if
(
file
.
is_open
())
{
std
::
string
line
;
index_t
cnt
=
0
;
while
(
std
::
getline
(
file
,
line
))
{
if
(
cnt
>=
static_cast
<
index_t
>
(
mData
.
size
()))
{
throw
std
::
runtime_error
(
std
::
string
(
"data read from file:"
)
+
file_name
+
" is too big"
);
}
if
(
dtype
==
"float"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stof
(
line
));
}
else
if
(
dtype
==
"int"
||
dtype
==
"int32"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stoi
(
line
));
}
cnt
++
;
}
file
.
close
();
if
(
cnt
<
static_cast
<
index_t
>
(
mData
.
size
()))
{
std
::
cerr
<<
"Warning! reading from file:"
<<
file_name
<<
", does not match the size of this tensor"
<<
std
::
endl
;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void
savetxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ofstream
file
(
file_name
);
if
(
file
.
is_open
())
{
for
(
auto
&
itm
:
mData
)
{
if
(
dtype
==
"float"
)
file
<<
type_convert
<
float
>
(
itm
)
<<
std
::
endl
;
else
if
(
dtype
==
"int"
)
file
<<
type_convert
<
int
>
(
itm
)
<<
std
::
endl
;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file
<<
type_convert
<
float
>
(
itm
)
<<
std
::
endl
;
}
file
.
close
();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
Descriptor
mDesc
;
Descriptor
mDesc
;
Data
mData
;
Data
mData
;
};
};
...
...
include/ck_tile/host/joinable_thread.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace
ck_tile
{
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
}
// namespace ck_tile
include/ck_tile/host/reference/reference_elementwise.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
ElementOp
>
CK_TILE_HOST
void
reference_unary_elementwise
(
const
HostTensor
<
ADataType
>&
a
,
HostTensor
<
BDataType
>&
b
,
ElementOp
element_op
)
{
// TODO: imeplement gpu version reference function
auto
f
=
[
&
](
auto
i
)
{
auto
v_a
=
type_convert
<
ComputeDataType
>
(
a
.
mData
[
i
]);
auto
v_b
=
element_op
(
v_a
);
b
.
mData
[
i
]
=
ck_tile
::
type_convert
<
BDataType
>
(
v_b
);
};
make_ParallelTensorFunctor
(
f
,
b
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
ElementOp
>
CK_TILE_HOST
void
reference_binary_elementwise
(
const
HostTensor
<
ADataType
>&
a
,
const
HostTensor
<
BDataType
>&
b
,
HostTensor
<
CDataType
>&
c
,
ElementOp
element_op
)
{
// TODO: imeplement gpu version reference function
auto
f
=
[
&
](
auto
i
)
{
auto
v_a
=
type_convert
<
ComputeDataType
>
(
a
.
mData
[
i
]);
auto
v_b
=
type_convert
<
ComputeDataType
>
(
b
.
mData
[
i
]);
auto
v_c
=
element_op
(
v_a
,
v_b
);
c
.
mData
[
i
]
=
ck_tile
::
type_convert
<
CDataType
>
(
v_c
);
};
make_ParallelTensorFunctor
(
f
,
c
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_fused_moe.hpp
0 → 100644
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template
<
typename
AccDataType
,
// you only need to explcitly set this one
typename
Activation
,
// ck_tile::element_wise::Gelu
typename
ADataType
,
typename
GDataType
,
typename
DDataType
,
typename
ODataType
,
typename
AScaleDataType
,
typename
GScaleDataType
,
typename
DScaleDataType
,
typename
YSmoothScaleDataType
,
typename
TopkWeightDataType
,
typename
IndexDataType
>
void
reference_fused_moe
(
const
ck_tile
::
HostTensor
<
ADataType
>&
a_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
GDataType
>&
g_host
,
// [experts, interme_size_0, hidden_size]
const
ck_tile
::
HostTensor
<
DDataType
>&
d_host
,
// [experts, hidden_size, interme_size_1]
const
ck_tile
::
HostTensor
<
AScaleDataType
>&
sa_host
,
// [tokens, 1],
const
ck_tile
::
HostTensor
<
GScaleDataType
>&
sg_host
,
// [experts, 1, interme_size_0]
const
ck_tile
::
HostTensor
<
DScaleDataType
>&
sd_host
,
// [experts, 1, hidden_size],
const
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>&
sy_host
,
// [experts, 1, interme_size_0]
ck_tile
::
HostTensor
<
ODataType
>&
o_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_token_ids_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
TopkWeightDataType
>&
sorted_weight_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_expert_ids_host
,
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
num_sorted_tiles_host
,
// [1]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
token_ids_host
,
// [tokens, topk] --> ugly!!! remove in the future
ck_tile
::
index_t
block_m
,
ck_tile
::
index_t
tokens
,
ck_tile
::
index_t
experts
,
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
{
assert
(
sorted_token_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_weight_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_expert_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
num_sorted_tiles_host
.
get_element_size
()
==
1
);
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
]
/
block_m
;
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
;
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
/
(
gate_only
?
1
:
2
);
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
{
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
};
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
// first gemm
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_0
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
{
acc
+=
type_convert
<
AccDataType
>
(
a_host
(
i_token
,
i_k
))
*
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
}
acc_0
(
0
,
i_n
)
=
acc
;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
intermediate_size_1
});
if
(
gate_only
)
{
if
(
intermediate_size_1
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if
(
intermediate_size_1
*
2
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile
::
HostTensor
<
AccDataType
>
acc_1
({
1
,
hidden_size
});
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
intermediate_size_1
;
i_k
++
)
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
}
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
out_topk_tokens
(
i_token
,
i_topk
,
i_n
)
=
acc_1
(
0
,
i_n
);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor
(
f
,
max_num_tokens_padded
)(
1
);
// reduce
auto
r
=
[
&
](
auto
i_token
)
{
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
type_convert
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_topk
=
0
;
i_topk
<
topk
;
i_topk
++
)
{
acc
+=
out_topk_tokens
(
i_token
,
i_topk
,
i_n
);
}
o_host
(
i_token
,
i_n
)
=
type_convert
<
ODataType
>
(
acc
);
}
};
make_ParallelTensorFunctor
(
r
,
tokens
)(
std
::
thread
::
hardware_concurrency
());
(
void
)
num_sorted_tiles_host
;
(
void
)
sa_host
;
(
void
)
sg_host
;
(
void
)
sd_host
;
(
void
)
sy_host
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
1f9546e0
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -14,55 +15,36 @@ template <typename ADataType,
...
@@ -14,55 +15,36 @@ template <typename ADataType,
typename
BDataType
,
typename
BDataType
,
typename
AccDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
typename
ACCElementOp
=
ck_tile
::
identity
>
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
const
HostTensor
<
BDataType
>&
b_
n_k
,
const
HostTensor
<
BDataType
>&
b_
k_n
,
HostTensor
<
CDataType
>&
c_m_n
,
HostTensor
<
CDataType
>&
c_m_n
,
const
AElementOp
&
a_element_op
=
{},
const
AElementOp
&
a_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
const
ACCElementOp
&
acc_element_op
=
{})
{
{
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
const
std
::
size_t
M
=
a_m_k
.
get_length
(
0
);
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
auto
f_mn
=
[
&
](
auto
m
,
auto
n
)
{
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
AccDataType
v_acc
=
0
;
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
AccDataType
v_acc
=
0
;
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
v_acc
+=
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
}
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
};
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
}
template
<
typename
ADataType
,
template
<
typename
ADataType
,
...
@@ -201,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
...
@@ -201,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return
;
return
;
}
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
// namespace ck_tile
}
// namespace ck_tile
Prev
1
…
15
16
17
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment