Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
d71189ff
Unverified
Commit
d71189ff
authored
Sep 03, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Sep 03, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
f84e2020
73b67f29
Changes
74
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1641 additions
and
209 deletions
+1641
-209
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
...grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
...r_operation/gpu/device/impl/device_grouped_conv_utils.hpp
+30
-14
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
+33
-1
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
.../ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
+52
-2
include/ck_tile/core/config.hpp
include/ck_tile/core/config.hpp
+9
-0
include/ck_tile/core/numeric/bfloat16.hpp
include/ck_tile/core/numeric/bfloat16.hpp
+34
-0
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+10
-3
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+52
-1
include/ck_tile/core/utility/type_traits.hpp
include/ck_tile/core/utility/type_traits.hpp
+17
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+9
-0
include/ck_tile/host/kernel_launch.hpp
include/ck_tile/host/kernel_launch.hpp
+5
-5
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
...reference/reference_batched_rotary_position_embedding.hpp
+73
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+6
-2
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
+19
-3
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
+108
-0
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
+279
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
+679
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
...le/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
+42
-0
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+182
-177
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_fwd_multiple_d_xdl_large_tensor_cshuffle.hpp
View file @
d71189ff
...
...
@@ -713,7 +713,7 @@ struct DeviceGroupedConvFwdMultipleD_Xdl_CShuffle_Large_Tensor
return
false
;
}
}
if
constexpr
(
!
is_NSpatialG
K
_GKSpatial_NSpatialG
C
<
ALayout
,
BLayout
,
ELayout
>
())
if
constexpr
(
!
is_NSpatialG
C
_GKSpatial_NSpatialG
K
<
ALayout
,
BLayout
,
ELayout
>
())
{
return
false
;
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp
View file @
d71189ff
...
...
@@ -12,7 +12,7 @@ namespace device {
// 1d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NWG
K
_GKXC_NWG
C
()
constexpr
bool
is_NWG
C
_GKXC_NWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -20,7 +20,7 @@ constexpr bool is_NWGK_GKXC_NWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNW
K
_GKXC_GNW
C
()
constexpr
bool
is_GNW
C
_GKXC_GNW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKXC
>
&&
...
...
@@ -28,7 +28,7 @@ constexpr bool is_GNWK_GKXC_GNWC()
}
// 2d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NHWG
K
_GKYXC_NHWG
C
()
constexpr
bool
is_NHWG
C
_GKYXC_NHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
...
...
@@ -36,15 +36,23 @@ constexpr bool is_NHWGK_GKYXC_NHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNHW
K
_GKYXC_GNHW
C
()
constexpr
bool
is_GNHW
C
_GKYXC_GNHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
GNHWK
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NGCHW_GKYXC_NGKHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKHW
>
;
}
// 3d
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NDHWG
K
_GKZYXC_NDHWG
C
()
constexpr
bool
is_NDHWG
C
_GKZYXC_NDHWG
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NDHWGC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -52,7 +60,7 @@ constexpr bool is_NDHWGK_GKZYXC_NDHWGC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNDHW
K
_GKZYXC_GNDHW
C
()
constexpr
bool
is_GNDHW
C
_GKZYXC_GNDHW
K
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
GNDHWC
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
...
...
@@ -60,19 +68,27 @@ constexpr bool is_GNDHWK_GKZYXC_GNDHWC()
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGK_GKSpatial_NSpatialGC
()
constexpr
bool
is_NGCDHW_GKZYXC_NGKDHW
()
{
return
is_same_v
<
InLayout
,
tensor_layout
::
convolution
::
NGCDHW
>
&&
is_same_v
<
WeiLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
&&
is_same_v
<
OutLayout
,
tensor_layout
::
convolution
::
NGKDHW
>
;
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_NSpatialGC_GKSpatial_NSpatialGK
()
{
return
is_NWG
K
_GKXC_NWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
K
_GKYXC_NHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
K
_GKZYXC_NDHWG
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_NWG
C
_GKXC_NWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NHWG
C
_GKYXC_NHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_NDHWG
C
_GKZYXC_NDHWG
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
typename
InLayout
,
typename
WeiLayout
,
typename
OutLayout
>
constexpr
bool
is_GNSpatial
K
_GKSpatial_GNSpatial
C
()
constexpr
bool
is_GNSpatial
C
_GKSpatial_GNSpatial
K
()
{
return
is_GNW
K
_GKXC_GNW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
K
_GKYXC_GNHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
K
_GKZYXC_GNDHW
C
<
InLayout
,
WeiLayout
,
OutLayout
>
();
return
is_GNW
C
_GKXC_GNW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNHW
C
_GKYXC_GNHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
()
||
is_GNDHW
C
_GKZYXC_GNDHW
K
<
InLayout
,
WeiLayout
,
OutLayout
>
();
}
template
<
index_t
NumATensor
=
1
,
index_t
NumBTensor
=
1
,
index_t
NumDTensor
=
0
,
typename
=
void
>
...
...
include/ck/tensor_operation/gpu/device/tensor_layout.hpp
View file @
d71189ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -115,6 +115,23 @@ struct NDHWGC : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGC"
;
};
// input tensor
// packed NGCW/NGCHW/NGCDHW
struct
NGCW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCW"
;
};
struct
NGCHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCHW"
;
};
struct
NGCDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGCDHW"
;
};
// input tensor
// strided layout
struct
G_NW_C
:
public
BaseTensorLayout
...
...
@@ -325,6 +342,21 @@ struct NDHWGK : public BaseTensorLayout
static
constexpr
const
char
*
name
=
"NDHWGK"
;
};
struct
NGKW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKW"
;
};
struct
NGKHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKHW"
;
};
struct
NGKDHW
:
public
BaseTensorLayout
{
static
constexpr
const
char
*
name
=
"NGKDHW"
;
};
// output tensor
// strided layout
struct
G_NW_K
:
public
BaseTensorLayout
...
...
include/ck/tensor_operation/gpu/grid/gridwise_elementwise_2d.hpp
View file @
d71189ff
...
...
@@ -41,6 +41,55 @@ __global__ void
elementwise_op
);
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InAGridDescTuple
,
typename
InBGridDescTuple
,
typename
OutAGridDescTuple
,
typename
OutBGridDescTuple
,
typename
InDataTypePointerTuple
,
typename
OutDataTypePointerTuple
,
typename
Block2TileMapA
,
typename
Block2TileMapB
,
typename
ElementwiseOperation
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_elementwise_dual
(
const
InBGridDescTuple
in_grid_desc_tuple_a
,
const
InBGridDescTuple
in_grid_desc_tuple_b
,
const
OutAGridDescTuple
out_grid_desc_tuple_a
,
const
OutBGridDescTuple
out_grid_desc_tuple_b
,
const
InDataTypePointerTuple
p_in_global_tuple_a
,
const
InDataTypePointerTuple
p_in_global_tuple_b
,
const
OutDataTypePointerTuple
p_out_global_tuple_a
,
const
OutDataTypePointerTuple
p_out_global_tuple_b
,
const
Block2TileMapA
block_2_tile_map_a
,
const
Block2TileMapB
block_2_tile_map_b
,
const
ElementwiseOperation
elementwise_op
,
const
index_t
a_grid_size
)
{
if
(
get_block_1d_id
()
<
a_grid_size
)
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_a
,
out_grid_desc_tuple_a
,
p_in_global_tuple_a
,
p_out_global_tuple_a
,
block_2_tile_map_a
,
elementwise_op
,
get_block_1d_id
());
}
else
{
GridwiseElementwiseFunctor
::
Run
(
in_grid_desc_tuple_b
,
out_grid_desc_tuple_b
,
p_in_global_tuple_b
,
p_out_global_tuple_b
,
block_2_tile_map_b
,
elementwise_op
,
get_block_1d_id
()
-
a_grid_size
);
}
}
template
<
typename
GridwiseElementwiseFunctor
,
typename
InGridDescTuple
,
typename
OutGridDescTuple
,
...
...
@@ -133,7 +182,8 @@ struct GridwiseElementwise
const
InDataTypePointerTuple
&
p_in_global_tuple
,
const
OutDataTypePointerTuple
&
p_out_global_tuple
,
const
Block2TileMap
&
block_2_tile_map
,
const
ElementwiseOperation
&
elementwise_op
)
const
ElementwiseOperation
&
elementwise_op
,
const
index_t
block_id
=
get_block_1d_id
())
{
constexpr
auto
src_datas
=
generate_tuple
(
...
...
@@ -169,7 +219,7 @@ struct GridwiseElementwise
Number
<
NumOutput
>
{});
const
auto
block_work_idx
=
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_
block_
1d_id
()
));
block_2_tile_map
.
CalculateBottomIndex
(
make_multi_index
(
block_
id
));
const
index_t
m0_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]
*
M0PerBlock
);
...
...
include/ck_tile/core/config.hpp
View file @
d71189ff
...
...
@@ -46,6 +46,7 @@
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD 0
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE_WITH_NAN 1
#define CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE 2
#define CK_TILE_FLOAT_TO_BFLOAT16_STANDARD_ASM 3
#ifndef CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT
#define CK_TILE_FLOAT_TO_BFLOAT16_DEFAULT CK_TILE_FLOAT_TO_BFLOAT16_TRUNCATE
...
...
@@ -156,6 +157,14 @@
#endif
#endif
#ifndef CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
#if HIP_VERSION_MAJOR == 6 && HIP_VERSION_MINOR == 2 && HIP_VERSION_PATCH >= 41133
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 1
#else
#define CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE 0
#endif
#endif
#ifndef CK_TILE_DEBUG_LOG
#define CK_TILE_DEBUG_LOG 0
#endif
...
...
include/ck_tile/core/numeric/bfloat16.hpp
View file @
d71189ff
...
...
@@ -17,6 +17,7 @@ enum class bf16_rounding_mode
standard
=
0
,
// rtn
truncate_with_nan
,
truncate
,
standard_asm
,
};
template
<
bf16_rounding_mode
rounding
=
...
...
@@ -148,6 +149,37 @@ constexpr uint16_t float_to_bf16_rtn_raw(float f)
return
uint16_t
(
u
.
int32
>>
16
);
}
CK_TILE_HOST
constexpr
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
return
float_to_bf16_rtn_raw
(
f
);
}
CK_TILE_DEVICE
uint16_t
float_to_bf16_rtn_asm
(
float
f
)
{
union
{
float
fp32
;
uint32_t
int32
;
}
u
=
{
f
};
static
constexpr
uint32_t
FP32_NAN
=
0x7fff0000
;
static
constexpr
uint32_t
ROUND_BIAS_FOR_BF16
=
0x7fff
;
using
uint32x2_t
=
uint32_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
check_nan
;
uint32_t
tmp
;
asm
volatile
(
"
\n
\
v_cmp_u_f32 %0, %2, %2
\n
\
v_bfe_u32 %1, %2, 16, 1
\n
\
v_add3_u32 %1, %2, %1, %3
\n
\
v_cndmask_b32 %2, %1, %4, %0
\n
\
v_lshrrev_b32 %2, 16, %2
\n
\
"
:
"=s"
(
check_nan
),
"+v"
(
tmp
),
"+v"
(
u
.
fp32
)
:
"v"
(
ROUND_BIAS_FOR_BF16
),
"v"
(
FP32_NAN
));
return
uint16_t
(
u
.
int32
);
}
// Truncate instead of rounding, preserving SNaN
CK_TILE_HOST_DEVICE
constexpr
uint16_t
float_to_bf16_truc_nan_raw
(
float
f
)
...
...
@@ -177,6 +209,8 @@ CK_TILE_HOST_DEVICE constexpr uint16_t float_to_bf16_raw(float f, constant<round
{
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard
)
return
float_to_bf16_rtn_raw
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
standard_asm
)
return
float_to_bf16_rtn_asm
(
f
);
else
if
constexpr
(
rounding
==
bf16_rounding_mode
::
truncate_with_nan
)
return
float_to_bf16_truc_nan_raw
(
f
);
else
...
...
include/ck_tile/core/numeric/math.hpp
View file @
d71189ff
...
...
@@ -536,13 +536,20 @@ float log(float x) { return __logf(x); };
CK_TILE_HOST
float
log
(
float
x
)
{
return
std
::
logf
(
x
);
};
CK_TILE_DEVICE
uint
32
_t
sad
(
uint
32
_t
x
,
uint
32
_t
y
,
uint
32
_t
acc
)
CK_TILE_DEVICE
uint
16
_t
sad
_u16
(
uint
16
_t
x
,
uint
16
_t
y
,
uint
16
_t
acc
)
{
// TODO: this is hacky, we use u16
return
__builtin_amdgcn_sad_u16
(
x
,
y
,
acc
);
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
CK_TILE_DEVICE
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
/// TODO: replace inline asm when intrinsic is available
uint32_t
res
;
asm
volatile
(
"v_sad_u32 %0, %1, %2, %3"
:
"=v"
(
res
)
:
"v"
(
x
),
"v"
(
y
),
"v"
(
acc
));
return
res
;
}
CK_TILE_HOST
uint32_t
sad_u32
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
(
x
>
y
?
(
x
-
y
)
:
(
y
-
x
))
+
acc
;
}
...
...
include/ck_tile/core/tensor/tile_window.hpp
View file @
d71189ff
...
...
@@ -214,6 +214,12 @@ struct tile_window_with_static_distribution
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move thread's window adaptor coordinate and bottom tensor coordinate
// [p0, p1, ..., y0, y1, ...] ==> [x0, x1, ...] ==> [x0', x1', ...] ==> [offset]
CK_TILE_DEVICE
void
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
...
...
@@ -393,7 +399,8 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord
,
bool_constant
<
oob_conditional_check
>
{},
pre_nop_
);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE || \
CK_TILE_WORKAROUND_ROCM_6_2_SCRATCH_MEMORY_ISSUE
asm
volatile
(
""
);
// this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
...
...
@@ -843,6 +850,17 @@ struct tile_window_with_static_lengths
CK_TILE_DEVICE
constexpr
auto
get_window_origin
()
const
{
return
window_origin_
;
}
CK_TILE_DEVICE
void
set_window_origin
(
const
BottomTensorIndex
&
new_window_origin
)
{
window_origin_
=
new_window_origin
;
}
CK_TILE_DEVICE
constexpr
void
set_bottom_tensor_view_data_ptr
(
typename
BottomTensorView
::
DataType
*
data
)
{
bottom_tensor_view_
.
buf_
.
p_data_
=
data
;
}
// move window-origin
CK_TILE_DEVICE
void
move
(
const
BottomTensorIndex
&
step
)
{
window_origin_
+=
step
;
}
...
...
@@ -871,6 +889,39 @@ make_tile_window(const TensorView_& tensor_view,
tensor_view
,
window_lengths
,
origin
};
}
// duplicate tile window and replace its origin
template
<
typename
TensorView
,
typename
WindowLengths
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
)
{
return
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>
{
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
};
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
multi_index
<
TensorView
::
get_num_of_dimension
()
>&
origin
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
origin
,
tile_distribution
);
}
template
<
typename
TensorView
,
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
tile_window_with_static_lengths
<
TensorView
,
WindowLengths
>&
tile_window
,
const
StaticTileDistribution
&
tile_distribution
)
{
return
make_tile_window
(
tile_window
.
get_bottom_tensor_view
(),
tile_window
.
get_window_lengths
(),
tile_window
.
get_window_origin
(),
tile_distribution
);
}
template
<
typename
TensorView_
,
typename
WindowLengths_
>
CK_TILE_DEVICE
void
move_tile_window
(
tile_window_with_static_lengths
<
TensorView_
,
WindowLengths_
>&
window
,
...
...
include/ck_tile/core/utility/type_traits.hpp
View file @
d71189ff
...
...
@@ -22,6 +22,23 @@ using remove_cvref_t = remove_cv_t<std::remove_reference_t<T>>;
template
<
typename
T
>
using
remove_pointer_t
=
typename
std
::
remove_pointer
<
T
>::
type
;
template
<
typename
From
,
typename
To
>
struct
copy_const
{
static_assert
(
!
std
::
is_const_v
<
From
>
);
using
type
=
To
;
};
template
<
typename
From
,
typename
To
>
struct
copy_const
<
const
From
,
To
>
{
using
type
=
std
::
add_const_t
<
typename
copy_const
<
From
,
To
>::
type
>
;
};
template
<
typename
From
,
typename
To
>
using
copy_const_t
=
typename
copy_const
<
From
,
To
>::
type
;
namespace
detail
{
template
<
class
Default
,
class
AlwaysVoid
,
template
<
class
...
>
class
Op
,
class
...
Args
>
struct
detector
...
...
include/ck_tile/host.hpp
View file @
d71189ff
...
...
@@ -15,6 +15,7 @@
#include "ck_tile/host/reference/reference_batched_elementwise.hpp"
#include "ck_tile/host/reference/reference_batched_gemm.hpp"
#include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
...
...
include/ck_tile/host/host_tensor.hpp
View file @
d71189ff
...
...
@@ -155,7 +155,12 @@ struct HostTensorDescriptor
return
space
;
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mLens
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_lengths
()
const
{
return
mLens
;
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mStrides
[
dim
];
}
const
std
::
vector
<
std
::
size_t
>&
get_strides
()
const
{
return
mStrides
;
}
template
<
typename
...
Is
>
...
...
@@ -325,8 +330,12 @@ struct HostTensor
{
}
std
::
size_t
get_length
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_length
(
dim
);
}
decltype
(
auto
)
get_lengths
()
const
{
return
mDesc
.
get_lengths
();
}
std
::
size_t
get_stride
(
std
::
size_t
dim
)
const
{
return
mDesc
.
get_stride
(
dim
);
}
decltype
(
auto
)
get_strides
()
const
{
return
mDesc
.
get_strides
();
}
std
::
size_t
get_num_of_dimension
()
const
{
return
mDesc
.
get_num_of_dimension
();
}
...
...
include/ck_tile/host/kernel_launch.hpp
View file @
d71189ff
...
...
@@ -73,17 +73,17 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
{
// clang-format off
if
(
!
s
.
time_kernel_
)
{
(
callables
(
s
),...);
hip_check_error
(
hipGetLastError
());
(
callables
(
s
),...);
HIP_CHECK_ERROR
(
hipGetLastError
());
return
0
;
}
if
(
s
.
is_gpu_timer_
)
{
gpu_timer
timer
{};
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
@@ -92,10 +92,10 @@ CK_TILE_HOST float launch_kernel(const stream_config& s, Callables... callables)
cpu_timer
timer
{};
// warmup
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
cold_niters_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
start
(
s
.
stream_id_
);
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
hip_check_error
(
hipGetLastError
());
for
(
int
i
=
0
;
i
<
s
.
nrepeat_
;
i
++
)
{
(
callables
(
s
),...);
}
HIP_CHECK_ERROR
(
hipGetLastError
());
timer
.
stop
(
s
.
stream_id_
);
return
timer
.
duration
()
/
s
.
nrepeat_
;
...
...
include/ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp
0 → 100644
View file @
d71189ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <cassert>
#include <thread>
namespace
ck_tile
{
template
<
typename
DataType
,
typename
ComputeDataType
=
float
>
CK_TILE_HOST
void
reference_batched_rotary_position_embedding
(
const
HostTensor
<
DataType
>&
input_bsd
,
const
HostTensor
<
DataType
>&
cos_sd
,
const
HostTensor
<
DataType
>&
sin_sd
,
bool
interleaved
,
HostTensor
<
DataType
>&
output_bsd
,
bool
use_1_row_sin_cos
=
false
)
{
assert
(
cos_sd
.
get_num_of_dimension
()
==
2
&&
sin_sd
.
get_num_of_dimension
()
==
2
);
assert
(
cos_sd
.
get_length
(
0
)
==
sin_sd
.
get_length
(
0
)
&&
cos_sd
.
get_length
(
1
)
==
sin_sd
.
get_length
(
1
));
const
index_t
rotary_dim
=
cos_sd
.
get_length
(
1
)
*
2
;
assert
(
static_cast
<
std
::
size_t
>
(
rotary_dim
)
<=
input_bsd
.
get_length
(
2
));
output_bsd
.
ForEach
([
&
](
auto
&
self
,
auto
i
)
{
const
index_t
i_d
=
i
[
2
];
if
(
rotary_dim
<=
i_d
)
{
self
(
i
)
=
input_bsd
(
i
);
return
;
}
assert
(
i_d
<
rotary_dim
);
const
index_t
i_s
=
i
[
1
];
const
index_t
i_s_cos_sin
=
(
use_1_row_sin_cos
?
0
:
i_s
);
const
ComputeDataType
cos
=
type_convert
<
ComputeDataType
>
(
interleaved
?
cos_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
cos_sd
(
i_s_cos_sin
,
i_d
%
cos_sd
.
get_length
(
1
)));
const
ComputeDataType
sin
=
type_convert
<
ComputeDataType
>
(
interleaved
?
sin_sd
(
i_s_cos_sin
,
i_d
/
2
)
:
sin_sd
(
i_s_cos_sin
,
i_d
%
sin_sd
.
get_length
(
1
)));
const
ComputeDataType
half_rotated_input
=
[
&
]
{
const
index_t
i_b
=
i
[
0
];
if
(
interleaved
)
{
const
bool
is_even
=
(
i_d
%
2
==
0
);
const
index_t
pos
=
i_d
+
(
is_even
?
1
:
-
1
);
const
ComputeDataType
sign
=
(
is_even
?
-
1
:
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
else
{
const
index_t
half_rdim
=
(
rotary_dim
/
2
);
const
index_t
pos
=
(
i_d
+
half_rdim
)
%
rotary_dim
;
const
ComputeDataType
sign
=
(
pos
<
half_rdim
?
1
:
-
1
);
return
sign
*
type_convert
<
ComputeDataType
>
(
input_bsd
(
i_b
,
i_s
,
pos
));
}
}();
ComputeDataType
result
=
type_convert
<
ComputeDataType
>
(
input_bsd
(
i
))
*
cos
+
half_rotated_input
*
sin
;
self
(
i
)
=
type_convert
<
DataType
>
(
result
);
});
}
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
d71189ff
...
...
@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...
...
@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
...
...
include/ck_tile/ops/fmha/block/block_position_encoding.hpp
View file @
d71189ff
...
...
@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT
=
2
,
};
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
struct
Alibi
{
static_assert
(
1
<=
LogMaxSadOprndSize
&&
LogMaxSadOprndSize
<=
32
,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t"
);
// RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...
...
@@ -79,6 +82,19 @@ struct Alibi
mode
=
mode_
;
}
CK_TILE_HOST
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_DEVICE
uint32_t
sad
(
uint32_t
x
,
uint32_t
y
,
uint32_t
acc
)
{
if
constexpr
(
LogMaxSadOprndSize
<=
16
)
{
return
sad_u16
(
static_cast
<
uint16_t
>
(
x
),
static_cast
<
uint16_t
>
(
y
),
static_cast
<
uint16_t
>
(
acc
));
}
return
sad_u32
(
x
,
y
,
acc
);
}
CK_TILE_HOST_DEVICE
void
update
(
DataType
&
pixel
,
index_t
row_idx
,
index_t
col_idx
)
{
if
constexpr
(
RowMajor
)
...
...
@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0
template
<
typename
DataType
,
bool
RowMajor
=
true
>
template
<
typename
DataType
,
bool
RowMajor
=
true
,
unsigned
LogMaxSadOprndSize
=
16
>
CK_TILE_HOST_DEVICE
auto
make_alibi_from_lr_mask
(
DataType
slope
,
index_t
window_left_size
,
index_t
window_right_size
,
...
...
@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode
alibi_mode
=
is_causal
?
AlibiMode
::
VERTICAL
:
static_cast
<
AlibiMode
>
(
mask_enum
)
/*either top-left or bottom-right*/
;
return
Alibi
<
DataType
,
RowMajor
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
return
Alibi
<
DataType
,
RowMajor
,
LogMaxSadOprndSize
>
{
slope
,
y_total
,
x_total
,
alibi_mode
};
}
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
...
...
include/ck_tile/ops/fmha/block/block_rotary_embedding.hpp
0 → 100644
View file @
d71189ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace
ck_tile
{
// This class is used for codegen pattern matching
enum
class
RotaryEmbeddingEnum
{
NONE
=
0
,
INTERLEAVED
=
1
,
// combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED
=
2
,
// combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template
<
RotaryEmbeddingEnum
>
struct
RotaryEmbeddingEnumToStr
;
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
NONE
>
{
static
constexpr
const
char
*
name
=
""
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
INTERLEAVED
>
{
static
constexpr
const
char
*
name
=
"inter"
;
};
template
<
>
struct
RotaryEmbeddingEnumToStr
<
RotaryEmbeddingEnum
::
HALF_ROTATED
>
{
static
constexpr
const
char
*
name
=
"half"
;
};
template
<
RotaryEmbeddingEnum
RotaryEnum
,
typename
ComputeDataType
=
float
>
struct
BlockRotaryEmbedding
{
template
<
typename
DistributedTensor
,
typename
OtherDramBlockWindow
,
typename
RotaryCosDramBlockWindow
,
typename
RotarySinDramBlockWindow
>
CK_TILE_HOST_DEVICE
static
void
apply
(
DistributedTensor
&
tile
,
OtherDramBlockWindow
other_window
,
RotaryCosDramBlockWindow
rotary_cos_window
,
RotarySinDramBlockWindow
rotary_sin_window
,
index_t
rotary_dim
,
index_t
thread_end
)
{
using
DataType
=
typename
remove_cvref_t
<
DistributedTensor
>::
DataType
;
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
INTERLEAVED
)
{
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
if
(
thread_end
<=
rotary_dim
)
{
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
2
>
{}([
&
](
auto
idx
)
{
const
auto
left
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
right
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
+
1
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
/
2
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
/
2
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
left
*
cos
-
right
*
sin
);
tile
.
thread_buf_
[
idx
+
1
]
=
type_convert
<
DataType
>
(
right
*
cos
+
left
*
sin
);
});
}
}
else
if
constexpr
(
RotaryEnum
==
RotaryEmbeddingEnum
::
HALF_ROTATED
)
{
if
(
thread_end
<=
rotary_dim
)
{
const
bool
is_left
=
(
thread_end
<=
(
rotary_dim
/
2
));
move_tile_window
(
other_window
,
{
0
,
is_left
?
rotary_dim
/
2
:
-
(
rotary_dim
/
2
)});
auto
other_tile
=
load_tile
(
other_window
);
move_tile_window
(
rotary_cos_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_cos_tile
=
load_tile
(
rotary_cos_window
);
move_tile_window
(
rotary_sin_window
,
{
0
,
is_left
?
0
:
-
(
rotary_dim
/
2
)});
auto
rotary_sin_tile
=
load_tile
(
rotary_sin_window
);
constexpr
index_t
thread_buffer_size
=
decltype
(
tile
.
thread_buf_
)
::
size
();
static_for
<
0
,
thread_buffer_size
,
1
>
{}([
&
](
auto
idx
)
{
const
auto
curr
=
type_convert
<
ComputeDataType
>
(
tile
.
thread_buf_
[
idx
]);
const
auto
other
=
type_convert
<
ComputeDataType
>
(
other_tile
.
thread_buf_
[
idx
]);
const
auto
cos
=
type_convert
<
ComputeDataType
>
(
rotary_cos_tile
.
thread_buf_
[
idx
]);
const
auto
sin
=
type_convert
<
ComputeDataType
>
(
rotary_sin_tile
.
thread_buf_
[
idx
]);
tile
.
thread_buf_
[
idx
]
=
type_convert
<
DataType
>
(
curr
*
cos
+
other
*
(
is_left
?
-
sin
:
sin
));
});
}
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/block/page_block_navigator.hpp
0 → 100644
View file @
d71189ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace
ck_tile
{
// assume that we have only 1 page-block/tensor view
template
<
typename
TensorView
>
struct
TrivialPageBlockNavigator
{
using
DataType
=
typename
TensorView
::
DataType
;
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
TrivialPageBlockNavigator
(
const
TensorView
&
tensor_view_
)
:
tensor_view
(
tensor_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
));
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
constexpr
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
return
make_tuple
(
/*block_index=*/
0
,
ck_tile
::
make_tile_window
(
tensor_view
,
window_lengths
,
window_origin
,
tile_distribution
));
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
static
index_t
move_tile_window
(
index_t
/*block_index*/
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
return
/*block_index=*/
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
{
return
global_window_origin
;
}
CK_TILE_HOST_DEVICE
static
constexpr
WindowOrigin
to_global_window_origin
(
index_t
/*block_index*/
,
const
WindowOrigin
&
local_window_origin
)
{
return
local_window_origin
;
}
private:
TensorView
tensor_view
;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template
<
typename
DataType_
,
index_t
VirtualDim
,
typename
TensorView
>
struct
PageBlockNavigator
{
using
DataType
=
DataType_
;
static_assert
(
std
::
is_same_v
<
DataType
,
typename
TensorView
::
DataType
>
);
static_assert
(
VirtualDim
==
0
||
VirtualDim
==
1
,
"only support 2d tile window"
);
using
WindowOrigin
=
multi_index
<
2
>
;
CK_TILE_HOST_DEVICE
constexpr
PageBlockNavigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks_
,
long_index_t
block_stride_
,
long_index_t
fixed_offset_
,
const
int32_t
*
physical_block_indices_
,
index_t
num_blocks_
,
index_t
page_block_size_
,
const
TensorView
&
complete_view_
,
const
TensorView
&
last_view_
)
:
physical_blocks
(
reinterpret_cast
<
DataType
*>
(
physical_blocks_
)),
block_stride
(
block_stride_
),
fixed_offset
(
fixed_offset_
),
physical_block_indices
(
physical_block_indices_
),
num_blocks
(
num_blocks_
),
page_block_size
(
page_block_size_
),
complete_view
(
complete_view_
),
last_view
(
last_view_
)
{
}
template
<
typename
WindowLengths
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
WindowLengths
,
typename
TileDistribution
>
CK_TILE_HOST_DEVICE
auto
make_tile_window
(
const
WindowLengths
&
window_lengths
,
const
WindowOrigin
&
window_origin
,
const
TileDistribution
&
tile_distribution
)
const
{
const
index_t
block_index
=
get_block_index
(
window_origin
);
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
window_origin
);
auto
new_tile_window
=
ck_tile
::
make_tile_window
(
is_last_block
(
block_index
)
?
last_view
:
complete_view
,
window_lengths
,
local_window_origin
,
tile_distribution
);
new_tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
block_index
));
return
make_tuple
(
block_index
,
new_tile_window
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
index_t
move_tile_window
(
index_t
block_index
,
TileWindow
&
tile_window
,
const
typename
remove_cvref_t
<
TileWindow
>::
BottomTensorIndex
&
step
)
const
{
ck_tile
::
move_tile_window
(
tile_window
,
step
);
const
WindowOrigin
global_window_origin
=
to_global_window_origin
(
block_index
,
tile_window
.
get_window_origin
());
const
WindowOrigin
local_window_origin
=
to_local_window_origin
(
global_window_origin
);
const
index_t
new_block_index
=
get_block_index
(
global_window_origin
);
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
local_window_origin
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
return
new_block_index
;
}
CK_TILE_HOST_DEVICE
bool
is_last_block
(
index_t
block_index
)
const
{
return
block_index
==
num_blocks
-
1
;
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
bool
is_cross_block
(
index_t
block_index
,
const
TileWindow
&
tile_window
)
const
{
const
index_t
origin
=
tile_window
.
get_window_origin
().
at
(
number
<
VirtualDim
>
{});
const
index_t
length
=
tile_window
.
get_window_lengths
().
at
(
number
<
VirtualDim
>
{});
return
(
block_index
<
num_blocks
-
1
)
&&
(
page_block_size
<
origin
+
length
);
}
template
<
typename
TileWindow
>
CK_TILE_HOST_DEVICE
void
move_to_block
(
index_t
block_index
,
TileWindow
&
tile_window
,
index_t
new_block_index
)
const
{
const
multi_index
<
2
>
step
=
[
&
]()
{
const
index_t
origin_diff
=
(
block_index
-
new_block_index
)
*
page_block_size
;
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
origin_diff
,
0
);
}
else
{
return
make_multi_index
(
0
,
origin_diff
);
}
}();
/// TODO: only update necessary attributes
tile_window
.
bottom_tensor_view_
.
desc_
=
(
is_last_block
(
new_block_index
)
?
last_view
:
complete_view
).
get_tensor_descriptor
();
tile_window
.
set_window_origin
(
tile_window
.
get_window_origin
()
+
step
);
tile_window
.
set_bottom_tensor_view_data_ptr
(
get_block_ptr
(
new_block_index
));
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_local_window_origin
(
const
WindowOrigin
&
global_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
0
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
length
-
page_block_size
*
num_complete_blocks
,
global_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
const
index_t
length
=
global_window_origin
.
at
(
number
<
1
>
{});
const
index_t
num_complete_blocks
=
integer_divide_floor
(
length
,
page_block_size
);
return
make_multi_index
(
global_window_origin
.
at
(
number
<
0
>
{}),
length
-
page_block_size
*
num_complete_blocks
);
}
}
CK_TILE_HOST_DEVICE
WindowOrigin
to_global_window_origin
(
index_t
block_index
,
const
WindowOrigin
&
local_window_origin
)
const
{
if
constexpr
(
VirtualDim
==
0
)
{
return
make_multi_index
(
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
0
>
{}),
local_window_origin
.
at
(
number
<
1
>
{}));
}
else
{
return
make_multi_index
(
local_window_origin
.
at
(
number
<
0
>
{}),
block_index
*
page_block_size
+
local_window_origin
.
at
(
number
<
1
>
{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType
*
get_block_ptr
(
index_t
block_index
)
const
{
return
physical_blocks
+
physical_block_indices
[
block_index
]
*
block_stride
+
fixed_offset
;
}
CK_TILE_HOST_DEVICE
int32_t
get_block_index
(
const
WindowOrigin
&
global_window_origin
)
const
{
return
integer_divide_floor
(
global_window_origin
.
at
(
number
<
VirtualDim
>
{}),
page_block_size
);
}
DataType
*
physical_blocks
;
long_index_t
block_stride
;
long_index_t
fixed_offset
;
const
int32_t
*
physical_block_indices
;
index_t
num_blocks
;
index_t
page_block_size
;
TensorView
complete_view
;
TensorView
last_view
;
};
template
<
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
const
TensorView
&
tensor_view
)
{
return
TrivialPageBlockNavigator
<
TensorView
>
(
tensor_view
);
}
template
<
typename
DataType
,
index_t
VirtualDim
,
typename
TensorView
>
CK_TILE_HOST_DEVICE
auto
make_page_block_navigator
(
copy_const_t
<
DataType
,
void
>*
physical_blocks
,
long_index_t
block_stride
,
long_index_t
fixed_offset
,
const
int32_t
*
physical_block_indices
,
index_t
num_blocks
,
index_t
page_block_size
,
const
TensorView
&
complete_view
,
const
TensorView
&
last_view
)
{
return
PageBlockNavigator
<
DataType
,
VirtualDim
,
TensorView
>
(
physical_blocks
,
block_stride
,
fixed_offset
,
physical_block_indices
,
num_blocks
,
page_block_size
,
complete_view
,
last_view
);
}
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp
0 → 100644
View file @
d71189ff
This diff is collapsed.
Click to expand it.
include/ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp
0 → 100644
View file @
d71189ff
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN0_
,
index_t
kK0_
,
index_t
kN1_
>
struct
FmhaFwdAppendKVTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN0
=
kN0_
;
static
constexpr
ck_tile
::
index_t
kK0
=
kK0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static_assert
(
kK0
==
kN1
);
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_knew
)
{
// TODO: this may need tuning
return
dim3
(
std
::
max
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
),
ck_tile
::
integer_divide_ceil
(
seqlen_knew
,
kN0
)),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
i_tile
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
d71189ff
This diff is collapsed.
Click to expand it.
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment