Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6a25d081
Commit
6a25d081
authored
Oct 09, 2024
by
carlushuang
Browse files
Merge remote-tracking branch 'origin/develop' into ck_tile/fav2_fwd_sept
parents
02f8c487
ceaed8e0
Changes
73
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
244 additions
and
113 deletions
+244
-113
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
+3
-1
example/ck_tile/03_gemm/README.md
example/ck_tile/03_gemm/README.md
+14
-6
example/ck_tile/04_img2col/README.md
example/ck_tile/04_img2col/README.md
+2
-1
include/ck/config.h.in
include/ck/config.h.in
+0
-7
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+6
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
+10
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
.../gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
...evice/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
...grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
...device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+1
-1
include/ck_tile/core/container/array.hpp
include/ck_tile/core/container/array.hpp
+1
-0
include/ck_tile/host/arg_parser.hpp
include/ck_tile/host/arg_parser.hpp
+15
-5
include/ck_tile/host/convolution_parameter.hpp
include/ck_tile/host/convolution_parameter.hpp
+0
-6
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+77
-14
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+70
-13
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+22
-29
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+8
-9
No files found.
example/ck_tile/02_layernorm2d/layernorm2d_fwd.cpp
View file @
6a25d081
...
...
@@ -35,7 +35,9 @@ float layernorm2d_fwd(layernorm2d_fwd_traits t,
YDataType
,
MeanDataType
,
InvStdDataType
,
Shape
>
;
Shape
,
true
,
true
>
;
using
Kernel
=
ck_tile
::
Layernorm2dFwd
<
PipelineProblem
>
;
...
...
example/ck_tile/03_gemm/README.md
View file @
6a25d081
...
...
@@ -6,7 +6,8 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_gemm_basic -j
```
This will result in an executable
`build/bin/tile_example_gemm_basic`
...
...
@@ -14,10 +15,17 @@ This will result in an executable `build/bin/tile_example_gemm_basic`
## example
```
args:
-m m dimension (default:3328)
-n m dimension (default:4096)
-b batch size (default:1)
-m m dimension (default:1024)
-n n dimension (default:2048)
-k k dimension (default:64)
-e epsilon (default:1e-5)
-v cpu validation or not (default:1)
-prec precision (default:fp16)
-stride_a Tensor A stride (default:0)
-stride_b Tensor B stride (default:0)
-stride_c Tensor C stride (default:0)
-v 0. No validation, 1. Validation on CPU, 2. Validation on GPU (default:2)
-e Absolute error tolerance (default:1e-5)
-prec data type. fp16/bf16/fp8/bf8 (default:fp16)
-warmup number of iterations before benchmark the kernel (default:10)
-repeat number of iterations to benchmark the kernel (default:100)
-timer gpu:gpu timer, cpu:cpu timer (default:gpu)
```
example/ck_tile/04_img2col/README.md
View file @
6a25d081
...
...
@@ -6,7 +6,8 @@ This folder contains example for Image to Column using ck_tile tile-programming
```
# in the root of ck_tile
mkdir build && cd build
sh ../script/cmake-ck-dev.sh ../ <arch> # you can replace this <arch> to gfx90a, gfx942...
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch>
make tile_example_img2col -j
```
This will result in an executable
`build/bin/tile_example_img2col`
include/ck/config.h.in
View file @
6a25d081
...
...
@@ -97,13 +97,6 @@
#cmakedefine CK_ENABLE_DL_KERNELS @CK_ENABLE_DL_KERNELS@
#endif
//
// Instances supports in the current CK build
//
#ifndef CK_ENABLE_INSTANCES_ONLY
#cmakedefine CK_ENABLE_INSTANCES_ONLY @CK_ENABLE_INSTANCES_ONLY@
#endif
//
// CK kernels which support XDL (MI series)
//
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
6a25d081
...
...
@@ -66,6 +66,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
}
else
...
...
@@ -143,6 +146,9 @@ float launch_and_time_kernel_with_preprocess(const StreamConfig& stream_config,
hip_check_error
(
hipEventElapsedTime
(
&
total_time
,
start
,
stop
));
hip_check_error
(
hipEventDestroy
(
start
));
hip_check_error
(
hipEventDestroy
(
stop
));
return
total_time
/
nrepeat
;
}
else
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1_ab_scale.hpp
View file @
6a25d081
...
...
@@ -308,7 +308,7 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -390,9 +390,10 @@ struct BlockwiseGemmXdlops_pipeline_v1_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2_ab_scale.hpp
View file @
6a25d081
...
...
@@ -350,7 +350,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -443,7 +443,7 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -518,9 +518,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
@@ -575,9 +576,10 @@ struct BlockwiseGemmXdlops_pipeline_v2_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3_ab_scale.hpp
View file @
6a25d081
...
...
@@ -427,7 +427,7 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
template
Run
<
>
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
...
...
@@ -504,9 +504,10 @@ struct BlockwiseGemmXdlops_pipeline_v3_ab_scale<BlockGemmPipelineScheduler::Intr
using
mfma_input_type
=
typename
vector_type
<
ComputeDataType
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
xdlops_gemm
.
template
Run
<
>(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf_per_scale
.
GetVectorTypeReference
(
I0
));
});
static_for
<
0
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
1
>
{}([
&
](
auto
t
)
{
constexpr
index_t
c_offset
=
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multi_abd_xdl_fixed_nk.hpp
View file @
6a25d081
...
...
@@ -64,7 +64,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
const
auto
StrideAs
=
gemm_desc_ptr
[
group_id
].
StrideAs
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_splitk_xdl_cshuffle_two_stage.hpp
View file @
6a25d081
...
...
@@ -345,7 +345,7 @@ struct DeviceGroupedGemmMultipleDSplitKXdlCShuffleTwoStage
const
index_t
N
=
gemm_descs
[
i
].
N_
;
const
index_t
K
=
gemm_descs
[
i
].
K_
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
skipped_group_count_
++
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_multiple_d_xdl_cshuffle_tile_loop.hpp
View file @
6a25d081
...
...
@@ -109,7 +109,7 @@ __global__ void
N
=
gemm_desc_ptr
[
group_id
].
N
;
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
{
grid_size_grp
=
0
;
continue
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
6a25d081
...
...
@@ -68,7 +68,7 @@ __global__ void
const
index_t
N
=
gemm_desc_ptr
[
group_id
].
N
;
const
index_t
K
=
gemm_desc_ptr
[
group_id
].
K
;
if
(
M
*
N
*
K
==
0
)
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
...
...
include/ck_tile/core/container/array.hpp
View file @
6a25d081
...
...
@@ -4,6 +4,7 @@
#pragma once
#include <initializer_list>
#include <vector>
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/numeric/integer.hpp"
...
...
include/ck_tile/host/arg_parser.hpp
View file @
6a25d081
...
...
@@ -50,12 +50,22 @@ class ArgParser
}
return
*
this
;
}
void
print
()
void
print
()
const
{
// find max key length
std
::
string
::
size_type
max_key_length
=
11
;
for
(
auto
&
key
:
keys
)
{
if
(
max_key_length
<
key
.
length
())
{
max_key_length
=
key
.
length
();
}
}
printf
(
"args:
\n
"
);
for
(
auto
&
key
:
keys
)
{
auto
value
=
input_map
[
key
]
;
auto
value
=
input_map
.
at
(
key
)
;
std
::
vector
<
std
::
string
>
help_text_lines
;
size_t
pos
=
0
;
for
(
size_t
next_pos
=
value
.
help_text
.
find
(
'\n'
,
pos
);
next_pos
!=
std
::
string
::
npos
;)
...
...
@@ -69,8 +79,7 @@ class ArgParser
std
::
string
(
value
.
help_text
.
begin
()
+
pos
,
value
.
help_text
.
end
()));
std
::
string
default_value
=
std
::
string
(
"(default:"
)
+
value
.
value
+
std
::
string
(
")"
);
std
::
cout
<<
std
::
setw
(
2
)
<<
std
::
setw
(
12
-
value
.
name
.
length
())
<<
"-"
<<
key
std
::
cout
<<
std
::
setw
(
1
+
max_key_length
-
value
.
name
.
length
())
<<
"-"
<<
key
<<
std
::
setw
(
4
)
<<
" "
<<
help_text_lines
[
0
]
<<
" "
<<
default_value
<<
std
::
endl
;
...
...
@@ -78,7 +87,8 @@ class ArgParser
help_next_line
!=
help_text_lines
.
end
();
++
help_next_line
)
{
std
::
cout
<<
std
::
setw
(
17
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
std
::
cout
<<
std
::
setw
(
1
+
max_key_length
+
4
)
<<
" "
<<
*
help_next_line
<<
std
::
endl
;
}
}
}
...
...
include/ck_tile/host/convolution_parameter.hpp
View file @
6a25d081
...
...
@@ -13,7 +13,6 @@ namespace conv {
struct
ConvParam
{
ConvParam
();
ConvParam
(
ck_tile
::
index_t
n_dim
,
ck_tile
::
index_t
group_count
,
ck_tile
::
index_t
n_batch
,
...
...
@@ -199,11 +198,6 @@ struct ConvParam
}
};
ConvParam
::
ConvParam
()
:
ConvParam
::
ConvParam
(
2
,
1
,
128
,
256
,
192
,
{
3
,
3
},
{
71
,
71
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
})
{
}
CK_TILE_HOST
std
::
string
get_conv_param_parser_helper_msg
()
{
std
::
string
msg
;
...
...
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
6a25d081
...
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
integer_divide_ceil
(
x_total
,
num_splits
)
)
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
)
;
const
index_t
split_end
=
split_start
+
x_per_split
;
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
6a25d081
...
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
...
@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
struct
FmhaBwd
Common
Dropout
Kargs
struct
FmhaBwdDropout
SeedOffset
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
,
const
float
raw_scale
)
template
<
typename
T
>
union
ValueOrPointer
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaBwdCommonDropoutKargs
:
FmhaBwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
,
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
,
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
...
...
@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
struct
FmhaBwdDeterministicKargs
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
...
...
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
...
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
...
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
}
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
6a25d081
...
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
...
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
struct
FmhaFwd
Common
Dropout
Kargs
struct
FmhaFwdDropout
SeedOffset
{
void
init_dropout
(
const
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
template
<
typename
T
>
union
ValueOrPointer
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaFwdCommonDropoutKargs
:
FmhaFwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
...
...
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
...
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
float
p_drop
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
Kargs
kargs
{{
q_ptr
,
k_ptr
,
...
...
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
if
constexpr
(
kHasDropout
)
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
...
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return
BlockDropout
{
i_batch_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
drop_offset
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
6a25d081
...
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
...
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
...
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
};
struct
GroupModeKargs
...
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
...
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
seqlen_q
,
hdim_v
,
num_splits
,
...
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
batch_stride_o
,
batch_stride_lse_acc
};
batch_stride_lse_acc
,
batch_stride_o_acc
,
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
{
...
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
...
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
{
...
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_ptr
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
num_splits
,
...
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
...
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
hdim_v
_
)
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
return
TilePartitioner
::
GridSize
(
batch_size
_
,
nhead
_
,
seqlen_q
_
,
hdim_v
_
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
...
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
row_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
query_start
;
}
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
...
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
else
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
// for simplicity, batch stride we just modify the pointer
...
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
...
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_
max_
seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_
max_
seqlen_q
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
}
else
...
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
}
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
View file @
6a25d081
...
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
hdim_v
_
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
_
,
kN1
),
nhead
_
,
batch_size
_
);
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment