Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
66593407
"...composable_kernel_rocm.git" did not exist on "22ee67a912f2192d8d47fea1f1ed1af554e91730"
Unverified
Commit
66593407
authored
Aug 05, 2024
by
Po Yen Chen
Committed by
GitHub
Aug 04, 2024
Browse files
[CK_TILE] Pick bugfixes for ROCm 6.2 compiler issues (#1430)
parent
00626ca8
Changes
11
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
747 additions
and
266 deletions
+747
-266
example/ck_tile/01_fmha/generate.py
example/ck_tile/01_fmha/generate.py
+17
-6
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+435
-199
include/ck_tile/core/arch/arch.hpp
include/ck_tile/core/arch/arch.hpp
+3
-5
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+18
-0
include/ck_tile/core/tensor/buffer_view.hpp
include/ck_tile/core/tensor/buffer_view.hpp
+34
-11
include/ck_tile/core/tensor/load_tile.hpp
include/ck_tile/core/tensor/load_tile.hpp
+13
-6
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+2
-0
include/ck_tile/core/tensor/tensor_view.hpp
include/ck_tile/core/tensor/tensor_view.hpp
+15
-9
include/ck_tile/core/tensor/tile_elementwise.hpp
include/ck_tile/core/tensor/tile_elementwise.hpp
+89
-10
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+94
-11
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
.../ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
+27
-9
No files found.
example/ck_tile/01_fmha/generate.py
View file @
66593407
...
@@ -355,6 +355,9 @@ class FmhaFwdApiPool:
...
@@ -355,6 +355,9 @@ class FmhaFwdApiPool:
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
per_hdim_case
=
per_hdim_case
+
FMHA_FWD_API_PER_HDIM_CASE
.
format
(
F_if
=
if_j
,
F_hdim
=
hdim
,
F_inner_dispatch
=
inners
)
if_i
=
'if'
if
i
==
0
else
'else if'
if_i
=
'if'
if
i
==
0
else
'else if'
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
per_dtypes
=
per_dtypes
+
FMHA_FWD_API_PER_DTYPE
.
format
(
F_if
=
if_i
,
F_dtype
=
dtype
,
F_hdim_case
=
per_hdim_case
)
if
not
per_dtypes
:
# empty string we add some ignore to suppress warning in api
per_dtypes
+=
' (void)t ; (void)s ; (void)a;'
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
return
FMHA_FWD_KERNEL_HEADER
+
FMHA_FWD_API
.
format
(
F_dispatch
=
per_dtypes
)
@
dataclass
@
dataclass
...
@@ -489,7 +492,8 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
...
@@ -489,7 +492,8 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines
=
[]
pipelines
=
[]
if
dtype
in
[
'fp16'
,
'bf16'
]:
if
dtype
in
[
'fp16'
,
'bf16'
]:
for
mask
,
bias
,
lse
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
for
mask
,
bias
,
lse
in
itertools
.
product
(
get_mask_map
(
mask_impl
).
keys
(),
BIAS_MAP
.
keys
(),
[
"t"
,
"f"
]):
if
hdim
==
256
:
# if hdim=32, fallback to 'qr' pipeline to workaround rocm 6.2 compiler problem (missing s_waitcnt)
if
hdim
==
256
or
hdim
==
32
:
# if True:
# if True:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
...
@@ -497,11 +501,18 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
...
@@ -497,11 +501,18 @@ def get_blobs(kernel_filter : Optional[str], receipt, mask_impl) -> Tuple[FmhaFw
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
else
:
else
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
if
bias
==
"bias"
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: rocm 6.2 compiler problem if using qr_async for bias case
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
if
receipt
==
1
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
'f'
,
'f'
,
'f'
,
'f'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
else
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
pipelines
.
append
(
FmhaFwdPipeline
(
'qr_async'
,
'col'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
if
receipt
==
1
and
bias
!=
"bias"
:
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'row'
,
't'
,
't'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: cover arbitraty hdim
pipelines
.
append
(
FmhaFwdPipeline
(
'qr'
,
'col'
,
't'
,
'f'
,
't'
,
't'
,
bias
,
lse
,
squant
,
mask
))
# TODO: cover arbitraty hdim
elif
dtype
in
[
'fp8'
,
'bf8'
]:
elif
dtype
in
[
'fp8'
,
'bf8'
]:
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
66593407
This diff is collapsed.
Click to expand it.
include/ck_tile/core/arch/arch.hpp
View file @
66593407
...
@@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
...
@@ -79,14 +79,12 @@ CK_TILE_DEVICE void block_sync_lds_direct_load()
"
::
);
"
::
);
}
}
CK_TILE_DEVICE
void
s_nop
()
CK_TILE_DEVICE
void
s_nop
(
index_t
cnt
=
0
)
{
{
#if 1
#if 1
asm
volatile
(
"\
asm
volatile
(
"s_nop %0"
:
:
"n"
(
cnt
)
:
);
s_nop 0
\n
\
"
::
);
#else
#else
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
cnt
);
#endif
#endif
}
}
...
...
include/ck_tile/core/config.hpp
View file @
66593407
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#define __gfx11__
#define __gfx11__
#endif
#endif
#include "hip/hip_version.h"
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#ifndef CK_TILE_DONT_USE_HIP_RUNTIME_HEADERS
#include "hip/hip_runtime.h"
#include "hip/hip_runtime.h"
#include "hip/hip_fp16.h"
#include "hip/hip_fp16.h"
...
@@ -144,6 +145,15 @@
...
@@ -144,6 +145,15 @@
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#define CK_TILE_WORKAROUND_SWDEV_XXXXXX_INT8_DS_WRITE_ISSUE 1
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 1 && HIP_VERSION_PATCH >= 40091) || \
(HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133)
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#define CK_TILE_DEBUG_LOG 0
#endif
#endif
...
@@ -167,7 +177,15 @@
...
@@ -167,7 +177,15 @@
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#define CK_TILE_USE_SUBDWORD_TILE_CAST 0
#endif
#endif
#ifndef CK_TILE_USE_PK_FP16_TILE_CAST
#define CK_TILE_USE_PK_FP16_TILE_CAST 0
#endif
// TODO: better solve this inside compiler
// TODO: better solve this inside compiler
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#ifndef CK_TILE_FMHA_FWD_FAST_EXP2
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#define CK_TILE_FMHA_FWD_FAST_EXP2 0
#endif
#endif
#ifndef CK_TILE_BUFFER_LOAD_RAW_BF16_WA
#define CK_TILE_BUFFER_LOAD_RAW_BF16_WA 1
#endif
include/ck_tile/core/tensor/buffer_view.hpp
View file @
66593407
...
@@ -68,6 +68,8 @@ struct buffer_view<address_space_enum::generic,
...
@@ -68,6 +68,8 @@ struct buffer_view<address_space_enum::generic,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
generic
;
return
address_space_enum
::
generic
;
...
@@ -223,25 +225,36 @@ struct buffer_view<address_space_enum::global,
...
@@ -223,25 +225,36 @@ struct buffer_view<address_space_enum::global,
T
*
p_data_
=
nullptr
;
T
*
p_data_
=
nullptr
;
BufferSizeType
buffer_size_
;
BufferSizeType
buffer_size_
;
int32x4_t
cached_buf_res_
;
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
remove_cvref_t
<
T
>
invalid_element_value_
=
T
{
0
};
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
CK_TILE_HOST_DEVICE
constexpr
buffer_view
()
:
p_data_
{},
buffer_size_
{},
invalid_element_value_
{}
:
p_data_
{},
buffer_size_
{},
cached_buf_res_
{
0
},
invalid_element_value_
{}
{
{
}
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
0
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
0
}
{
{
}
}
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
CK_TILE_HOST_DEVICE
constexpr
buffer_view
(
T
*
p_data
,
BufferSizeType
buffer_size
,
BufferSizeType
buffer_size
,
T
invalid_element_value
)
T
invalid_element_value
)
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
invalid_element_value_
{
invalid_element_value
}
:
p_data_
{
p_data
},
buffer_size_
{
buffer_size
},
cached_buf_res_
{
0
},
invalid_element_value_
{
invalid_element_value
}
{
{
}
}
// this is non constexpr intentially (will call some intrinsic internally)
// Must call for buffers that need *_raw load/store
CK_TILE_HOST_DEVICE
void
init_raw
()
{
cached_buf_res_
=
make_wave_buffer_resource
(
p_data_
,
buffer_size_
*
sizeof
(
type
));
}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
global
;
return
address_space_enum
::
global
;
...
@@ -332,12 +345,15 @@ struct buffer_view<address_space_enum::global,
...
@@ -332,12 +345,15 @@ struct buffer_view<address_space_enum::global,
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
get_raw
(
remove_cvref_t
<
X
>&
dst
,
get_raw
(
remove_cvref_t
<
X
>&
dst
,
index_t
i
,
bool
is_valid_element
)
const
index_t
i
,
bool
is_valid_element
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -348,18 +364,21 @@ struct buffer_view<address_space_enum::global,
...
@@ -348,18 +364,21 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
>
(
amd_buffer_load_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
,
oob_conditional_check
,
pre_nop
>
(
dst
,
p_data_
,
i
,
buffer_size_
,
is_valid_element
);
dst
,
cached_buf_res_
,
i
,
is_valid_element
,
bool_constant
<
pre_nop
>
{}
);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
typename
vector_traits
<
remove_cvref_t
<
T
>>::
scalar_type
>::
value
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_DEVICE
constexpr
auto
CK_TILE_DEVICE
constexpr
auto
async_get_raw
(
remove_cvref_t
<
T
>*
smem
,
async_get
(
remove_cvref_t
<
T
>*
smem
,
index_t
i
,
bool
/*is_valid_element*/
)
const
index_t
i
,
bool
/*is_valid_element*/
,
bool_constant
<
pre_nop
>
=
{})
const
{
{
// X is vector of T
// X is vector of T
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
constexpr
index_t
scalar_per_t_vector
=
vector_traits
<
remove_cvref_t
<
T
>>::
vector_size
;
...
@@ -370,8 +389,8 @@ struct buffer_view<address_space_enum::global,
...
@@ -370,8 +389,8 @@ struct buffer_view<address_space_enum::global,
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
constexpr
index_t
t_per_x
=
scalar_per_x_vector
/
scalar_per_t_vector
;
amd_async_buffer_load_with_oob
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
amd_async_buffer_load_with_oob
_raw
<
remove_cvref_t
<
T
>
,
t_per_x
,
Coherence
>
(
smem
,
p_data_
,
i
,
buffer_size_
);
smem
,
cached_buf_res_
,
i
,
bool_constant
<
pre_nop
>
{}
);
}
}
// i is offset of T, not X. i should be aligned to X
// i is offset of T, not X. i should be aligned to X
...
@@ -626,6 +645,8 @@ struct buffer_view<address_space_enum::lds,
...
@@ -626,6 +645,8 @@ struct buffer_view<address_space_enum::lds,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
lds
;
return
address_space_enum
::
lds
;
...
@@ -908,6 +929,8 @@ struct buffer_view<address_space_enum::vgpr,
...
@@ -908,6 +929,8 @@ struct buffer_view<address_space_enum::vgpr,
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{}
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
CK_TILE_DEVICE
static
constexpr
address_space_enum
get_address_space
()
{
{
return
address_space_enum
::
vgpr
;
return
address_space_enum
::
vgpr
;
...
...
include/ck_tile/core/tensor/load_tile.hpp
View file @
66593407
...
@@ -36,30 +36,37 @@ template <typename T,
...
@@ -36,30 +36,37 @@ template <typename T,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
,
index_t
NumCoord
,
bool
oob_conditional_check
=
true
>
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
CK_TILE_DEVICE
auto
load_tile_raw
(
T
&
tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{})
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{});
tile_window
.
load_raw
(
tile
,
bool_constant
<
oob_conditional_check
>
{}
,
bool_constant
<
pre_nop
>
{}
);
}
}
template
<
typename
LdsTileWindow_
,
template
<
typename
LdsTileWindow_
,
typename
BottomTensorView_
,
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
>
index_t
NumCoord
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
CK_TILE_DEVICE
auto
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
async_load_tile_raw
(
LdsTileWindow_
&&
lds_tile
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
const
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
NumCoord
>&
tile_window
,
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
{
return
tile_window
.
async_load
(
lds_tile
);
return
tile_window
.
async_load_raw
(
lds_tile
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
CK_TILE_DEVICE
auto
async_load_fence
(
index_t
cnt
=
0
)
...
...
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
66593407
...
@@ -35,6 +35,8 @@ struct null_tile_window
...
@@ -35,6 +35,8 @@ struct null_tile_window
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
BottomTensorIndex
{};
}
CK_TILE_DEVICE
void
init_raw
()
{}
WindowLengths
window_lengths_
;
WindowLengths
window_lengths_
;
};
};
...
...
include/ck_tile/core/tensor/tensor_view.hpp
View file @
66593407
...
@@ -33,6 +33,8 @@ struct tensor_view
...
@@ -33,6 +33,8 @@ struct tensor_view
{
{
}
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
buf_
.
init_raw
();
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
constexpr
auto
&
get_tensor_descriptor
()
const
{
return
desc_
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_dimension
()
...
@@ -82,30 +84,34 @@ struct tensor_view
...
@@ -82,30 +84,34 @@ struct tensor_view
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
// "coord" is coordinate of DataType, not X. "coord" should be aligned to X
template
<
typename
X
,
template
<
typename
X
,
bool
oob_conditional_check
=
true
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
void
CK_TILE_HOST_DEVICE
void
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
get_vectorized_elements_raw
(
remove_cvref_t
<
X
>&
dst
,
const
TensorCoord
&
coord
,
const
TensorCoord
&
coord
,
bool_constant
<
oob_conditional_check
>
=
{}
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
pre_nop
>
=
{})
const
{
{
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
>(
return
buf_
.
template
get_raw
<
X
,
oob_conditional_check
,
pre_nop
>(
dst
,
dst
,
coord
.
get_offset
(),
coord
.
get_offset
(),
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
));
coordinate_has_valid_offset_assuming_top_index_is_valid
(
desc_
,
coord
),
bool_constant
<
pre_nop
>
{});
}
}
template
<
typename
X
,
template
<
typename
X
,
bool
pre_nop
=
false
,
typename
std
::
enable_if
<
typename
std
::
enable_if
<
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
std
::
is_same_v
<
typename
vector_traits
<
remove_cvref_t
<
X
>
>::
scalar_type
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
typename
vector_traits
<
remove_cvref_t
<
DataType
>>::
scalar_type
>
,
bool
>::
type
=
false
>
bool
>::
type
=
false
>
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
(
remove_cvref_t
<
DataType
>*
smem
,
CK_TILE_HOST_DEVICE
constexpr
void
async_get_vectorized_elements
_raw
(
const
TensorCoord
&
coord
)
const
remove_cvref_t
<
DataType
>*
smem
,
const
TensorCoord
&
coord
,
bool_constant
<
pre_nop
>
=
{}
)
const
{
{
return
buf_
.
template
async_get
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
);
return
buf_
.
template
async_get_raw
<
X
>(
smem
,
coord
.
get_offset
(),
true
/*not used*/
,
bool_constant
<
pre_nop
>
{});
}
}
// X is vector of DataType.
// X is vector of DataType.
...
...
include/ck_tile/core/tensor/tile_elementwise.hpp
View file @
66593407
...
@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
...
@@ -76,23 +76,63 @@ CK_TILE_DEVICE void set_tile(null_tensor&, const T&)
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// TODO: prefer to use per-dword value to set a tensor, in case compiler not doing well with
// sub-dword tensor...
// sub-dword tensor...
template
<
typename
DstrTensors
,
index_t
v
>
template
<
typename
DstrTensors
,
index_t
v
,
bool
skip_subdword_opt
=
false
>
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
)
CK_TILE_DEVICE
void
set_tile
(
DstrTensors
&
dstr_tensor
,
number
<
v
>
,
bool_constant
<
skip_subdword_opt
>
=
{})
{
{
constexpr
index_t
tensor_bytes
=
using
elem_type
=
typename
DstrTensors
::
DataType
;
DstrTensors
::
get_thread_buffer_size
()
*
sizeof
(
typename
DstrTensors
::
DataType
);
constexpr
index_t
elem_size
=
sizeof
(
elem_type
);
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
)
constexpr
index_t
tensor_bytes
=
DstrTensors
::
get_thread_buffer_size
()
*
elem_size
;
// # bytes per write = 4
if
constexpr
(
v
==
0
&&
tensor_bytes
%
4
==
0
&&
!
skip_subdword_opt
)
{
{
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
auto
&
buffer
=
dstr_tensor
.
get_thread_buffer
();
static_for
<
0
,
tensor_bytes
/
4
,
1
>
{}([
&
](
auto
i_write
)
{
if
constexpr
(
elem_size
==
1
)
{
// # elements per write = 4
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
4
>
{
0
,
0
,
0
,
0
};
buffer
[
i_write
*
4
+
0
]
=
values
.
x
;
buffer
[
i_write
*
4
+
1
]
=
values
.
y
;
buffer
[
i_write
*
4
+
2
]
=
values
.
z
;
buffer
[
i_write
*
4
+
3
]
=
values
.
w
;
}
else
if
constexpr
(
elem_size
==
2
)
{
// # elements per write = 2
constexpr
auto
values
=
ext_vector_t
<
elem_type
,
2
>
{
0
,
0
};
buffer
[
i_write
*
2
+
0
]
=
values
.
x
;
buffer
[
i_write
*
2
+
1
]
=
values
.
y
;
}
else
if
constexpr
(
elem_size
==
4
)
{
// # elements per write = 1
constexpr
elem_type
value
=
0
;
buffer
[
i_write
]
=
value
;
}
else
{
static_assert
(
false
,
"type not supported"
);
}
});
#else
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
using
dvec_t
=
array
<
index_t
,
tensor_bytes
/
4
>
;
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
auto
&
tensor
=
reinterpret_cast
<
dvec_t
&>
(
dstr_tensor
.
get_thread_buffer
());
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
for
(
auto
i
=
0
;
i
<
tensor
.
size
();
i
++
)
tensor
.
get
(
i
)
=
v
;
tensor
.
get
(
i
)
=
v
;
#endif
}
}
else
else
{
{
tile_elementwise_inout
(
tile_elementwise_inout
([](
auto
&
x
)
{
x
=
type_convert
<
elem_type
,
index_t
>
(
v
);
},
[](
auto
&
x
)
{
x
=
type_convert
<
typename
DstrTensors
::
DataType
,
index_t
>
(
v
);
},
dstr_tensor
);
dstr_tensor
);
}
}
}
}
...
@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
...
@@ -110,7 +150,7 @@ CK_TILE_DEVICE void clear_tile(DstrTensors& dstr_tensor)
namespace
impl
{
namespace
impl
{
// TODO: this is ugly
// TODO: this is ugly
template
<
typename
OutDataType
,
typename
InTensor
>
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
x4
(
const
InTensor
&
in_dstr_tensors
)
CK_TILE_DEVICE
auto
cast_tile_pk_fp8
_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
{
#if defined(__gfx94__)
#if defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
// This API is designed to use the _pk_ serious of function
...
@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
...
@@ -156,6 +196,37 @@ CK_TILE_DEVICE auto cast_tile_pk_fp8x4(const InTensor& in_dstr_tensors)
#endif
#endif
}
}
template
<
typename
OutDataType
,
typename
InTensor
>
CK_TILE_DEVICE
auto
cast_tile_pk_fp16_fp32
(
const
InTensor
&
in_dstr_tensors
)
{
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx94__)
// This API is designed to use the _pk_ serious of function
constexpr
auto
in_tile_dstr
=
InTensor
::
get_tile_distribution
();
constexpr
index_t
thread_buffer_size
=
InTensor
::
get_thread_buffer_size
();
static_assert
(
thread_buffer_size
%
2
==
0
);
constexpr
index_t
thread_buffer_size_pk
=
thread_buffer_size
/
2
;
auto
out_dstr_tensor
=
make_static_distributed_tensor
<
OutDataType
>
(
in_tile_dstr
);
// TODO: this is rtz cvt, need be very careful
for
(
index_t
i
=
0
;
i
<
thread_buffer_size_pk
;
i
++
)
{
auto
o
=
__builtin_amdgcn_cvt_pkrtz
(
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
0
],
in_dstr_tensors
.
get_thread_buffer
()[
2
*
i
+
1
]);
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
0
)
=
o
.
x
;
out_dstr_tensor
.
get_thread_buffer
().
at
(
2
*
i
+
1
)
=
o
.
y
;
}
return
out_dstr_tensor
;
#else
// fallback
return
tile_elementwise_in
(
type_convert
<
OutDataType
,
typename
InTensor
::
DataType
>
,
in_dstr_tensors
);
#endif
}
#if CK_TILE_USE_SUBDWORD_TILE_CAST
#if CK_TILE_USE_SUBDWORD_TILE_CAST
// this function assume either src or dst (or both) date type is under 1 dword
// this function assume either src or dst (or both) date type is under 1 dword
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
// we pack subdword value into 1 dword to avoid compiler's default subdword behavior(which is buggy)
...
@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
...
@@ -229,8 +300,16 @@ CK_TILE_DEVICE auto cast_tile(const SrcTensor& src_tensor)
float
>
&&
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
(
SrcTensor
::
get_thread_buffer_size
()
%
4
==
0
))
{
{
return
impl
::
cast_tile_pk_fp8
x4
<
DstType
,
SrcTensor
>
(
src_tensor
);
return
impl
::
cast_tile_pk_fp8
_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
}
#if CK_TILE_USE_PK_FP16_TILE_CAST
else
if
constexpr
(
std
::
is_same_v
<
DstType
,
fp16_t
>
&&
std
::
is_same_v
<
typename
SrcTensor
::
DataType
,
float
>
&&
(
SrcTensor
::
get_thread_buffer_size
()
%
2
==
0
))
{
return
impl
::
cast_tile_pk_fp16_fp32
<
DstType
,
SrcTensor
>
(
src_tensor
);
}
#endif
#if CK_TILE_USE_SUBDWORD_TILE_CAST
#if CK_TILE_USE_SUBDWORD_TILE_CAST
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
else
if
constexpr
(
sizeof
(
DstType
)
<
4
||
sizeof
(
typename
SrcTensor
::
DataType
)
<
4
)
{
{
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
66593407
...
@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
...
@@ -344,9 +344,10 @@ struct tile_window_with_static_distribution
return
dst_tensor
;
return
dst_tensor
;
}
}
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
>
template
<
typename
DstTile
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
CK_TILE_DEVICE
void
load_raw
(
DstTile
&
dst_tensor
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
{
using
Traits
=
load_store_traits
;
using
Traits
=
load_store_traits
;
...
@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
...
@@ -373,7 +374,13 @@ struct tile_window_with_static_distribution
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// data index [y0, y1, ...]
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
...
@@ -384,8 +391,12 @@ struct tile_window_with_static_distribution
...
@@ -384,8 +391,12 @@ struct tile_window_with_static_distribution
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
get_bottom_tensor_view
().
template
get_vectorized_elements_raw
<
vector_t
>(
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
dst_vec_tbuf
.
template
at
<
d
/
Traits
::
ScalarPerVector
>(),
bottom_tensor_thread_coord
,
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{});
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
{
...
@@ -399,12 +410,17 @@ struct tile_window_with_static_distribution
...
@@ -399,12 +410,17 @@ struct tile_window_with_static_distribution
}
}
});
});
});
});
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm
volatile
(
"; this inline asm is workaround to prevent compiler from using too much "
"scratch memory"
::
);
#endif
}
}
// TODO: currently async load only implemented in inline asm
// TODO: currently async load only implemented in inline asm
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
>
template
<
typename
LdsTileWindow_
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
async_load
(
LdsTileWindow_
&&
lds_tile
,
CK_TILE_DEVICE
auto
async_load_raw
(
LdsTileWindow_
&&
lds_tile
,
bool_constant
<
oob_conditional_check
>
=
{})
const
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
// using LdsTensorView = typename LdsTileWindow::BottomTensorView;
...
@@ -449,11 +465,17 @@ struct tile_window_with_static_distribution
...
@@ -449,11 +465,17 @@ struct tile_window_with_static_distribution
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
constexpr
auto
pre_nop_
=
[
&
]()
{
if
constexpr
(
pre_nop
&&
iCoord
==
0
&&
iCoordAccess
==
0
)
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
// read from bottom tensor
// read from bottom tensor
get_bottom_tensor_view
().
template
async_get_vectorized_elements
<
vector_t
>(
get_bottom_tensor_view
().
template
async_get_vectorized_elements
_raw
<
vector_t
>(
smem
,
bottom_tensor_thread_coord
);
smem
,
bottom_tensor_thread_coord
,
pre_nop_
);
// move thread coordinate
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
...
@@ -608,6 +630,67 @@ struct tile_window_with_static_distribution
...
@@ -608,6 +630,67 @@ struct tile_window_with_static_distribution
});
});
}
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
#if 0 // debug
// TODO: this use more register for FA, but less register for GEMM
// need investigation
// only support warp-tile and block-tile
static_assert(NDimP == 1 or NDimP == 2, "wrong!");
WindowAdaptorCoord window_adaptor_thread_coord_tmp;
if constexpr(NDimP == 1)
{
window_adaptor_thread_coord_tmp = make_tensor_adaptor_coordinate(
tile_dstr_.get_ps_ys_to_xs_adaptor(), AdaptorTopIndex{get_lane_id(), 0});
}
else if constexpr(NDimP == 2)
{
window_adaptor_thread_coord_tmp =
make_tensor_adaptor_coordinate(tile_dstr_.get_ps_ys_to_xs_adaptor(),
AdaptorTopIndex{get_warp_id(), get_lane_id(), 0});
}
#else
// TODO: this use less register for FA, but more register for GEMM
// need investigation
const
auto
window_adaptor_thread_coord_tmp
=
make_tensor_adaptor_coordinate
(
tile_dstr_
.
get_ps_ys_to_xs_adaptor
(),
container_concat
(
detail
::
get_partition_index
(
tile_dstr_
),
array
<
index_t
,
NDimY
>
{
0
}));
#endif
BottomTensorIndex
bottom_tensor_thread_origin_idx_tmp
=
window_origin_
+
window_adaptor_thread_coord_tmp
.
get_bottom_index
();
const
auto
bottom_tensor_thread_coord_tmp
=
make_tensor_coordinate
(
bottom_tensor_view_
.
get_tensor_descriptor
(),
bottom_tensor_thread_origin_idx_tmp
);
// pre-compute NumCoord (WindowAdaptorCoord, BottomTensorCoord) bundles to speed up
// future load/store() calls (might allocate more registers)
using
Traits
=
load_store_traits
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
auto
window_adaptor_thread_coord
=
window_adaptor_thread_coord_tmp
;
auto
bottom_tensor_thread_coord
=
bottom_tensor_thread_coord_tmp
;
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_step_between
(
number
<
0
>
{},
number
<
iCoord
*
NumAccessPerCoord
>
{});
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
array
<
index_t
,
NDimP
>
{
0
},
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
pre_computed_coords_
(
iCoord
)
=
make_tuple
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
);
});
}
CK_TILE_HOST_DEVICE
void
init_raw
()
{
bottom_tensor_view_
.
init_raw
();
}
// this is the bottom tensor view
// this is the bottom tensor view
// [x0', x1', ...] ==> [offset]
// [x0', x1', ...] ==> [offset]
BottomTensorView
bottom_tensor_view_
;
BottomTensorView
bottom_tensor_view_
;
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp
View file @
66593407
...
@@ -78,6 +78,12 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -78,6 +78,12 @@ struct BlockFmhaPipelineQRKSVSAsync
return
Problem
::
kBlockPerCu
;
return
Problem
::
kBlockPerCu
;
else
else
{
{
// minimize occupancy
if
constexpr
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
)
{
return
1
;
}
if
constexpr
(
kK0BlockLength
<=
32
)
if
constexpr
(
kK0BlockLength
<=
32
)
{
{
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
if
constexpr
(
kPadSeqLenK
&&
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
&&
...
@@ -212,11 +218,14 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -212,11 +218,14 @@ struct BlockFmhaPipelineQRKSVSAsync
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_lengths
(),
q_dram_block_window_tmp
.
get_window_origin
(),
q_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeQDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
q_dram_window
.
init_raw
();
// TODO: we use async Copy for K, which is inline asm
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
// a side effect is we have to use inline asm for q as well
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
auto
q
=
decltype
(
load_tile
(
q_dram_window
)){};
set_tile
(
q
,
number
<
0
>
{});
// use per-dword clear to avoid scratch
// TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw
(
q
,
q_dram_window
);
load_tile_raw
(
q
,
q_dram_window
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -285,6 +294,16 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -285,6 +294,16 @@ struct BlockFmhaPipelineQRKSVSAsync
k_dram_block_window
.
get_window_origin
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
// K DRAM tile window for
// load
// load
k_dram_window
.
init_raw
();
constexpr
auto
k_oob_ck
=
bool_constant
<
true
>
{};
constexpr
auto
k_pre_np
=
[
&
]()
{
if
constexpr
(
kPadSeqLenK
&&
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
(
BiasEnum
!=
BlockAttentionBiasEnum
::
NO_BIAS
)))
return
bool_constant
<
true
>
{};
else
return
bool_constant
<
false
>
{};
}();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
auto
bias_dram_window
=
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
@@ -299,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -299,7 +318,7 @@ struct BlockFmhaPipelineQRKSVSAsync
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
Policy
::
template
MakeVDramTileDistribution
<
Problem
>());
// prefetch K tile
// prefetch K tile
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -322,7 +341,9 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -322,7 +341,9 @@ struct BlockFmhaPipelineQRKSVSAsync
{
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
static_for
<
0
,
k0_loops
-
1
,
1
>
{}([
&
](
auto
i_k0
)
{
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
async_load_tile_raw
(
k_lds_store
(
number
<
LdsSeq
.
at
(
number
<
i_k0
+
1
>
{})
>
{}),
k_dram_window
);
k_dram_window
,
k_oob_ck
,
k_pre_np
);
if
constexpr
(
i_k0
<
k0_loops
-
1
)
if
constexpr
(
i_k0
<
k0_loops
-
1
)
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
...
@@ -609,16 +630,13 @@ struct BlockFmhaPipelineQRKSVSAsync
...
@@ -609,16 +630,13 @@ struct BlockFmhaPipelineQRKSVSAsync
{
{
// move K tile windows
// move K tile windows
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
move_tile_window
(
k_dram_block_window
,
{
kN0
,
0
});
k_dram_window
=
k_dram_window
.
set_window_origin
(
k_dram_block_window
.
get_window_origin
());
make_tile_window
(
k_dram_block_window
.
get_bottom_tensor_view
(),
k_dram_block_window
.
get_window_lengths
(),
k_dram_block_window
.
get_window_origin
(),
Policy
::
template
MakeKDramTileDistribution
<
Problem
>());
if
constexpr
(
k1_loops
>=
2
&&
if
constexpr
(
k1_loops
>=
2
&&
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
LdsSeq
.
at
(
number
<
0
>
{})
==
LdsSeq
.
at
(
number
<
k0_loops
+
k1_loops
-
2
>
{}))
__builtin_amdgcn_s_barrier
();
__builtin_amdgcn_s_barrier
();
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
);
async_load_tile_raw
(
k_lds_store
(
LdsSeq
.
at
(
number
<
0
>
{})),
k_dram_window
,
k_oob_ck
,
k_pre_np
);
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
move_tile_window
(
k_dram_window
,
{
0
,
kK0
});
}
}
// tail
// tail
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment