Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
a4522ae3
Commit
a4522ae3
authored
Nov 06, 2024
by
illsilin
Browse files
sync from public repo
parents
1f127242
e0594d08
Changes
425
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2402 additions
and
285 deletions
+2402
-285
include/ck_tile/core/algorithm/coordinate_transform.hpp
include/ck_tile/core/algorithm/coordinate_transform.hpp
+104
-0
include/ck_tile/core/algorithm/indexing_adaptor.hpp
include/ck_tile/core/algorithm/indexing_adaptor.hpp
+60
-0
include/ck_tile/core/algorithm/space_filling_curve.hpp
include/ck_tile/core/algorithm/space_filling_curve.hpp
+7
-5
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+195
-18
include/ck_tile/core/arch/utility.hpp
include/ck_tile/core/arch/utility.hpp
+43
-0
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+24
-1
include/ck_tile/core/container/sequence.hpp
include/ck_tile/core/container/sequence.hpp
+122
-0
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+49
-5
include/ck_tile/core/numeric/int8.hpp
include/ck_tile/core/numeric/int8.hpp
+104
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+930
-44
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+4
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+121
-57
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+83
-4
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+7
-0
include/ck_tile/core/tensor/shuffle_tile.hpp
include/ck_tile/core/tensor/shuffle_tile.hpp
+1
-1
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+14
-0
include/ck_tile/core/tensor/store_tile.hpp
include/ck_tile/core/tensor/store_tile.hpp
+29
-2
include/ck_tile/core/tensor/sweep_tile.hpp
include/ck_tile/core/tensor/sweep_tile.hpp
+278
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+192
-25
include/ck_tile/core/tensor/tile_distribution.hpp
include/ck_tile/core/tensor/tile_distribution.hpp
+35
-123
No files found.
include/ck_tile/core/algorithm/coordinate_transform.hpp
View file @
a4522ae3
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
...
@@ -23,6 +23,7 @@ enum struct coord_transform_enum
replicate
,
replicate
,
xor_t
,
xor_t
,
offset
,
offset
,
indexing
,
};
};
template
<
index_t
NDimLow
,
index_t
NDimUp
>
template
<
index_t
NDimLow
,
index_t
NDimUp
>
...
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
...
@@ -1526,6 +1527,88 @@ struct offset : public base_transform<1, 1>
}
}
};
};
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
struct
indexing
:
public
base_transform
<
1
,
1
>
{
static
constexpr
index_t
NDimUp
=
1
;
using
LowerIndex
=
multi_index
<
1
>
;
using
UpperIndex
=
multi_index
<
1
>
;
using
UpLengths
=
decltype
(
make_tuple
(
UpLength
{}));
UpLengths
up_lengths_
;
IndexingAdaptor
iadaptor_
;
CK_TILE_HOST_DEVICE
constexpr
indexing
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing
(
const
UpLength
&
up_length
,
const
IndexingAdaptor
&
iadaptor
)
:
up_lengths_
{
make_tuple
(
up_length
)},
iadaptor_
{
iadaptor
}
{
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_type_enum
()
{
return
coord_transform_enum
::
indexing
;
}
CK_TILE_HOST_DEVICE
constexpr
const
auto
&
get_upper_lengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
calculate_lower_index
(
idx_low
,
idx_up
);
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
NDimUp
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
NDimUp
,
"wrong! inconsistent # of dimension"
);
iadaptor_
.
update_lower_index
(
idx_diff_low
,
idx_diff_up
,
idx_low
,
idx_up
);
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_always_mapped_to_valid_lower_index
()
{
return
true
;
}
template
<
typename
UpIdx
>
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_valid_upper_index_mapped_to_valid_lower_index
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
UpLengths
>::
value
&&
IndexingAdaptor
::
is_known_at_compile_time
();
}
CK_TILE_HOST_DEVICE
void
print
()
const
{
printf
(
"embed{"
);
//
printf
(
"up_lengths_: "
);
print
(
up_lengths_
);
printf
(
", "
);
printf
(
"}"
);
}
};
//*******************************************************************************************************
//*******************************************************************************************************
template
<
typename
LowLength
>
template
<
typename
LowLength
>
...
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
...
@@ -1646,3 +1729,24 @@ CK_TILE_HOST_DEVICE constexpr auto make_offset_transform(const LowLength& low_le
}
}
}
// namespace ck_tile
}
// namespace ck_tile
#include "ck_tile/core/algorithm/indexing_adaptor.hpp"
namespace
ck_tile
{
template
<
typename
UpLength
,
typename
Indices
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform
(
const
UpLength
&
up_lengths
,
const
Indices
&
indices
)
{
// by default we use the simplest one
return
indexing
<
UpLength
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>>
{
up_lengths
,
indexing_adaptor_onshot_cached
<
remove_cvref_t
<
Indices
>>
{
indices
}};
}
template
<
typename
UpLength
,
typename
IndexingAdaptor
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_indexing_transform_with_adaptor
(
const
UpLength
&
up_lengths
,
const
IndexingAdaptor
&
iadaptor
)
{
return
indexing
<
UpLength
,
IndexingAdaptor
>
{
up_lengths
,
iadaptor
};
}
}
// namespace ck_tile
include/ck_tile/core/algorithm/indexing_adaptor.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/multi_index.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// pre-defined indexing adaptor used for indexing(scatter/gather)
// this version cache the index inside thread register(which is also prefered in real senario)
// however it's user's responsibility that each thread only provide one indexing, which means
// move coordinate will not change on this dim
template
<
typename
IndexingType
>
struct
indexing_adaptor_onshot_cached
{
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
()
=
default
;
CK_TILE_HOST_DEVICE
constexpr
indexing_adaptor_onshot_cached
(
const
IndexingType
&
idx
)
:
cached_idx_
(
idx
)
{
}
IndexingType
cached_idx_
;
template
<
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
calculate_lower_index
(
LowIdx
&
idx_low
,
const
UpIdx
&
/*idx_up*/
)
const
{
static_assert
(
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_low
(
number
<
0
>
{})
=
cached_idx_
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
>
CK_TILE_HOST_DEVICE
void
update_lower_index
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
idx_diff_up
,
LowIdx
&
/*idx_low*/
,
const
UpIdx
&
/*idx_up*/
)
const
{
// TODO: nonthing changed here
static_assert
(
LowIdxDiff
::
size
()
==
1
&&
UpIdxDiff
::
size
()
==
1
&&
LowIdx
::
size
()
==
1
&&
UpIdx
::
size
()
==
1
,
"wrong! inconsistent # of dimension"
);
idx_diff_low
(
number
<
0
>
{})
=
idx_diff_up
[
number
<
0
>
{}];
// pass the diff to lower, but not changing the actually index
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
is_known_at_compile_time
()
{
return
ck_tile
::
is_known_at_compile_time
<
IndexingType
>::
value
;
}
};
}
// namespace ck_tile
include/ck_tile/core/algorithm/space_filling_curve.hpp
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -81,8 +81,10 @@ struct space_filling_curve
...
@@ -81,8 +81,10 @@ struct space_filling_curve
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
-
1
>
{});
return
get_step_between
(
number
<
AccessIdx1d
>
{},
number
<
AccessIdx1d
-
1
>
{});
}
}
// Do not use this function directly!
// TODO: can refactor into generic lambda in the future
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
Index
get_index
(
number
<
AccessIdx1d
>
)
static
CK_TILE_HOST_DEVICE
constexpr
Index
_
get_index
(
number
<
AccessIdx1d
>
)
{
{
#if 0
#if 0
/*
/*
...
@@ -153,11 +155,11 @@ struct space_filling_curve
...
@@ -153,11 +155,11 @@ struct space_filling_curve
return
idx_md
;
return
idx_md
;
}
}
// FIXME: re
name this function
// FIXME: re
turn tuple of number<>, which is compile time only variable
template
<
index_t
AccessIdx1d
>
template
<
index_t
AccessIdx1d
>
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index
_tuple_of_number
(
number
<
AccessIdx1d
>
)
static
CK_TILE_HOST_DEVICE
constexpr
auto
get_index
(
number
<
AccessIdx1d
>
)
{
{
constexpr
auto
idx
=
get_index
(
number
<
AccessIdx1d
>
{});
constexpr
auto
idx
=
_
get_index
(
number
<
AccessIdx1d
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
number
<
idx
[
i
]
>
{};
},
number
<
nDim
>
{});
return
generate_tuple
([
&
](
auto
i
)
{
return
number
<
idx
[
i
]
>
{};
},
number
<
nDim
>
{});
}
}
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
a4522ae3
...
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
...
@@ -621,6 +621,99 @@ CK_TILE_DEVICE void buffer_load_fence(index_t cnt = 0)
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
asm
volatile
(
"s_waitcnt vmcnt(%0)"
:
:
"n"
(
cnt
)
:
"memory"
);
}
}
namespace
impl
{
// below type indicate the data type used for buffer load inline asm
// clang-format off
template
<
index_t
N
,
typename
T
>
struct
smem_load_trait
;
template
<
typename
T
>
struct
smem_load_trait
<
16
,
T
>
{
using
payload_t
=
fp32x4_t
;
};
template
<
typename
T
>
struct
smem_load_trait
<
8
,
T
>
{
using
payload_t
=
fp32x2_t
;
};
template
<
typename
T
>
struct
smem_load_trait
<
4
,
T
>
{
using
payload_t
=
float
;
};
template
<
typename
T
>
struct
smem_load_trait
<
2
,
T
>
{
using
payload_t
=
float
;
};
template
<
typename
T
>
struct
smem_load_trait
<
1
,
T
>
{
using
payload_t
=
float
;
};
// clang-format on
}
// namespace impl
// NOTE: smem load/store no need pre_nop to make sure dependency by sw, happy :)
template
<
index_t
>
struct
smem_load
;
template
<
>
struct
smem_load
<
16
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
16
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
16
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b128 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
8
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
8
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
8
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b64 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
4
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
4
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_b32 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
2
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
// subdword is buggy, use dword buf and convert manually
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_u16 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
template
<
>
struct
smem_load
<
1
>
{
template
<
typename
T
>
CK_TILE_DEVICE
void
operator
()(
T
&
value
,
index_t
v_offset
,
index_t
i_offset
)
{
static_assert
(
sizeof
(
T
)
==
4
);
using
mbuf_t
=
typename
impl
::
smem_load_trait
<
1
,
T
>::
payload_t
;
asm
volatile
(
"ds_read_u8 %0, %1 offset:%2"
:
"=v"
(
reinterpret_cast
<
mbuf_t
&>
(
value
))
// ! direct write
:
"v"
(
v_offset
),
"n"
(
i_offset
)
:
"memory"
);
}
};
// clang-format off
// clang-format off
namespace
impl
{
namespace
impl
{
...
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
...
@@ -976,6 +1069,16 @@ llvm_amdgcn_raw_buffer_atomic_max_fp64(double vdata,
int
soffset
,
// dst_wave_addr_offset
int
soffset
,
// dst_wave_addr_offset
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
int
glc_slc
)
__asm
(
"llvm.amdgcn.raw.buffer.atomic.fmax.f64"
);
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
bool
pre_nop
=
false
>
template
<
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
CK_TILE_DEVICE
void
async_buffer_load_dword_v
(
void
*
smem
,
int32x4_t
rsrc
,
int32x4_t
rsrc
,
...
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1313,6 +1416,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
int32x4_t
src_wave_buffer_resource
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_linear_addr_offset
,
index_t
flag
=
0
,
index_t
flag
=
0
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
...
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1327,7 +1431,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
0
,
src_linear_addr_offset
,
flag
,
flag
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
...
@@ -1337,7 +1441,7 @@ CK_TILE_DEVICE void amd_buffer_load_raw_impl(thread_buffer<T, N>& dst,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
src_wave_addr_offset
,
src_wave_addr_offset
,
0
,
src_linear_addr_offset
,
flag
,
flag
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
...
@@ -1365,6 +1469,43 @@ CK_TILE_DEVICE void amd_async_buffer_load_impl(T* smem,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
amd_async_buffer_load
(
CK_TILE_LDS_ADDR
T
*
smem
,
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_addr_offset
,
index_t
src_wave_addr_offset
,
index_t
src_immediate_addr_offset
=
0
,
index_t
flag
=
0
,
bool_constant
<
oob_conditional_check
>
=
{})
{
static_assert
(
sizeof
(
T
)
*
N
==
4
,
"wrong! not implemented vector size"
);
if
constexpr
(
oob_conditional_check
)
{
index_t
v_offset
=
flag
?
v_offset
:
src_wave_buffer_resource
[
2
];
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
sizeof
(
uint32_t
),
v_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
else
{
llvm_amdgcn_raw_buffer_load_lds
(
src_wave_buffer_resource
,
smem
,
sizeof
(
uint32_t
),
src_thread_addr_offset
,
src_wave_addr_offset
,
src_immediate_addr_offset
,
static_cast
<
index_t
>
(
coherence
));
}
}
template
<
index_t
N
,
template
<
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
>
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
CK_TILE_DEVICE
void
amd_buffer_store_impl_with_bytes
(
const
thread_buffer
<
int8_t
,
N
>
src_thread_data
,
...
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
...
@@ -1685,6 +1826,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
int32x4_t
dst_wave_buffer_resource
,
int32x4_t
dst_wave_buffer_resource
,
index_t
dst_thread_addr_offset
,
index_t
dst_thread_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
dst_wave_addr_offset
,
index_t
dst_linear_addr_offset
,
index_t
is_valid_element
=
1
)
index_t
is_valid_element
=
1
)
{
{
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
constexpr
index_t
bytes
=
sizeof
(
T
)
*
N
;
...
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
...
@@ -1698,7 +1840,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
0
,
dst_linear_addr_offset
,
is_valid_element
);
is_valid_element
);
}
}
else
else
...
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
...
@@ -1707,7 +1849,7 @@ CK_TILE_DEVICE void amd_buffer_store_raw_impl(const thread_buffer<T, N>& dst_thr
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
dst_wave_addr_offset
,
dst_wave_addr_offset
,
0
);
dst_linear_addr_offset
);
}
}
}
}
...
@@ -2014,6 +2156,7 @@ template <typename T,
...
@@ -2014,6 +2156,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
T
*
p_src_wave
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
src_element_space_size
,
index_t
src_element_space_size
,
index_t
is_valid_element
=
0
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
...
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
...
@@ -2022,12 +2165,14 @@ CK_TILE_DEVICE void amd_buffer_load_raw(thread_buffer<T, N>& dst,
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
dst
,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
0
,
0
,
src_linear_addr_offset
,
is_valid_element
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -2041,16 +2186,19 @@ template <typename T,
...
@@ -2041,16 +2186,19 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
CK_TILE_DEVICE
void
amd_buffer_load_raw
(
thread_buffer
<
T
,
N
>&
dst
,
const
int32x4_t
src_wave_buffer_resource
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
is_valid_element
=
0
,
index_t
is_valid_element
=
0
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
amd_buffer_load_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
dst
,
src_wave_buffer_resource
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
src_thread_addr_offset
,
0
,
0
,
src_linear_addr_offset
,
is_valid_element
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -2066,6 +2214,7 @@ template <typename T,
...
@@ -2066,6 +2214,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
T
*
p_src_wave
,
const
T
*
p_src_wave
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
index_t
src_element_space_size
,
index_t
src_element_space_size
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
...
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
...
@@ -2073,9 +2222,14 @@ CK_TILE_DEVICE void amd_async_buffer_load_with_oob_raw(T* smem,
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_src_wave
,
src_element_space_size
*
sizeof
(
T
));
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
}
// This version support buffer resource as input arg
// This version support buffer resource as input arg
...
@@ -2086,12 +2240,42 @@ template <typename T,
...
@@ -2086,12 +2240,42 @@ template <typename T,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob_raw
(
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
amd_async_buffer_load_impl
<
T
,
N
,
coherence
>
(
smem
,
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
0
,
bool_constant
<
pre_nop
>
{});
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
bool_constant
<
pre_nop
>
{});
}
// This version support buffer resource as input arg
template
<
typename
T
,
index_t
N
,
amd_buffer_coherence_enum
coherence
=
amd_buffer_coherence_enum
::
coherence_default
,
bool
oob_conditional_check
=
false
>
CK_TILE_DEVICE
void
amd_async_buffer_load_with_oob
(
CK_TILE_LDS_ADDR
T
*
smem
,
const
int32x4_t
src_wave_buffer_resource
,
index_t
src_thread_element_offset
,
index_t
src_linear_element_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
{
index_t
src_thread_addr_offset
=
src_thread_element_offset
*
sizeof
(
T
);
index_t
src_linear_addr_offset
=
src_linear_element_offset
*
sizeof
(
T
);
amd_async_buffer_load
<
T
,
N
,
coherence
>
(
smem
,
src_wave_buffer_resource
,
src_thread_addr_offset
,
0
,
src_linear_addr_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
}
// buffer_store requires:
// buffer_store requires:
...
@@ -2146,6 +2330,7 @@ template <typename T,
...
@@ -2146,6 +2330,7 @@ template <typename T,
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
CK_TILE_DEVICE
void
amd_buffer_store_raw
(
const
thread_buffer
<
T
,
N
>&
src_thread_data
,
T
*
p_dst_wave
,
T
*
p_dst_wave
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_thread_element_offset
,
const
index_t
dst_linear_element_offset
,
const
bool
dst_thread_element_valid
,
const
bool
dst_thread_element_valid
,
const
index_t
dst_element_space_size
)
const
index_t
dst_element_space_size
)
{
{
...
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
...
@@ -2153,11 +2338,13 @@ CK_TILE_DEVICE void amd_buffer_store_raw(const thread_buffer<T, N>& src_thread_d
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
make_wave_buffer_resource
(
p_dst_wave
,
dst_element_space_size
*
sizeof
(
T
));
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_thread_addr_offset
=
dst_thread_element_offset
*
sizeof
(
T
);
index_t
dst_linear_addr_offset
=
dst_linear_element_offset
*
sizeof
(
T
);
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
amd_buffer_store_raw_impl
<
T
,
N
,
coherence
,
oob_conditional_check
>
(
src_thread_data
,
dst_wave_buffer_resource
,
dst_wave_buffer_resource
,
dst_thread_addr_offset
,
dst_thread_addr_offset
,
0
,
0
,
dst_linear_addr_offset
,
dst_thread_element_valid
);
dst_thread_element_valid
);
}
}
...
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
...
@@ -2221,16 +2408,6 @@ CK_TILE_DEVICE void amd_buffer_atomic_max(const thread_buffer<T, N>& src_thread_
#endif
#endif
}
}
// Direct loads from global to LDS.
CK_TILE_DEVICE_EXTERN
void
llvm_amdgcn_raw_buffer_load_lds
(
int32x4_t
rsrc
,
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
,
index_t
size
,
index_t
voffset
,
index_t
soffset
,
index_t
offset
,
index_t
aux
)
__asm
(
"llvm.amdgcn.raw.buffer.load.lds"
);
template
<
typename
T
,
index_t
NumElemsPerThread
>
template
<
typename
T
,
index_t
NumElemsPerThread
>
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
CK_TILE_DEVICE
void
amd_direct_load_global_to_lds
(
const
T
*
global_base_ptr
,
const
index_t
global_offset
,
const
index_t
global_offset
,
...
...
include/ck_tile/core/arch/utility.hpp
View file @
a4522ae3
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
...
@@ -59,4 +59,47 @@ CK_TILE_DEVICE T warp_shuffle_down(const T& v_local, uint32_t lane_delta)
#endif
#endif
}
}
template
<
typename
T
>
CK_TILE_DEVICE
T
warp_shuffle
(
const
T
&
v_local
,
uint32_t
src_lane
)
{
#if 0
return __shfl(v_local, src_lane);
#elif
1
if
constexpr
(
sizeof
(
int32_t
)
>
sizeof
(
T
))
{
union
packet
{
int32_t
x
;
T
v
;
};
packet
p
;
p
.
v
=
v_local
;
packet
p_remote
;
p_remote
.
x
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
p
));
return
p_remote
.
v
;
}
else
if
constexpr
(
sizeof
(
int32_t
)
==
sizeof
(
T
))
{
const
int32_t
v_remote_tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
v_local
));
return
bit_cast
<
T
>
(
v_remote_tmp
);
}
else
{
static_assert
(
sizeof
(
T
)
%
sizeof
(
int32_t
)
==
0
,
"wrong!"
);
constexpr
index_t
elm
=
sizeof
(
T
)
/
sizeof
(
int32_t
);
using
vector_type
=
thread_buffer
<
int32_t
,
elm
>
;
auto
vs
=
bit_cast
<
vector_type
>
(
v_local
);
auto
vs_remote
=
vector_type
{};
static_for
<
0
,
elm
,
1
>
{}([
&
](
auto
i_e
)
{
int32_t
tmp
=
__builtin_amdgcn_ds_bpermute
(
src_lane
<<
2
,
bit_cast
<
int32_t
>
(
vs
[
i_e
]));
vs_remote
(
i_e
)
=
tmp
;
});
return
bit_cast
<
T
>
(
vs_remote
);
}
#endif
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/config.hpp
View file @
a4522ae3
...
@@ -32,13 +32,28 @@
...
@@ -32,13 +32,28 @@
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_DEVICE inline __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_HOST_DEVICE inline __host__ __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_DEVICE_EXTERN __device__
#define CK_TILE_HOST_DEVICE_EXTERN __host__ __device__
#else
#else
#define CK_TILE_HOST inline
#define CK_TILE_HOST inline
#define CK_TILE_DEVICE inline
#define CK_TILE_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_HOST_DEVICE inline
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_DEVICE_EXTERN
#define CK_TILE_HOST_DEVICE_EXTERN
#endif
#endif
// implementing the "memory address space" attribute
// https://llvm.org/docs/AMDGPUUsage.html#amdgpu-address-spaces-table
#ifdef __HIPCC_
#define CK_TILE_GENERIC_ADDR __attribute__((address_space(0)))
#define CK_TILE_GLOBAL_ADDR __attribute__((address_space(1)))
#define CK_TILE_LDS_ADDR __attribute__((address_space(3)))
#define CK_TILE_BUF_RES_ADDR __attribute__((address_space(8)))
#else
#define CK_TILE_GENERIC_ADDR
#define CK_TILE_GLOBAL_ADDR
#define CK_TILE_LDS_ADDR
#define CK_TILE_BUF_RES_ADDR
#endif
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#ifndef CK_TILE_USE_CUSTOM_DATA_TYPE
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#define CK_TILE_USE_CUSTOM_DATA_TYPE 0 // custom data type will generate extra move/bfi code
#endif
#endif
...
@@ -157,8 +172,11 @@
...
@@ -157,8 +172,11 @@
#endif
#endif
#endif
#endif
// workaround for ROCm 6.2 and later
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 3 && HIP_VERSION_PATCH >= 42131) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR > 3)
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
...
@@ -200,3 +218,8 @@
...
@@ -200,3 +218,8 @@
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
#endif
// workaround: compiler not emiting reciprocal instruction frm __frcp_rn()
#ifndef CK_TILE_WORKAROUND_SWDEV_383542
#define CK_TILE_WORKAROUND_SWDEV_383542 1
#endif
include/ck_tile/core/container/sequence.hpp
View file @
a4522ae3
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
...
@@ -1111,4 +1111,126 @@ CK_TILE_HOST_DEVICE constexpr auto generate_array(F&& f, number<N>)
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
namespace
impl
{
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
}
// namespace impl
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
impl
::
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
constexpr
auto
r
=
reverse_slice_sequence
(
Seq
{}.
reverse
(),
number
<
SliceSize
>
{},
Mask
{}.
reverse
());
return
make_tuple
(
r
[
number
<
0
>
{}].
reverse
(),
r
[
number
<
1
>
{}].
reverse
(),
number
<
Seq
::
size
()
-
r
[
number
<
2
>
{}]
-
1
>
{});
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/container/tuple.hpp
View file @
a4522ae3
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
concat_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
embed_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
...
@@ -603,7 +623,7 @@ template <typename... Ys,
...
@@ -603,7 +623,7 @@ template <typename... Ys,
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
{
static_assert
(
X
::
S
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
static_assert
(
X
::
s
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
+=
x
[
i
];
});
return
y
;
return
y
;
...
@@ -615,7 +635,7 @@ template <typename... Ys,
...
@@ -615,7 +635,7 @@ template <typename... Ys,
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-=
(
tuple
<
Ys
...
>&
y
,
const
X
&
x
)
{
{
static_assert
(
X
::
S
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
static_assert
(
X
::
s
ize
()
==
sizeof
...(
Ys
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
constexpr
index_t
NSize
=
sizeof
...(
Ys
);
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
i
)
{
y
[
i
]
-=
x
[
i
];
});
return
y
;
return
y
;
...
@@ -627,7 +647,7 @@ template <typename... Xs,
...
@@ -627,7 +647,7 @@ template <typename... Xs,
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
{
static_assert
(
Y
::
S
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
static_assert
(
Y
::
s
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
tuple
<
Xs
...
>
r
;
...
@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
...
@@ -635,13 +655,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator+(const tuple<Xs...>& x, const Y& y)
return
r
;
return
r
;
}
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
+
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
+
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
{
static_assert
(
Y
::
S
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
static_assert
(
Y
::
s
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
tuple
<
Xs
...
>
r
;
...
@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
...
@@ -649,13 +677,21 @@ CK_TILE_HOST_DEVICE constexpr auto operator-(const tuple<Xs...>& x, const Y& y)
return
r
;
return
r
;
}
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
-
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
-
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
template
<
typename
...
Xs
,
typename
Y
,
typename
Y
,
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
std
::
enable_if_t
<!
std
::
is_integral
<
Y
>
::
value
&&
!
std
::
is_floating_point
<
Y
>::
value
,
bool
>
=
false
>
false
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
Y
&
y
)
{
{
static_assert
(
Y
::
S
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
static_assert
(
Y
::
s
ize
()
==
sizeof
...(
Xs
),
"wrong! size not the same"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
tuple
<
Xs
...
>
r
;
tuple
<
Xs
...
>
r
;
...
@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
...
@@ -686,6 +722,14 @@ CK_TILE_HOST_DEVICE constexpr auto operator*(const tuple<Xs...>& x, Y a)
return
a
*
x
;
return
a
*
x
;
}
}
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
*
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong!"
);
constexpr
index_t
NSize
=
sizeof
...(
Xs
);
return
generate_tuple
([
&
](
auto
i
)
{
return
x
[
i
]
*
y
[
i
];
},
number
<
NSize
>
{});
}
template
<
typename
...
Xs
,
typename
...
Ys
>
template
<
typename
...
Xs
,
typename
...
Ys
>
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
CK_TILE_HOST_DEVICE
constexpr
auto
operator
/
(
const
tuple
<
Xs
...
>&
x
,
const
tuple
<
Ys
...
>&
y
)
{
{
...
...
include/ck_tile/core/numeric/int8.hpp
0 → 100644
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/utility/bit_cast.hpp"
#include "ck_tile/core/utility/random.hpp"
#include <stdint.h>
#include <type_traits>
#pragma once
namespace
ck_tile
{
// use int8_t directly for int8 arithemetic
// here one can use ck_tile::int8_t to access original int8_t
using
int8_t
=
int8_t
;
// limits
template
<
class
T
>
struct
numeric
;
template
<
>
struct
numeric
<
int8_t
>
{
// minimum finite value, or minimum positive normalized value for float
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
min
()
{
return
int8_t
(
-
128
);
}
// minumum finite value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
lowest
()
{
return
int8_t
(
-
128
);
}
// maximum finite value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
max
()
{
return
int8_t
(
127
);
}
// difference between 1.0 and next value representable by float
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
epsilon
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
round_error
()
{
return
1
;
// not used
}
// positive infinity value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
infinity
()
{
return
1
;
// not used
}
// quiet NaN
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
quiet_NaN
()
{
return
1
;
// not used
}
// signaling NaN
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
signaling_NaN
()
{
return
1
;
// not used
}
// smallest positive subnormal value
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
denorm_min
()
{
return
1
;
// not used
}
CK_TILE_HOST_DEVICE
static
constexpr
int8_t
zero
()
{
return
0
;
}
};
#if 0
template <typename T>
struct numeric_traits;
template <>
struct numeric_traits<int8_t>
{
static constexpr int exp = 5;
static constexpr int mant = 10;
static constexpr int bias = 15;
static constexpr uint16_t nan_mask = 0x7C00;
static constexpr uint16_t head_mask = 0xFC00;
static constexpr uint16_t mant_mask = 0x3FF;
static constexpr uint16_t exp_mask = 0x1F;
static constexpr uint32_t Inf = 0x7C00;
static constexpr uint32_t NegInf = 0xFC00;
static constexpr uint32_t NaN = 0x7C01;
static constexpr uint32_t Neg0 = 0x8000;
using bitwise_type = uint16_t;
};
#endif
CK_TILE_HOST_DEVICE
constexpr
float
int8_to_float
(
const
int8_t
&
x
)
{
return
static_cast
<
float
>
(
x
);
}
CK_TILE_HOST_DEVICE
constexpr
int8_t
float_to_int8
(
const
float
&
x
)
{
return
static_cast
<
int8_t
>
(
x
);
}
}
// namespace ck_tile
include/ck_tile/core/numeric/math.hpp
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -487,55 +487,12 @@ struct log2e<float>
...
@@ -487,55 +487,12 @@ struct log2e<float>
template
<
typename
T
=
double
>
template
<
typename
T
=
double
>
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
constexpr
T
log2e_v
=
log2e
<
T
>::
value
;
// math
CK_TILE_HOST_DEVICE
float
abs
(
const
float
&
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
}
CK_TILE_HOST_DEVICE
bool
isnan
(
const
float
&
x
)
{
uint32_t
xx
=
bit_cast
<
uint32_t
>
(
x
);
return
(
xx
&
0x7fffffff
)
>
0x7F800000
;
}
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
CK_TILE_DEVICE
float
exp
(
float
x
)
{
return
__ocml_exp_f32
(
x
);
};
CK_TILE_HOST
float
exp
(
float
x
)
{
return
std
::
expf
(
x
);
}
CK_TILE_DEVICE
CK_TILE_DEVICE
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
float
exp2
(
float
x
)
{
return
exp2f
(
x
);
};
CK_TILE_HOST
CK_TILE_HOST
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
float
exp2
(
float
x
)
{
return
std
::
exp2f
(
x
);
};
CK_TILE_DEVICE
float
log
(
float
x
)
{
return
__logf
(
x
);
};
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint16_t
sad_u16
(
uint16_t
x
,
uint16_t
y
,
uint16_t
acc
)
CK_TILE_DEVICE
uint16_t
sad_u16
(
uint16_t
x
,
uint16_t
y
,
uint16_t
acc
)
{
{
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
...
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
...
@@ -554,4 +511,933 @@ CK_TILE_HOST uint32_t sad_u32(uint32_t x, uint32_t y, uint32_t acc)
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
}
///////////////////////////////////////////////////////////////
}
// namespace ck_tile
// blow function need data type pre-defined
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#ifndef __HIP_DEVICE_COMPILE__
#include <cmath>
#endif
namespace
ck_tile
{
#if CK_TILE_WORKAROUND_SWDEV_383542
extern
"C"
CK_TILE_DEVICE
float
__ocml_native_recip_f32
(
float
);
#endif
// math functions for the host, some are implemented by calling C++ std functions
CK_TILE_HOST
float
abs
(
float
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
double
abs
(
double
x
)
{
return
std
::
abs
(
x
);
};
CK_TILE_HOST
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_HOST
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
}
#endif
CK_TILE_HOST
bool
isnan
(
float
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
double
x
)
{
return
std
::
isnan
(
x
);
};
CK_TILE_HOST
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_HOST
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_HOST
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_HOST
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
std
::
sqrt
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_HOST
float
sqrt
(
float
x
)
{
return
std
::
sqrt
(
x
);
};
CK_TILE_HOST
double
sqrt
(
double
x
)
{
return
std
::
sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tanh
<
float
>
(
float
x
)
{
return
std
::
tanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tanh
<
double
>
(
double
x
)
{
return
std
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acos
<
float
>
(
float
x
)
{
return
std
::
acosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acos
<
double
>
(
double
x
)
{
return
std
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_HOST
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
CK_TILE_HOST
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atan
<
float
>
(
float
x
)
{
return
std
::
atanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atan
<
double
>
(
double
x
)
{
return
std
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sin
<
float
>
(
float
x
)
{
return
std
::
sinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sin
<
double
>
(
double
x
)
{
return
std
::
sin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asin
<
float
>
(
float
x
)
{
return
std
::
asinf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asin
<
double
>
(
double
x
)
{
return
std
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
asinh
<
float
>
(
float
x
)
{
return
std
::
asinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
asinh
<
double
>
(
double
x
)
{
return
std
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cos
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
cosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cos
<
float
>
(
float
x
)
{
return
std
::
cosf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cos
<
double
>
(
double
x
)
{
return
std
::
cos
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
acosh
<
float
>
(
float
x
)
{
return
std
::
acoshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
acosh
<
double
>
(
double
x
)
{
return
std
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
tan
<
float
>
(
float
x
)
{
return
std
::
tanf
(
x
);
};
template
<
>
CK_TILE_HOST
double
tan
<
double
>
(
double
x
)
{
return
std
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
atanh
<
float
>
(
float
x
)
{
return
std
::
atanhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
atanh
<
double
>
(
double
x
)
{
return
std
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
sinh
<
float
>
(
float
x
)
{
return
std
::
sinhf
(
x
);
};
template
<
>
CK_TILE_HOST
double
sinh
<
double
>
(
double
x
)
{
return
std
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
ceil
<
float
>
(
float
x
)
{
return
std
::
ceilf
(
x
);
};
template
<
>
CK_TILE_HOST
double
ceil
<
double
>
(
double
x
)
{
return
std
::
ceil
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
cosh
<
float
>
(
float
x
)
{
return
std
::
coshf
(
x
);
};
template
<
>
CK_TILE_HOST
double
cosh
<
double
>
(
double
x
)
{
return
std
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_HOST
float
floor
<
float
>
(
float
x
)
{
return
std
::
floorf
(
x
);
};
template
<
>
CK_TILE_HOST
double
floor
<
double
>
(
double
x
)
{
return
std
::
floor
(
x
);
};
template
<
typename
T
>
CK_TILE_HOST
T
rcp
(
T
x
)
{
return
type_convert
<
T
>
(
1.
f
/
type_convert
<
float
>
(
x
));
};
template
<
typename
T
>
CK_TILE_HOST
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
exp
<
float
>
(
float
x
)
{
return
std
::
expf
(
x
);
}
template
<
>
CK_TILE_HOST
double
exp
<
double
>
(
double
x
)
{
return
std
::
exp
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
logf
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
log
<
float
>
(
float
x
)
{
return
std
::
logf
(
x
);
}
template
<
>
CK_TILE_HOST
double
log
<
double
>
(
double
x
)
{
return
std
::
log
(
x
);
}
template
<
typename
T
>
CK_TILE_HOST
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
std
::
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
}
template
<
>
CK_TILE_HOST
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
std
::
powf
(
x
,
gamma
);
}
template
<
>
CK_TILE_HOST
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
std
::
pow
(
x
,
gamma
);
}
template
<
typename
T
>
CK_TILE_HOST
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
std
::
expm1f
(
type_convert
<
float
>
(
x
)));
}
template
<
>
CK_TILE_HOST
float
expm1
<
float
>
(
float
x
)
{
return
std
::
expm1f
(
x
);
}
template
<
>
CK_TILE_HOST
double
expm1
<
double
>
(
double
x
)
{
return
std
::
expm1
(
x
);
}
// math functions for the HIP kernel, some are implemented by calling hip builtin functions
CK_TILE_DEVICE
float
abs
(
float
x
)
{
union
{
float
f32
;
uint32_t
u32
;
}
y
;
y
.
f32
=
x
;
y
.
u32
=
y
.
u32
&
0x7fffffff
;
return
y
.
f32
;
};
CK_TILE_DEVICE
double
abs
(
double
x
)
{
return
::
abs
(
x
);
};
CK_TILE_DEVICE
int8_t
abs
(
int8_t
x
)
{
int8_t
sgn
=
x
>>
(
8
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
CK_TILE_DEVICE
int32_t
abs
(
int32_t
x
)
{
int32_t
sgn
=
x
>>
(
32
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
int4_t
abs
(
int4_t
x
)
{
int4_t
sgn
=
x
>>
(
4
-
1
);
return
(
x
^
sgn
)
-
sgn
;
};
#endif
CK_TILE_DEVICE
fp16_t
abs
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
uint16_t
abs_xx
=
xx
&
0x7fff
;
fp16_t
abs_x
=
bit_cast
<
fp16_t
>
(
abs_xx
);
return
abs_x
;
};
CK_TILE_DEVICE
bool
isnan
(
float
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
double
x
)
{
return
::
isnan
(
x
);
};
CK_TILE_DEVICE
bool
isnan
(
int8_t
x
)
{
(
void
)
x
;
return
false
;
};
CK_TILE_DEVICE
bool
isnan
(
int32_t
x
)
{
(
void
)
x
;
return
false
;
};
#ifdef CK_TILE_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
CK_TILE_DEVICE
bool
isnan
(
int4_t
x
)
{
(
void
)
x
;
return
false
;
};
#endif
CK_TILE_DEVICE
bool
isnan
(
fp16_t
x
)
{
uint16_t
xx
=
bit_cast
<
uint16_t
>
(
x
);
return
(
xx
&
0x7FFF
)
>
0x7C00
;
};
CK_TILE_DEVICE
fp16_t
sqrt
(
fp16_t
x
)
{
return
static_cast
<
fp16_t
>
(
__builtin_amdgcn_sqrtf
(
static_cast
<
float
>
(
x
)));
};
CK_TILE_DEVICE
float
sqrt
(
float
x
)
{
return
__builtin_amdgcn_sqrtf
(
x
);
};
CK_TILE_DEVICE
double
sqrt
(
double
x
)
{
return
__builtin_amdgcn_sqrt
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tanh
<
float
>
(
float
x
)
{
return
::
tanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tanh
<
double
>
(
double
x
)
{
return
::
tanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acos
(
T
x
)
{
return
type_convert
<
T
>
(
::
acosf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acos
<
float
>
(
float
x
)
{
return
::
acosf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acos
<
double
>
(
double
x
)
{
return
::
acos
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
neg
(
T
x
)
{
return
type_convert
<
T
>
(
-
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
neg
<
float
>
(
float
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
double
neg
<
double
>
(
double
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int32_t
neg
<
int32_t
>
(
int32_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
int8_t
neg
<
int8_t
>
(
int8_t
x
)
{
return
-
x
;
};
template
<
>
CK_TILE_DEVICE
fp16_t
neg
<
fp16_t
>
(
fp16_t
x
)
{
return
-
x
;
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atan
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atan
<
float
>
(
float
x
)
{
return
::
atanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atan
<
double
>
(
double
x
)
{
return
::
atan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sin
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sin
<
float
>
(
float
x
)
{
return
::
sinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sin
<
double
>
(
double
x
)
{
return
::
sin
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
sin
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_sin_f16
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asin
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asin
<
float
>
(
float
x
)
{
return
::
asinf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asin
<
double
>
(
double
x
)
{
return
::
asin
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
asinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
asinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
asinh
<
float
>
(
float
x
)
{
return
::
asinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
asinh
<
double
>
(
double
x
)
{
return
::
asinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
acosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
acoshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
acosh
<
float
>
(
float
x
)
{
return
::
acoshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
acosh
<
double
>
(
double
x
)
{
return
::
acosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
tan
(
T
x
)
{
return
type_convert
<
T
>
(
::
tanf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
tan
<
float
>
(
float
x
)
{
return
::
tanf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
tan
<
double
>
(
double
x
)
{
return
::
tan
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
atanh
(
T
x
)
{
return
type_convert
<
T
>
(
::
atanhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
atanh
<
float
>
(
float
x
)
{
return
::
atanhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
atanh
<
double
>
(
double
x
)
{
return
::
atanh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
sinh
(
T
x
)
{
return
type_convert
<
T
>
(
::
sinhf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
sinh
<
float
>
(
float
x
)
{
return
::
sinhf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
sinh
<
double
>
(
double
x
)
{
return
::
sinh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
ceil
(
T
x
)
{
return
type_convert
<
T
>
(
::
ceilf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
ceil
<
float
>
(
float
x
)
{
return
::
ceilf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
ceil
<
double
>
(
double
x
)
{
return
::
ceil
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
ceil
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_ceil_f16
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
cosh
(
T
x
)
{
return
type_convert
<
T
>
(
::
coshf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
cosh
<
float
>
(
float
x
)
{
return
::
coshf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
cosh
<
double
>
(
double
x
)
{
return
::
cosh
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
floor
(
T
x
)
{
return
type_convert
<
T
>
(
::
floorf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
floor
<
float
>
(
float
x
)
{
return
::
floorf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
floor
<
double
>
(
double
x
)
{
return
::
floor
(
x
);
};
template
<
>
CK_TILE_DEVICE
fp16_t
floor
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_floor_f16
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
rcp
(
T
x
)
{
#if !CK_TILE_WORKAROUND_SWDEV_383542
return
__frcp_rn
(
x
);
#else
// return __ocml_native_recip_f32(x);
return
__builtin_amdgcn_rcpf
(
x
);
#endif
};
template
<
typename
T
>
CK_TILE_DEVICE
T
exp
(
T
x
)
{
return
type_convert
<
T
>
(
__ocml_exp_f32
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
exp
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_exp_f16
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
exp
<
float
>
(
float
x
)
{
return
__ocml_exp_f32
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
exp
<
double
>
(
double
x
)
{
return
exp
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
log
(
T
x
)
{
return
type_convert
<
T
>
(
__logf
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
fp16_t
log
<
fp16_t
>
(
fp16_t
x
)
{
return
__ocml_log_f16
(
x
);
};
template
<
>
CK_TILE_DEVICE
float
log
<
float
>
(
float
x
)
{
return
__logf
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
log
<
double
>
(
double
x
)
{
return
log
(
x
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
pow
(
T
x
,
T
gamma
)
{
return
type_convert
<
T
>
(
powf
(
type_convert
<
float
>
(
x
),
type_convert
<
float
>
(
gamma
)));
};
template
<
>
CK_TILE_DEVICE
float
pow
<
float
>
(
float
x
,
float
gamma
)
{
return
powf
(
x
,
gamma
);
};
template
<
>
CK_TILE_DEVICE
double
pow
<
double
>
(
double
x
,
double
gamma
)
{
return
pow
(
x
,
gamma
);
};
template
<
typename
T
>
CK_TILE_DEVICE
T
expm1
(
T
x
)
{
return
type_convert
<
T
>
(
expm1f
(
type_convert
<
float
>
(
x
)));
};
template
<
>
CK_TILE_DEVICE
float
expm1
<
float
>
(
float
x
)
{
return
expm1f
(
x
);
};
template
<
>
CK_TILE_DEVICE
double
expm1
<
double
>
(
double
x
)
{
return
expm1
(
x
);
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/numeric/type_convert.hpp
View file @
a4522ae3
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
...
@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
int8_t
,
int8
)
CK_TILE_TYPE_CONVERT
(
int8_t
,
int8
,
float
,
float
)
#undef CK_TILE_TYPE_CONVERT
#undef CK_TILE_TYPE_CONVERT
#endif
#endif
...
...
include/ck_tile/core/tensor/buffer_view.hpp
View file @
a4522ae3
...
@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
...
@@ -91,8 +91,10 @@ struct buffer_view<address_space_enum::generic,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
...
@@ -107,11 +109,11 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
+
linear_offset
]),
sizeof
(
X
));
return
tmp
;
return
tmp
;
#else
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
+
linear_offset
]);
#endif
#endif
}
}
else
else
...
@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
...
@@ -134,17 +136,17 @@ struct buffer_view<address_space_enum::generic,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
// FIXME: remove memory_operation_enum::add
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
}
}
}
}
...
@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
...
@@ -154,7 +156,7 @@ struct buffer_view<address_space_enum::generic,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
...
@@ -169,9 +171,9 @@ struct buffer_view<address_space_enum::generic,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
__builtin_memcpy
(
&
(
p_data_
[
i
+
linear_offset
]),
&
tmp
,
sizeof
(
X
));
#else
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
+
linear_offset
])
=
x
;
#endif
#endif
}
}
}
}
...
@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
...
@@ -276,8 +278,10 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -303,7 +307,7 @@ struct buffer_view<address_space_enum::global,
t_per_x
,
t_per_x
,
Coherence
,
Coherence
,
oob_conditional_check
>
(
oob_conditional_check
>
(
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
}
else
else
{
{
...
@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
...
@@ -311,8 +315,11 @@ struct buffer_view<address_space_enum::global,
remove_cvref_t
<
T
>
,
remove_cvref_t
<
T
>
,
t_per_x
,
t_per_x
,
Coherence
,
Coherence
,
oob_conditional_check
>
(
oob_conditional_check
>
(
p_data_
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
,
invalid_element_value_
);
i
+
linear_offset
,
is_valid_element
,
buffer_size_
,
invalid_element_value_
);
}
}
}
}
else
else
...
@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
...
@@ -322,11 +329,11 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
+
linear_offset
]),
sizeof
(
X
));
return
tmp
;
return
tmp
;
#else
#else
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
]);
return
*
c_style_pointer_cast
<
const
X
*>
(
&
p_data_
[
i
+
linear_offset
]);
#endif
#endif
}
}
else
else
...
@@ -352,7 +359,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -352,7 +359,8 @@ struct buffer_view<address_space_enum::global,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
index_t
v_offset
,
index_t
i_offset
,
bool
is_valid_element
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -366,7 +374,38 @@ struct buffer_view<address_space_enum::global,
...
@@ -366,7 +374,38 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
dst
,
cached_buf_res_
,
v_offset
,
i_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_x_vector
=
vector_traits
<
remove_cvref_t
<
X
>>::
vector_size
;
static_assert
(
scalar_per_x_vector
%
scalar_per_t_vector
==
0
,
"wrong! X should contain multiple T"
);
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
...
@@ -378,6 +417,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -378,6 +417,7 @@ struct buffer_view<address_space_enum::global,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
index_t
i
,
index_t
linear_offset
,
bool
/*is_valid_element*/
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
...
@@ -391,7 +431,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -391,7 +431,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
amd_async_buffer_load_with_oob_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{});
smem
,
cached_buf_res_
,
i
,
linear_offset
,
bool_constant
<
pre_nop
>
{});
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
...
@@ -401,25 +441,25 @@ struct buffer_view<address_space_enum::global,
...
@@ -401,25 +441,25 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_add
)
{
{
this
->
template
atomic_add
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
atomic_add
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
atomic_max
)
{
{
this
->
template
atomic_max
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
atomic_max
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
// FIXME: remove memory_operation_enum::add
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
// tmp += x;
// tmp += x;
// this->template set<X>(i, is_valid_element, tmp);
// this->template set<X>(i, is_valid_element, tmp);
}
}
...
@@ -432,7 +472,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -432,7 +472,7 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -453,7 +493,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -453,7 +493,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
amd_buffer_store
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
}
else
else
{
{
...
@@ -462,9 +502,9 @@ struct buffer_view<address_space_enum::global,
...
@@ -462,9 +502,9 @@ struct buffer_view<address_space_enum::global,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
__builtin_memcpy
(
&
(
p_data_
[
i
+
linear_offset
]),
&
tmp
,
sizeof
(
X
));
#else
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
+
linear_offset
])
=
x
;
#endif
#endif
}
}
}
}
...
@@ -477,7 +517,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -477,7 +517,7 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set_raw
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set_raw
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -489,7 +529,7 @@ struct buffer_view<address_space_enum::global,
...
@@ -489,7 +529,7 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
amd_buffer_store_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
,
linear_offset
,
is_valid_element
,
buffer_size_
);
}
}
template
<
typename
X
,
template
<
typename
X
,
...
@@ -497,7 +537,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -497,7 +537,8 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_add
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
atomic_add
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
using
scalar_t
=
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
;
...
@@ -532,13 +573,13 @@ struct buffer_view<address_space_enum::global,
...
@@ -532,13 +573,13 @@ struct buffer_view<address_space_enum::global,
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_add
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
}
else
else
{
{
if
(
is_valid_element
)
if
(
is_valid_element
)
{
{
atomic_add_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
atomic_add_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
+
linear_offset
],
x
);
}
}
}
}
}
}
...
@@ -548,7 +589,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -548,7 +589,8 @@ struct buffer_view<address_space_enum::global,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
atomic_max
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
atomic_max
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -572,11 +614,11 @@ struct buffer_view<address_space_enum::global,
...
@@ -572,11 +614,11 @@ struct buffer_view<address_space_enum::global,
if
constexpr
(
use_amd_buffer_addressing
)
if
constexpr
(
use_amd_buffer_addressing
)
{
{
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
amd_buffer_atomic_max
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
x
,
p_data_
,
i
,
is_valid_element
,
buffer_size_
);
x
,
p_data_
,
i
+
linear_offset
,
is_valid_element
,
buffer_size_
);
}
}
else
if
(
is_valid_element
)
else
if
(
is_valid_element
)
{
{
atomic_max_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
],
x
);
atomic_max_g
<
remove_cvref_t
<
T
>
,
t_per_x
>
(
&
p_data_
[
i
+
linear_offset
],
x
);
}
}
}
}
...
@@ -668,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
...
@@ -668,8 +710,10 @@ struct buffer_view<address_space_enum::lds,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -684,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
...
@@ -684,14 +728,14 @@ struct buffer_view<address_space_enum::lds,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
;
X
tmp
;
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
]),
sizeof
(
X
));
__builtin_memcpy
(
&
tmp
,
&
(
p_data_
[
i
+
linear_offset
]),
sizeof
(
X
));
return
tmp
;
return
tmp
;
#else
#else
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
using
buf_t
=
ext_vector_t
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
scalar_per_t_vector
*
scalar_per_x_vector
>
;
scalar_per_t_vector
*
scalar_per_x_vector
>
;
// using buf_t = ushort __attribute__((ext_vector_type(8)));
// using buf_t = ushort __attribute__((ext_vector_type(8)));
auto
rtn
=
*
c_style_pointer_cast
<
const
buf_t
*>
(
&
p_data_
[
i
]);
auto
rtn
=
*
c_style_pointer_cast
<
const
buf_t
*>
(
&
p_data_
[
i
+
linear_offset
]);
return
bit_cast
<
X
>
(
rtn
);
return
bit_cast
<
X
>
(
rtn
);
#endif
#endif
}
}
...
@@ -708,6 +752,23 @@ struct buffer_view<address_space_enum::lds,
...
@@ -708,6 +752,23 @@ struct buffer_view<address_space_enum::lds,
}
}
}
}
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
v_offset
,
index_t
i_offset
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
smem_load
<
sizeof
(
X
)
>
{}(
dst
,
v_offset
*
sizeof
(
T
),
i_offset
*
sizeof
(
T
));
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
memory_operation_enum
Op
,
template
<
memory_operation_enum
Op
,
typename
X
,
typename
X
,
...
@@ -715,17 +776,17 @@ struct buffer_view<address_space_enum::lds,
...
@@ -715,17 +776,17 @@ struct buffer_view<address_space_enum::lds,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
// FIXME: remove memory_operation_enum::add
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
}
}
}
}
...
@@ -735,7 +796,7 @@ struct buffer_view<address_space_enum::lds,
...
@@ -735,7 +796,7 @@ struct buffer_view<address_space_enum::lds,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -751,6 +812,7 @@ struct buffer_view<address_space_enum::lds,
...
@@ -751,6 +812,7 @@ struct buffer_view<address_space_enum::lds,
bool
constexpr
workaround_int8_ds_write_issue
=
false
;
bool
constexpr
workaround_int8_ds_write_issue
=
false
;
#endif
#endif
i
+=
linear_offset
;
// simplicity
if
constexpr
(
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
if
constexpr
(
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
,
int8_t
>::
value
&&
int8_t
>::
value
&&
workaround_int8_ds_write_issue
)
workaround_int8_ds_write_issue
)
...
@@ -952,8 +1014,10 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -952,8 +1014,10 @@ struct buffer_view<address_space_enum::vgpr,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get
(
index_t
i
,
get
(
index_t
i
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
index_t
/*linear_offset*/
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -995,17 +1059,17 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -995,17 +1059,17 @@ struct buffer_view<address_space_enum::vgpr,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
update
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
update
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
if
constexpr
(
Op
==
memory_operation_enum
::
set
)
{
{
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
);
}
}
// FIXME: remove memory_operation_enum::add
// FIXME: remove memory_operation_enum::add
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
else
if
constexpr
(
Op
==
memory_operation_enum
::
add
)
{
{
auto
tmp
=
this
->
template
get
<
X
>(
i
,
is_valid_element
);
auto
tmp
=
this
->
template
get
<
X
>(
i
,
linear_offset
,
is_valid_element
);
this
->
template
set
<
X
>(
i
,
is_valid_element
,
x
+
tmp
);
this
->
template
set
<
X
>(
i
,
linear_offset
,
is_valid_element
,
x
+
tmp
);
}
}
}
}
...
@@ -1015,7 +1079,7 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -1015,7 +1079,7 @@ struct buffer_view<address_space_enum::vgpr,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
void
set
(
index_t
i
,
bool
is_valid_element
,
const
X
&
x
)
CK_TILE_DEVICE
void
set
(
index_t
i
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
)
{
{
// X contains multiple T
// X contains multiple T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -1030,9 +1094,9 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -1030,9 +1094,9 @@ struct buffer_view<address_space_enum::vgpr,
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
#if CK_TILE_EXPERIMENTAL_USE_MEMCPY_FOR_VECTOR_ACCESS
X
tmp
=
x
;
X
tmp
=
x
;
__builtin_memcpy
(
&
(
p_data_
[
i
]),
&
tmp
,
sizeof
(
X
));
__builtin_memcpy
(
&
(
p_data_
[
i
+
linear_offset
]),
&
tmp
,
sizeof
(
X
));
#else
#else
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
*
c_style_pointer_cast
<
X
*>
(
&
p_data_
[
i
+
linear_offset
])
=
x
;
#endif
#endif
}
}
}
}
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
a4522ae3
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,6 +12,7 @@
...
@@ -12,6 +12,7 @@
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tile_window.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
#include "ck_tile/core/tensor/null_tensor.hpp"
...
@@ -28,9 +29,48 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
...
@@ -28,9 +29,48 @@ CK_TILE_DEVICE auto load_tile(const tile_window_with_static_distribution<BottomT
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{})
{
{
return
tile_window
.
load
(
bool_constant
<
oob_conditional_check
>
{});
return
tile_window
.
load
(
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
DistributedTensor_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load_tile
(
DistributedTensor_
&
dst_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
{
return
tile_window
.
load
(
dst_tile
,
bool_constant
<
oob_conditional_check
>
{});
}
/**
* @brief Loads a tile of data using inline assembly.
*
* @note Bare in mind that loading data this way, you have to manually initialize your
* thread buffer and synchronize load afterwards in order to make sure it's done before
* using loaded data from registers
* @see `tile_window_with_static_distribution::init_raw()` and `buffer_view.hpp`
* @see `buffer_load_fence()`
*/
template
<
typename
T
,
template
<
typename
T
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
...
@@ -46,7 +86,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
...
@@ -46,7 +86,27 @@ CK_TILE_DEVICE auto load_tile_raw(T& tile,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
tile_window
.
load_raw
(
tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
T
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
load_raw
(
tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
...
@@ -66,7 +126,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
...
@@ -66,7 +126,26 @@ async_load_tile_raw(LdsTileWindow_&& lds_tile,
bool_constant
<
pre_nop
>
=
{})
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load_raw
(
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
return
tile_window
.
async_load_raw
(
lds_tile
,
number
<-
1
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
a4522ae3
...
@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
...
@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
}
template
<
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
null_tile_window
<
WindowLengths
>&
t
,
const
StaticTileDistribution
&
)
{
return
t
;
}
template
<
typename
WindowLengths
>
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
...
...
include/ck_tile/core/tensor/shuffle_tile.hpp
View file @
a4522ae3
...
@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
...
@@ -109,7 +109,7 @@ CK_TILE_DEVICE void shuffle_tile_impl_in_thread(OutTensor& out_tensor, const InT
// get input vectors
// get input vectors
static_for
<
0
,
num_vec_in
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
num_vec_in
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
idx_y_in
=
generate_
array
(
constexpr
auto
idx_y_in
=
generate_
tuple
(
[
&
](
auto
ii
)
{
[
&
](
auto
ii
)
{
return
ii
==
y_dim_vec_out
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
return
ii
==
y_dim_vec_out
?
idx_y_start
[
ii
]
+
i
:
idx_y_start
[
ii
];
},
},
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
a4522ae3
...
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
...
@@ -187,4 +187,18 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
});
}
}
// this function used inside span loop over
template
<
typename
YLengths
,
index_t
XUnpacks
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks_from_x_unpacks
(
YLengths
,
number
<
XUnpacks
>
)
{
constexpr
auto
y_size
=
reduce_on_sequence
(
YLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
y_packs
=
number
<
XUnpacks
>
{};
static_assert
(
y_size
%
y_packs
==
0
);
constexpr
auto
y_slice_size
=
y_size
/
y_packs
;
constexpr
auto
slice_info
=
slice_sequence
(
YLengths
{},
number
<
y_slice_size
>
{});
constexpr
auto
unpacks
=
slice_info
[
number
<
1
>
{}];
return
unpacks
;
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/store_tile.hpp
View file @
a4522ae3
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
#include "ck_tile/core/tensor/tile_window_linear.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -72,7 +73,7 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
...
@@ -72,7 +73,7 @@ store_tile(tile_window_with_static_distribution<BottomTensorView_,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
{
tile_window
.
store
(
dstr_tensor
);
tile_window
.
store
(
dstr_tensor
,
number
<-
1
>
{}
);
}
}
template
<
typename
BottomTensorView_
,
template
<
typename
BottomTensorView_
,
...
@@ -87,7 +88,33 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
...
@@ -87,7 +88,33 @@ store_tile_raw(tile_window_with_static_distribution<BottomTensorView_,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
{
tile_window
.
store_raw
(
dstr_tensor
);
tile_window
.
store_raw
(
dstr_tensor
,
number
<-
1
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store
(
dstr_tensor
,
number
<-
1
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
>
CK_TILE_DEVICE
void
store_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
{
tile_window
.
store_raw
(
dstr_tensor
,
number
<-
1
>
{});
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/sweep_tile.hpp
View file @
a4522ae3
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/functional_with_tuple.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
...
@@ -27,4 +28,281 @@ CK_TILE_DEVICE void sweep_tile_span(TileDistributedSpan_, const F& f)
});
});
}
}
// unpacked span, this version support span with unpack(multi-arg) functor
//
template
<
typename
TileDistributedSpan_
,
// tile_distributed_span<...>
typename
F
,
// signature: F(tile_distributed_index<...>)
typename
Unpacks
=
typename
uniform_sequence_gen
<
TileDistributedSpan_
::
Impl
::
size
(),
1
>
::
type
>
CK_TILE_DEVICE
void
sweep_tile_uspan
(
TileDistributedSpan_
,
const
F
&
f
,
Unpacks
=
{})
{
using
DstrSpan
=
remove_cvref_t
<
TileDistributedSpan_
>
;
static_uford
<
typename
DstrSpan
::
Impl
,
Unpacks
>
{}(
[
&
](
auto
...
dstr_idx_impl
)
{
f
(
detail
::
make_tile_distributed_index
(
dstr_idx_impl
)...);
});
}
namespace
impl
{
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl
;
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
i_idx
))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
)))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<>>
{
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
return
1
;
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
unpack
(
f
,
span_idx
);
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
unpack
(
f
,
span_idx
);
}
};
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl_0
;
// TODO: support empty tuple to remove this "entry-point" like function
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
i_idx
)...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
))...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
}
// namespace impl
/*
* Enhanced sweep-tile utility, can control unpacks along each X-dim
* the lambda function argument is the distributed-idx, which can directly
* plugged into the distributed tensor as setter/getter
*
* e.g. below function, y with the type DistributedTensor, r is row scale
*
* // sweep tile 1 by 1
* sweep_tile<DistributedTensor>([&](auto idx) {
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
* y(idx) = y(idx) * r(row_id);
* });
*
* // sweep tile with 2 pixel from last dim each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_0, auto idx_1) {
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
* y(idx_0) = y(idx_0) * r(row_id);
* y(idx_1) = y(idx_1) * r(row_id);
* },
* sequence<1, 2>{});
*
* // sweep tile with 2x2 pixel each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
* y(idx_00) = y(idx_00) * r(row_id0);
* y(idx_01) = y(idx_01) * r(row_id0);
* y(idx_10) = y(idx_10) * r(row_id1);
* y(idx_11) = y(idx_11) * r(row_id1);
* },
* sequence<2, 2>{});
*
* TODO: do we need constexpr? lambda function could be non-constexpr
*/
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
F
&
f
,
UnpacksPerXDim
=
{})
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
);
}
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
DistributedTensor
&
,
const
F
&
f
,
UnpacksPerXDim
=
{})
{
sweep_tile
<
DistributedTensor
,
F
,
UnpacksPerXDim
>
(
f
,
UnpacksPerXDim
{});
}
/*
* construct a sweep tile instance, which support issue the lambda one by one
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
* the functionality is the same as sweep_tile()
*/
template
<
typename
DistributedTensor_
,
typename
F_
,
typename
UnpacksPerXDim_
=
typename
uniform_sequence_gen
<
DistributedTensor_
::
get_num_of_dimension
(),
1
>
::
type
>
struct
tile_sweeper
{
using
DistributedTensor
=
remove_cvref_t
<
DistributedTensor_
>
;
using
F
=
remove_cvref_t
<
F_
>
;
using
UnpacksPerXDim
=
remove_cvref_t
<
UnpacksPerXDim_
>
;
CK_TILE_HOST_DEVICE
tile_sweeper
(
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{}
CK_TILE_HOST_DEVICE
tile_sweeper
(
const
DistributedTensor
&
,
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
tmp
=
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{};
return
tmp
.
get_num_of_access
();
}
CK_TILE_HOST_DEVICE
void
operator
()()
const
{
sweep_tile
<
DistributedTensor
>
(
f
,
UnpacksPerXDim
{});
}
template
<
index_t
i_access
>
CK_TILE_HOST_DEVICE
void
operator
()(
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
,
number
<
i_access
>
{});
}
F
f
;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN tile_sweeper(const F&, U = {})->tile_sweeper<T, F, U>;
// deduction guide
template
<
typename
T
,
typename
F
,
typename
U
=
typename
uniform_sequence_gen
<
T
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE_EXTERN
tile_sweeper
(
const
T
&
,
const
F
&
,
U
=
{})
->
tile_sweeper
<
T
,
F
,
U
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/tensor/tensor_view.hpp
View file @
a4522ae3
...
@@ -16,6 +16,24 @@
...
@@ -16,6 +16,24 @@
namespace
ck_tile
{
namespace
ck_tile
{
/*
* tensor_view
* abstract the underneath memory buffer(global, LDS, etc...)
* and provide a unified get/set function for access
*
* For addressing into the buffer we use 2 variable to control:
* coord : ND tensor coordinate, will calculate the actual offset inside
* linear_offset : 1D offset, will be used in the immediate field of
* the buffer instruction to help reduce register usage
*
* User can use either of the field, or both to indexing into the tensor
*
* We usually provide 2 set of API for buffer get/set, e.g.
* get_vectorized_elements()/get_vectorized_elements_raw()
* the former usually will call intrinsic or normal C function, the later
* usually will call inline-asm function
*
*/
template
<
typename
BufferView_
,
template
<
typename
BufferView_
,
typename
TensorDesc_
,
typename
TensorDesc_
,
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
memory_operation_enum
DstInMemOp_
=
memory_operation_enum
::
set
>
...
@@ -49,22 +67,6 @@ struct tensor_view
...
@@ -49,22 +67,6 @@ struct tensor_view
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_buffer_view
()
{
return
buf_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_buffer_view
()
{
return
buf_
;
}
#if 0
CK_TILE_HOST_DEVICE constexpr DataType get_element(const TensorCoord& coord) const
{
return buf_.template get<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord));
}
CK_TILE_HOST_DEVICE constexpr void set_element(const TensorCoord& coord, const DataType& x)
{
buf_.template set<DataType>(
coord.get_offset(),
coordinate_has_valid_offset_assuming_top_index_is_valid(desc_, coord),
x);
}
#endif
// X is vector of DataType.
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
...
@@ -75,14 +77,34 @@ struct tensor_view
...
@@ -75,14 +77,34 @@ struct tensor_view
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{})
const
{
{
return
buf_
.
template
get
<
X
>(
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{});
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
remove_cvref_t
<
X
>
get_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
// flag
bool_constant
<
oob_conditional_check
>
=
{})
const
{
return
buf_
.
template
get
<
X
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
// X is vector of DataType.
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
...
@@ -94,12 +116,90 @@ struct tensor_view
...
@@ -94,12 +116,90 @@ struct tensor_view
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
dst
,
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
CK_TILE_LDS_ADDR
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
)
const
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool_constant
<
pre_nop
>
=
{})
const
{
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
bool_constant
<
pre_nop
>
{});
}
}
...
@@ -110,11 +210,15 @@ struct tensor_view
...
@@ -110,11 +210,15 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements_raw
(
CK_TILE_HOST_DEVICE
constexpr
void
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{})
const
async_get_vectorized_elements_raw
(
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
async_get_raw
<
X
>(
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
smem
,
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
bool_constant
<
pre_nop
>
{});
}
}
// X is vector of DataType.
// X is vector of DataType.
...
@@ -125,11 +229,15 @@ struct tensor_view
...
@@ -125,11 +229,15 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements
(
CK_TILE_HOST_DEVICE
constexpr
void
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
...
@@ -140,15 +248,53 @@ struct tensor_view
...
@@ -140,15 +248,53 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
CK_TILE_HOST_DEVICE
constexpr
void
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
set_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
set_vectorized_elements_raw
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
set_raw
<
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
// X is vector of DataType.
// X is vector of DataType.
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
...
@@ -157,15 +303,36 @@ struct tensor_view
...
@@ -157,15 +303,36 @@ struct tensor_view
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
CK_TILE_HOST_DEVICE
constexpr
void
const
TensorCoord
&
coord
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
coord
.
get_offset
(),
linear_offset
,
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
x
);
x
);
}
}
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
update_vectorized_elements
(
const
TensorCoord
&
coord
,
index_t
linear_offset
,
bool
is_valid_element
,
const
X
&
x
,
bool_constant
<
oob_conditional_check
>
=
{})
{
buf_
.
template
update
<
DstInMemOp
,
X
,
oob_conditional_check
>(
coord
.
get_offset
(),
linear_offset
,
is_valid_element
,
x
);
}
CK_TILE_HOST_DEVICE
void
print
()
const
CK_TILE_HOST_DEVICE
void
print
()
const
{
{
printf
(
"tensor_view{"
);
printf
(
"tensor_view{"
);
...
...
include/ck_tile/core/tensor/tile_distribution.hpp
View file @
a4522ae3
...
@@ -17,6 +17,14 @@
...
@@ -17,6 +17,14 @@
namespace
ck_tile
{
namespace
ck_tile
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
return
Distribution
::
_get_partition_index
();
}
}
// namespace detail
// distributed span
// distributed span
template
<
index_t
...
PartialHsLengths
>
template
<
index_t
...
PartialHsLengths
>
struct
tile_distributed_span
struct
tile_distributed_span
...
@@ -83,6 +91,21 @@ struct tile_distribution
...
@@ -83,6 +91,21 @@ struct tile_distribution
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_p
()
{
return
NDimP
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension_r
()
{
return
NDimR
;
}
CK_TILE_HOST_DEVICE
static
auto
_get_partition_index
()
{
// only support warp-tile and block-tile
static_assert
(
NDimP
==
1
or
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_lengths
()
{
{
#if 0
#if 0
...
@@ -149,6 +172,16 @@ struct tile_distribution
...
@@ -149,6 +172,16 @@ struct tile_distribution
}
}
#endif
#endif
template
<
typename
PartitionIndex
=
decltype
(
_get_partition_index
())>
CK_TILE_HOST_DEVICE
auto
calculate_index
(
const
PartitionIndex
&
ps_idx
=
_get_partition_index
())
const
{
const
auto
ps_ys_idx
=
container_concat
(
ps_idx
,
array
<
index_t
,
NDimY
>
{
0
});
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
ps_ys_to_xs_
,
ps_ys_idx
);
return
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
get_distributed_spans
()
{
{
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
constexpr
auto
distributed_spans_impl
=
DstrEncode
::
detail
::
distributed_spans_lengthss_
;
...
@@ -421,6 +454,7 @@ struct tile_distribution_detail
...
@@ -421,6 +454,7 @@ struct tile_distribution_detail
}
// namespace detail
}
// namespace detail
#if 0
// this returns a constexpr tile_distribution
// this returns a constexpr tile_distribution
template <typename StaticTileDistributionEncoding_>
template <typename StaticTileDistributionEncoding_>
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistributionEncoding_)
...
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
...
@@ -457,6 +491,7 @@ CK_TILE_HOST_DEVICE constexpr auto make_tile_distribution(StaticTileDistribution
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
detail::tile_distribution_detail<remove_cvref_t<decltype(rh_major_minor_to_hidden_ids)>>>{
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
ps_ys_to_xs_adaptor, ys_to_d_descriptor};
}
}
#endif
// this returns a static tile_distribution
// this returns a static tile_distribution
template
<
typename
StaticTileDistributionEncoding_
>
template
<
typename
StaticTileDistributionEncoding_
>
...
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
...
@@ -499,129 +534,6 @@ CK_TILE_HOST_DEVICE constexpr auto make_static_tile_distribution(StaticTileDistr
//***********************************************************************************
//***********************************************************************************
namespace
detail
{
namespace
detail
{
template
<
typename
Distribution
>
CK_TILE_HOST_DEVICE
auto
get_partition_index
(
Distribution
)
{
// only support warp-tile and block-tile
static_assert
(
Distribution
::
NDimP
==
1
or
Distribution
::
NDimP
==
2
,
"wrong!"
);
if
constexpr
(
Distribution
::
NDimP
==
1
)
{
return
array
<
index_t
,
1
>
{
get_lane_id
()};
}
else
if
constexpr
(
Distribution
::
NDimP
==
2
)
{
return
array
<
index_t
,
2
>
{
get_warp_id
(),
get_lane_id
()};
}
}
template
<
typename
,
typename
,
typename
,
index_t
>
struct
reverse_slice_sequence_impl
;
template
<
index_t
x
,
index_t
...
xs
,
index_t
m
,
index_t
...
ms
,
index_t
id
,
index_t
...
ids
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
,
xs
...
>
,
sequence
<
m
,
ms
...
>
,
sequence
<
id
,
ids
...
>
,
SliceSize
>
{
using
old_scan
=
reverse_slice_sequence_impl
<
sequence
<
xs
...
>
,
sequence
<
ms
...
>
,
sequence
<
ids
...
>
,
SliceSize
>
;
static
constexpr
auto
slice_size
=
old_scan
::
remaining_slice_sizes
::
front
().
value
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
typename
sequence_merge
<
sequence
<
slice_length
>
,
typename
old_scan
::
dim_lengths
>::
type
;
using
dim_slices
=
typename
sequence_merge
<
sequence
<
x
/
slice_length
>
,
typename
old_scan
::
dim_slices
>::
type
;
using
remaining_slice_sizes
=
typename
sequence_merge
<
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
,
typename
old_scan
::
remaining_slice_sizes
>::
type
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
_split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
_split_idx
=
std
::
conditional_t
<
_split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_flag
=
_split_flag
||
old_scan
::
split_flag
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
old_scan
::
split_flag
,
number
<
old_scan
::
split_idx
>
,
number
<
_split_idx
>>::
value
;
};
template
<
index_t
x
,
index_t
m
,
index_t
id
,
index_t
SliceSize
>
struct
reverse_slice_sequence_impl
<
sequence
<
x
>
,
sequence
<
m
>
,
sequence
<
id
>
,
SliceSize
>
{
static
constexpr
auto
slice_size
=
SliceSize
;
static
constexpr
auto
slice_length
=
std
::
conditional_t
<
m
,
number
<
gcd
(
x
,
slice_size
)
>
,
number
<
x
>>::
value
;
using
dim_lengths
=
sequence
<
slice_length
>
;
using
dim_slices
=
sequence
<
x
/
slice_length
>
;
using
remaining_slice_sizes
=
std
::
conditional_t
<
m
,
sequence
<
slice_size
/
slice_length
>
,
sequence
<
slice_size
>>
;
// the first idx that sliced length not equal to original length
static
constexpr
index_t
_flag
=
slice_length
!=
x
&&
remaining_slice_sizes
{}.
front
().
value
==
1
;
static
constexpr
index_t
split_flag
=
std
::
conditional_t
<
m
,
number
<
_flag
>
,
number
<
0
>>::
value
;
static
constexpr
index_t
split_idx
=
std
::
conditional_t
<
split_flag
,
number
<
id
>
,
number
<
0
>>::
value
;
};
// clang-format off
// input a sequence(with optional mask), and the SliceSize : size per slice
// output the sequence each slice, and number of slices
//
// e.g. <2, 1, 4, 2>, 8 -> lengths:<1, 1, 4, 2> , nums: <2, 1, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 4, 1, 2>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 1> : 16 slices , slice_idx: 2
// <4, 2, 4, 1, 6>, 4 -> lengths:<1, 1, 2, 1, 2> , nums: <4, 2, 2, 1, 3> : 48 slices , slice_idx: 2
// <4, 2, 5, 1, 2>, 10 -> lengths:<1, 1, 5, 1, 2> , nums: <4, 2, 1, 1, 1> : 8 slices , slice_idx: 1
//
// <4, 2, 8>, 64 -> lengths:<4, 2, 8> , nums: <1, 1, 1> : 1 slices , slice_idx: 0
// <4, 2, 8>, 32 -> lengths:<2, 2, 8> , nums: <2, 1, 1> : 2 slices , slice_idx: 0
// <4, 2, 8>, 16 -> lengths:<1, 2, 8> , nums: <4, 1, 1> : 4 slices , slice_idx: 0
// <4, 2, 8>, 8 -> lengths:<1, 1, 8> , nums: <4, 2, 1> : 8 slices , slice_idx: 1
// <4, 2, 8>, 4 -> lengths:<1, 1, 4> , nums: <4, 2, 2> : 16 slices , slice_idx: 2
// <4, 2, 8>, 2 -> lengths:<1, 1, 2> , nums: <4, 2, 4> : 32 slices , slice_idx: 2
// <4, 2, 8>, 1 -> lengths:<1, 1, 1> , nums: <4, 2, 8> : 64 slices , slice_idx: 2
//
// <4, 2, 1, 4, 2> / 4 ->
// mask:<1, 1, 1, 0, 1>, -> lengths:<1, 2, 1, 4, 2> , nums: <4, 1, 1, 1, 1> : 8 slices , slice_idx: 0
//
// return tuple<slice_lengths, slice_nums, slice_index>, slice_index is at which index will start
// have split slices (right -> left)
// or the first index that sliced length is different from the original length
// clang-format on
template
<
typename
Seq
,
index_t
SliceSize
,
typename
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>
::
type
>
constexpr
auto
reverse_slice_sequence
(
Seq
,
number
<
SliceSize
>
,
Mask
=
typename
uniform_sequence_gen
<
Seq
::
size
(),
1
>::
type
{})
{
static_assert
(
Seq
::
size
()
==
Mask
::
size
());
using
sliced_type
=
reverse_slice_sequence_impl
<
Seq
,
Mask
,
typename
arithmetic_sequence_gen
<
0
,
Seq
::
size
(),
1
>::
type
,
SliceSize
>
;
static_assert
(
sliced_type
::
remaining_slice_sizes
::
front
().
value
==
1
,
"can not evenly divide this sequence, please check"
);
return
make_tuple
(
typename
sliced_type
::
dim_lengths
{},
typename
sliced_type
::
dim_slices
{},
number
<
sliced_type
::
split_idx
>
{});
}
//
//
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// slice tensor from x_dim, result in split in y_dim, not p_dim.
// We don't support slice cross p_dim (aka, slice different threads)
// We don't support slice cross p_dim (aka, slice different threads)
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
…
22
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment