Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7ffb0921
Commit
7ffb0921
authored
Oct 07, 2024
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/ck_tile_universal_gemm_p1
parents
4cf45f1b
0023f01a
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
784 additions
and
207 deletions
+784
-207
include/ck_tile/ops/fmha/block/block_masking.hpp
include/ck_tile/ops/fmha/block/block_masking.hpp
+2
-2
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
+77
-14
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+70
-13
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+22
-29
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+8
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+21
-23
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+2
-2
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
...eline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
+6
-0
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
.../fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
...fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
+5
-5
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
...mha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
+4
-5
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+9
-0
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
...ile/ops/image_to_column/kernel/image_to_column_kernel.hpp
+224
-0
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
...mage_to_column/pipeline/block_image_to_column_problem.hpp
+27
-0
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
...s/image_to_column/pipeline/tile_image_to_column_shape.hpp
+32
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+247
-86
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+13
-9
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+4
-5
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+3
-0
script/cmake-ck-release.sh
script/cmake-ck-release.sh
+3
-0
No files found.
include/ck_tile/ops/fmha/block/block_masking.hpp
View file @
7ffb0921
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
...
@@ -308,9 +308,9 @@ struct SimplifiedGenericAttentionMask
{
{
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
auto
[
origin_start
,
origin_end
]
=
GetTileRangeAlongX
(
i_y
,
height
,
width
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
x_total
/
num_splits
);
const
index_t
x_per_split
=
ck_tile
::
max
(
1
,
integer_divide_ceil
(
x_total
,
num_splits
)
)
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_start
=
x_per_split
*
i_split
;
const
index_t
split_end
=
(
i_split
==
num_splits
-
1
?
x_total
:
split_start
+
x_per_split
)
;
const
index_t
split_end
=
split_start
+
x_per_split
;
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
return
ck_tile
::
make_tuple
(
ck_tile
::
max
(
origin_start
,
split_start
),
ck_tile
::
min
(
origin_end
,
split_end
));
ck_tile
::
min
(
origin_end
,
split_end
));
...
...
include/ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp
View file @
7ffb0921
...
@@ -6,8 +6,11 @@
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
...
@@ -194,11 +197,39 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
ck_tile
::
GenericAttentionMaskEnum
mask_type
;
};
};
struct
FmhaBwd
Common
Dropout
Kargs
struct
FmhaBwdDropout
SeedOffset
{
{
void
init_dropout
(
const
float
p_drop
,
template
<
typename
T
>
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
,
union
ValueOrPointer
const
float
raw_scale
)
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaBwdCommonDropoutKargs
:
FmhaBwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
,
float
raw_scale
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
,
float
raw_scale
)
{
{
float
p_undrop
=
1.0
-
p_drop
;
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
p_undrop_in_uint8_t
=
...
@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
...
@@ -206,23 +237,25 @@ struct FmhaBwdDQDKDVKernel
rp_undrop
=
1.0
/
p_undrop
;
rp_undrop
=
1.0
/
p_undrop
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
scale_rp_undrop
=
rp_undrop
*
raw_scale
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
}
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
float
scale_rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
};
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
struct
FmhaBwdBatchModeDropoutKargs
:
FmhaBwdCommonDropoutKargs
{
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
ck_tile
::
index_t
batch_stride_randval
=
0
;
};
};
struct
FmhaBwdDeterministicKargs
struct
FmhaBwdDeterministicKargs
{
{
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
ck_tile
::
index_t
split_stride_dq_acc
=
0
;
...
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
...
@@ -327,7 +360,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
...
@@ -405,7 +439,20 @@ struct FmhaBwdDQDKDVKernel
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
if
constexpr
(
kIsStoreRandval
)
{
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
...
@@ -471,7 +518,8 @@ struct FmhaBwdDQDKDVKernel
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
window_size_right
,
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
...
@@ -539,7 +587,20 @@ struct FmhaBwdDQDKDVKernel
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
,
scale
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
,
scale
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
),
scale
);
}
if
constexpr
(
kIsStoreRandval
)
if
constexpr
(
kIsStoreRandval
)
{
{
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
...
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
...
@@ -958,8 +1019,10 @@ struct FmhaBwdDQDKDVKernel
return
FmhaDropout
{
i_batch_
,
return
FmhaDropout
{
i_batch_
,
i_nhead_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
kargs
.
drop_offset
,
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
};
kargs
.
p_undrop_in_uint8_t
};
}
}
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
7ffb0921
...
@@ -6,8 +6,11 @@
...
@@ -6,8 +6,11 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include <string>
#include <string>
#include <type_traits>
#include <type_traits>
#include <utility>
#include <variant>
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S[seqlen_q, seqlen_k] = Q[seqlen_q, hdim_q] @ K[seqlen_k, hdim_q]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
// S'[seqlen_q, seqlen_k] = S[seqlen_q, seqlen_k] * Scale[1]
...
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
...
@@ -170,29 +173,55 @@ struct FmhaFwdKernel
ck_tile
::
index_t
batch_stride_lse
=
0
;
ck_tile
::
index_t
batch_stride_lse
=
0
;
};
};
struct
FmhaFwd
Common
Dropout
Kargs
struct
FmhaFwdDropout
SeedOffset
{
{
void
init_dropout
(
const
float
p_drop
,
template
<
typename
T
>
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
union
ValueOrPointer
{
T
val
;
const
T
*
ptr
;
};
ValueOrPointer
<
uint64_t
>
drop_seed
;
ValueOrPointer
<
uint64_t
>
drop_offset
;
bool
is_drop_seed_offset_from_host
;
};
struct
FmhaFwdCommonDropoutKargs
:
FmhaFwdDropoutSeedOffset
{
void
init_dropout
(
float
p_drop
,
uint64_t
seed
,
uint64_t
offset
)
{
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
this
->
drop_seed
.
val
=
seed
;
this
->
drop_offset
.
val
=
offset
;
this
->
is_drop_seed_offset_from_host
=
true
;
}
void
init_dropout
(
float
p_drop
,
const
uint64_t
*
seed_ptr
,
const
uint64_t
*
offset_ptr
)
{
{
float
p_undrop
=
1.0
-
p_drop
;
float
p_undrop
=
1.0
-
p_drop
;
p_undrop_in_uint8_t
=
p_undrop_in_uint8_t
=
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
uint8_t
(
std
::
floor
(
p_undrop
*
std
::
numeric_limits
<
uint8_t
>::
max
()));
rp_undrop
=
1.0
/
p_undrop
;
rp_undrop
=
1.0
/
p_undrop
;
drop_seed
=
std
::
get
<
0
>
(
drop_seed_offset
);
this
->
drop_seed
.
ptr
=
seed_ptr
;
drop_offset
=
std
::
get
<
1
>
(
drop_seed_offset
);
this
->
drop_offset
.
ptr
=
offset_ptr
;
this
->
is_drop_seed_offset_from_host
=
false
;
}
}
float
rp_undrop
=
1
;
float
rp_undrop
=
1
;
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
uint8_t
p_undrop_in_uint8_t
=
std
::
numeric_limits
<
uint8_t
>::
max
();
bool
is_store_randval
=
false
;
bool
is_store_randval
=
false
;
uint64_t
drop_seed
=
1
;
uint64_t
drop_offset
=
0
;
void
*
rand_val_ptr
=
nullptr
;
void
*
rand_val_ptr
=
nullptr
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
ck_tile
::
index_t
nhead_stride_randval
=
0
;
};
};
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
struct
FmhaFwdBatchModeDropoutKargs
:
FmhaFwdCommonDropoutKargs
{
{
ck_tile
::
index_t
batch_stride_randval
=
0
;
ck_tile
::
index_t
batch_stride_randval
=
0
;
...
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
...
@@ -278,7 +307,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
...
@@ -344,7 +374,19 @@ struct FmhaFwdKernel
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
...
@@ -392,7 +434,8 @@ struct FmhaFwdKernel
ck_tile
::
index_t
mask_type
,
ck_tile
::
index_t
mask_type
,
float
p_drop
,
float
p_drop
,
bool
s_randval
,
bool
s_randval
,
const
std
::
tuple
<
uint64_t
,
uint64_t
>&
drop_seed_offset
)
std
::
variant
<
std
::
pair
<
uint64_t
,
uint64_t
>
,
std
::
pair
<
const
void
*
,
const
void
*>>
drop_seed_offset
)
{
{
Kargs
kargs
{{
q_ptr
,
Kargs
kargs
{{
q_ptr
,
k_ptr
,
k_ptr
,
...
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
...
@@ -455,7 +498,19 @@ struct FmhaFwdKernel
}
}
if
constexpr
(
kHasDropout
)
if
constexpr
(
kHasDropout
)
{
{
kargs
.
init_dropout
(
p_drop
,
drop_seed_offset
);
if
(
drop_seed_offset
.
index
()
==
0
)
// seed & offset come from host
{
const
auto
&
[
seed
,
offset
]
=
std
::
get
<
0
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
seed
,
offset
);
}
else
// seed & offset come from device
{
const
auto
&
[
seed_ptr
,
offset_ptr
]
=
std
::
get
<
1
>
(
drop_seed_offset
);
kargs
.
init_dropout
(
p_drop
,
reinterpret_cast
<
const
uint64_t
*>
(
seed_ptr
),
reinterpret_cast
<
const
uint64_t
*>
(
offset_ptr
));
}
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
rand_val_ptr
=
rand_val_ptr
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
stride_randval
=
stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
kargs
.
nhead_stride_randval
=
nhead_stride_randval
;
...
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
...
@@ -748,8 +803,10 @@ struct FmhaFwdKernel
return
BlockDropout
{
i_batch_
,
return
BlockDropout
{
i_batch_
,
i_nhead_
,
i_nhead_
,
kargs
.
num_head_q
,
kargs
.
num_head_q
,
kargs
.
drop_seed
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_seed
.
val
kargs
.
drop_offset
,
:
*
kargs
.
drop_seed
.
ptr
,
kargs
.
is_drop_seed_offset_from_host
?
kargs
.
drop_offset
.
val
:
*
kargs
.
drop_offset
.
ptr
,
kargs
.
rp_undrop
,
kargs
.
rp_undrop
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
p_undrop_in_uint8_t
,
kargs
.
is_store_randval
};
kargs
.
is_store_randval
};
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
7ffb0921
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -78,8 +78,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
o_ptr
;
void
*
o_ptr
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
batch
;
ck_tile
::
index_t
max_seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
seqlen_q
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
hdim_v
;
ck_tile
::
index_t
num_splits
;
ck_tile
::
index_t
num_splits
;
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -91,8 +89,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
nhead_stride_o
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
};
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -114,8 +110,9 @@ struct FmhaFwdSplitKVCombineKernel
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kStoreLSE
,
CommonLSEKargs
,
EmptyKargs
<
0
>>
,
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
std
::
conditional_t
<
kDoFp8StaticQuant
,
Fp8StaticQuantKargs
,
EmptyKargs
<
1
>>
{
{
ck_tile
::
index_t
batch_stride_o
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
batch_stride_o
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -135,7 +132,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -157,7 +153,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_acc_ptr
,
o_ptr
,
o_ptr
,
batch
,
batch
,
max_seqlen_q
,
seqlen_q
,
seqlen_q
,
hdim_v
,
hdim_v
,
num_splits
,
num_splits
,
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -166,13 +161,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
{},
// placeholder for fp8_static_quant args
{},
// placeholder for fp8_static_quant args
batch_stride_o
,
batch_stride_lse_acc
,
batch_stride_lse_acc
};
batch_stride_o_acc
,
batch_stride_o
};
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -195,7 +190,6 @@ struct FmhaFwdSplitKVCombineKernel
void
*
lse_ptr
,
void
*
lse_ptr
,
void
*
o_ptr
,
void
*
o_ptr
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
batch
,
ck_tile
::
index_t
max_seqlen_q
,
const
void
*
seqstart_q_ptr
,
const
void
*
seqstart_q_ptr
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
,
ck_tile
::
index_t
num_splits
,
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -206,7 +200,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_lse
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
nhead_stride_o
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
)
ck_tile
::
index_t
split_stride_o_acc
)
{
{
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -214,7 +207,6 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_ptr
,
o_acc_ptr
,
o_ptr
,
o_ptr
,
batch
,
batch
,
max_seqlen_q
,
-
1
,
// seqlen will be updated by another pointer
-
1
,
// seqlen will be updated by another pointer
hdim_v
,
hdim_v
,
num_splits
,
num_splits
,
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -223,7 +215,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
nhead_stride_o
,
nhead_stride_o
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for lse
{},
// placeholder for lse
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -243,12 +234,12 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
return
kargs
;
}
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
_
)
ck_tile
::
index_t
hdim_v
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
_
,
nhead
_
,
seqlen_q
_
,
hdim_v
_
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -270,10 +261,8 @@ struct FmhaFwdSplitKVCombineKernel
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
const
long_index_t
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_o_acc
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_lse
=
0
;
long_index_t
batch_offset_o
=
0
;
long_index_t
batch_offset_o
=
0
;
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -282,14 +271,16 @@ struct FmhaFwdSplitKVCombineKernel
// get starting offset for each batch
// get starting offset for each batch
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
batch_offset_lse_acc
=
query_start
;
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
row_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
query_start
;
batch_offset_lse
=
query_start
;
}
}
batch_offset_o
=
query_start
*
kargs
.
row_stride_o
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
const
auto
adjusted_seqstart_q_ptr
=
kargs
.
seqstart_q_ptr
+
i_batch
;
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
kargs
.
seqlen_q
=
adjusted_seqstart_q_ptr
[
1
]
-
adjusted_seqstart_q_ptr
[
0
];
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -303,13 +294,15 @@ struct FmhaFwdSplitKVCombineKernel
}
}
else
else
{
{
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kStoreLSE
)
if
constexpr
(
kStoreLSE
)
{
{
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
batch_offset_lse
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse
;
}
}
batch_offset_o
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o
;
}
}
// for simplicity, batch stride we just modify the pointer
// for simplicity, batch stride we just modify the pointer
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -341,7 +334,7 @@ struct FmhaFwdSplitKVCombineKernel
auto
o_acc_dram
=
[
&
]()
{
auto
o_acc_dram
=
[
&
]()
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
o_acc_ptr
,
make_tuple
(
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
num_splits
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
make_tuple
(
kargs
.
split_stride_o_acc
,
kargs
.
row_stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -351,14 +344,14 @@ struct FmhaFwdSplitKVCombineKernel
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
FmhaPipeline
::
kM0
>
{},
number
<
FmhaPipeline
::
kN1
>
{}),
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
sequence
<
false
,
kPadSeqLenQ
,
kPadHeadDimV
>
{});
const
index_t
padded_
max_
seqlen_q
=
const
index_t
padded_seqlen_q
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
1
>
{}];
const
index_t
padded_hdim_v
=
const
index_t
padded_hdim_v
=
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
o_acc_dram_view
.
get_tensor_descriptor
().
get_lengths
()[
number
<
2
>
{}];
return
transform_tensor_view
(
return
transform_tensor_view
(
o_acc_dram_view
,
o_acc_dram_view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_
max_
seqlen_q
)),
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
num_splits
,
padded_seqlen_q
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_pass_through_transform
(
padded_hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -417,7 +410,7 @@ struct FmhaFwdSplitKVCombineKernel
identity
{},
// lse_element_func
identity
{},
// lse_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
composes
(
saturates
<
fp8_t
>
{},
scales
{
kargs
.
scale_o
}),
// o_acc_element_func
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
else
else
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
...
@@ -426,7 +419,7 @@ struct FmhaFwdSplitKVCombineKernel
o_acc_dram_window
,
o_acc_dram_window
,
lse_dram_window
,
lse_dram_window
,
kargs
.
num_splits
,
kargs
.
num_splits
,
kargs
.
max_
seqlen_q
,
kargs
.
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
}();
}();
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
View file @
7ffb0921
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
...
@@ -13,21 +13,20 @@ struct FmhaFwdSplitKVCombineTilePartitioner
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
_
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
_
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
_
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
_
)
ck_tile
::
index_t
hdim_v
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
_
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
_
,
kN1
),
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
_
,
nhead
,
batch_size
_
);
batch_size
);
}
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_block
=
blockIdx
.
x
;
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
7ffb0921
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -135,9 +135,6 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_lse_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
nhead_stride_o_acc
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_lse_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
ck_tile
::
index_t
split_stride_o_acc
;
};
};
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -201,6 +198,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_q
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_lse_acc
;
ck_tile
::
index_t
batch_stride_o_acc
;
};
};
struct
GroupModeKargs
struct
GroupModeKargs
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -217,8 +216,8 @@ struct FmhaFwdSplitKVKernel
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqstart_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
const
int32_t
*
seqlen_k_ptr
;
ck_tile
::
index_t
batch_stride_k
;
ck_tile
::
index_t
batch_stride_k
;
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
;
ck_tile
::
index_t
batch_stride_v
;
// only used for paged-kvcache
};
};
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
using
Kargs
=
std
::
conditional_t
<
kIsGroupMode
,
GroupModeKargs
,
BatchModeKargs
>
;
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -296,8 +295,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -307,7 +304,9 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
reinterpret_cast
<
const
int32_t
*>
(
seqlen_k_ptr
),
batch_stride_q
,
batch_stride_q
,
batch_stride_k
,
batch_stride_k
,
batch_stride_v
};
batch_stride_v
,
batch_stride_lse_acc
,
batch_stride_o_acc
};
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -375,10 +374,8 @@ struct FmhaFwdSplitKVKernel
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_bias
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_lse_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
nhead_stride_o_acc
,
ck_tile
::
index_t
batch_stride_k
,
ck_tile
::
index_t
batch_stride_k
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_v
,
ck_tile
::
index_t
batch_stride_v
,
// only used for paged-kvcache
ck_tile
::
index_t
batch_stride_lse_acc
,
ck_tile
::
index_t
batch_stride_o_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_lse_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
split_stride_o_acc
,
ck_tile
::
index_t
window_size_left
,
ck_tile
::
index_t
window_size_left
,
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
...
@@ -412,8 +409,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v
,
nhead_stride_v
,
nhead_stride_lse_acc
,
nhead_stride_lse_acc
,
nhead_stride_o_acc
,
nhead_stride_o_acc
,
batch_stride_lse_acc
,
batch_stride_o_acc
,
split_stride_lse_acc
,
split_stride_lse_acc
,
split_stride_o_acc
},
// args for common karg
split_stride_o_acc
},
// args for common karg
{},
// placeholder for bias
{},
// placeholder for bias
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
...
@@ -452,11 +447,11 @@ struct FmhaFwdSplitKVKernel
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
seqlen_q
,
hdim_v
,
num_splits
);
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_
seqlen_q
,
hdim_v
,
num_splits
);
}
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -483,8 +478,7 @@ struct FmhaFwdSplitKVKernel
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_v
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_bias
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
long_index_t
batch_offset_lse_acc
=
0
;
const
long_index_t
batch_offset_o_acc
=
long_index_t
batch_offset_o_acc
=
0
;
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
kIsGroupMode
)
if
constexpr
(
kIsGroupMode
)
{
{
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -492,9 +486,9 @@ struct FmhaFwdSplitKVKernel
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
query_start
=
kargs
.
seqstart_q_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
const
long_index_t
key_start
=
kargs
.
seqstart_k_ptr
[
i_batch
];
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_q
=
query_start
*
kargs
.
stride_q
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_k
=
key_start
*
kargs
.
stride_k
;
batch_offset_lse_acc
=
query_start
;
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
batch_offset_v
=
key_start
*
kargs
.
stride_v
;
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
...
@@ -508,6 +502,9 @@ struct FmhaFwdSplitKVKernel
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
batch_offset_bias
=
query_start
*
kargs
.
stride_bias
+
key_start
;
}
}
batch_offset_lse_acc
=
query_start
;
batch_offset_o_acc
=
query_start
*
kargs
.
stride_o_acc
;
// get real # queries & # keys under group mode
// get real # queries & # keys under group mode
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
kargs
.
seqlen_q
=
kargs
.
seqstart_q_ptr
[
i_batch
+
1
]
-
kargs
.
seqstart_q_ptr
[
i_batch
];
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -545,6 +542,7 @@ struct FmhaFwdSplitKVKernel
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_k
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_k
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_v
=
static_cast
<
long_index_t
>
(
i_cache_batch
)
*
kargs
.
batch_stride_v
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_lse_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_lse_acc
;
batch_offset_o_acc
=
static_cast
<
long_index_t
>
(
i_batch
)
*
kargs
.
batch_stride_o_acc
;
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
)
{
{
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
...
@@ -895,8 +893,8 @@ struct FmhaFwdSplitKVKernel
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
hdim_v
,
1
),
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentO
>
{},
number
<
1
>
{},
number
<
1
>
{});
number
<
1
>
{});
return
pad_tensor_view
(
return
pad_tensor_view
(
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
View file @
7ffb0921
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
...
@@ -20,12 +20,12 @@ struct FmhaFwdSplitKVTilePartitioner
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
seqlen_q
,
ck_tile
::
index_t
max_
seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
*
num_splits
,
nhead
*
num_splits
,
batch_size
);
batch_size
);
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp
View file @
7ffb0921
...
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -827,6 +827,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
},
},
s_acc
,
s_acc
,
bias_s_tile
);
bias_s_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
else
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ALIBI
)
{
{
...
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -918,6 +919,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
gemm_1
(
dv_acc
,
pt_reg_tensor
,
dot_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
1
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 4, OGrad@V Gemm2
// STAGE 4, OGrad@V Gemm2
auto
dp_acc
=
SPGradBlockTileType
{};
auto
dp_acc
=
SPGradBlockTileType
{};
...
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -927,6 +929,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
dp_acc
=
gemm_2
(
do_reg_tensor
,
v_reg_tensor
);
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
2
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 5, P^T(PGrad^T - D)
// STAGE 5, P^T(PGrad^T - D)
auto
ds
=
SPGradBlockTileType
{};
auto
ds
=
SPGradBlockTileType
{};
...
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -965,6 +968,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
Policy
::
template
MakeBiasTileDistribution
<
Problem
>());
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
shuffle_tile
(
dbias_tile
,
shuffled_dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
store_tile
(
dbias_dram_window
,
dbias_tile
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
// STAGE 6, SGrad^T@Q^T Gemm3
// STAGE 6, SGrad^T@Q^T Gemm3
...
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -984,6 +988,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
move_tile_window
(
ds_lds_read_window
,
{
0
,
kK4
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
3
>();
__builtin_amdgcn_sched_barrier
(
0
);
// STAGE 7, SGrad@K^T Gemm4
// STAGE 7, SGrad@K^T Gemm4
auto
dq_acc
=
QGradBlockTileType
{};
auto
dq_acc
=
QGradBlockTileType
{};
clear_tile
(
dq_acc
);
clear_tile
(
dq_acc
);
...
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
...
@@ -1005,6 +1010,7 @@ struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
});
});
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
HotLoopScheduler
::
template
GemmStagedScheduler
<
4
>();
__builtin_amdgcn_sched_barrier
(
0
);
// Results Scale
// Results Scale
if
constexpr
(
FmhaDropout
::
IsDropout
)
if
constexpr
(
FmhaDropout
::
IsDropout
)
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp
View file @
7ffb0921
...
@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1727,7 +1727,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
0
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
0
>
()
{
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
// Comp: Q x K
...
@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1759,7 +1759,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
1
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
1
>
()
{
{
// Mem: Q^T LDS load
// Mem: Q^T LDS load
// Comp: OGrad x V
// Comp: OGrad x V
...
@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1777,7 +1777,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
2
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
2
>
()
{
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
// Comp: PT x OGrad
...
@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1796,7 +1796,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
3
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
3
>
()
{
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
// Comp: SGradT x QT
...
@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
...
@@ -1830,7 +1830,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
}
}
template
<
>
template
<
>
CK_TILE_DEVICE
static
constexpr
void
GemmStagedScheduler
<
4
>
()
CK_TILE_DEVICE
constexpr
void
GemmStagedScheduler
<
4
>
()
{
{
// Mem: SGrad, OGrad, D LDS load.
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
// Comp: SGrad x KT
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp
View file @
7ffb0921
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -107,7 +107,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
const
OaccElementFunction
&
o_acc_element_func
,
index_t
num_splits
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
// lse_acc tile in LDS
// lse_acc tile in LDS
...
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -261,7 +261,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
auto
o_acc
=
make_static_distributed_tensor
<
OaccDataType
>
(
o_acc_dist
);
clear_tile
(
o_acc
);
clear_tile
(
o_acc
);
const
index_t
padded_
max_
seqlen_q
=
integer_divide_ceil
(
max_
seqlen_q
,
kM0
)
*
kM0
;
const
index_t
padded_seqlen_q
=
integer_divide_ceil
(
seqlen_q
,
kM0
)
*
kM0
;
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
for
(
index_t
i_split
=
0
;
i_split
<
num_splits
;
++
i_split
)
{
{
...
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -282,7 +282,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
});
});
}
}
move_tile_window
(
o_acc_dram_window
,
{
padded_
max_
seqlen_q
,
0
});
move_tile_window
(
o_acc_dram_window
,
{
padded_seqlen_q
,
0
});
}
}
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
o_acc
=
tile_elementwise_in
(
o_acc_element_func
,
o_acc
);
...
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -297,7 +297,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
const
OaccDramBlockWindow
&
o_acc_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
LSEDramBlockWindow
&
lse_dram_block_window
,
index_t
num_splits
,
index_t
num_splits
,
index_t
max_
seqlen_q
,
index_t
seqlen_q
,
void
*
smem_ptr
)
const
void
*
smem_ptr
)
const
{
{
return
operator
()(
lse_acc_dram_block_window
,
return
operator
()(
lse_acc_dram_block_window
,
...
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
...
@@ -306,7 +306,7 @@ struct BlockFmhaFwdSplitKVCombinePipeline
identity
{},
identity
{},
identity
{},
identity
{},
num_splits
,
num_splits
,
max_
seqlen_q
,
seqlen_q
,
smem_ptr
);
smem_ptr
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp
View file @
7ffb0921
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -64,8 +64,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -212,8 +210,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
const
auto
[
seqlen_k_start
,
seqlen_k_end
]
=
mask
.
GetTileRangeAlongX
(
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
q_origin
.
at
(
number
<
0
>
{}),
number
<
kM0
>
{},
number
<
kN0
>
{},
num_splits
,
i_split
);
// check early exit if
masked and
no work to do
.
// check early exit if no work to do
if
constexpr
(
FmhaMask
::
IsMasking
||
kHasUnevenSplits
)
if
constexpr
(
FmhaMask
::
IsMasking
||
kPadSeqLenK
||
kHasUnevenSplits
)
{
{
const
index_t
original_num_total_loop
=
const
index_t
original_num_total_loop
=
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
integer_divide_ceil
(
seqlen_k_end
-
seqlen_k_start
,
kN0
);
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
...
@@ -616,7 +614,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
o_spans
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
const
auto
tmp
=
[
&
]()
{
const
auto
tmp
=
[
&
]()
{
if
constexpr
(
FmhaMask
::
IsMasking
)
if
constexpr
(
BiasEnum
==
BlockAttentionBiasEnum
::
ELEMENTWISE_BIAS
||
FmhaMask
::
IsMasking
)
{
{
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
return
l
[
i_idx
]
==
0.
f
?
0.
f
:
1
/
l
[
i_idx
];
}
}
...
...
include/ck_tile/ops/image_to_column.hpp
0 → 100644
View file @
7ffb0921
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
0 → 100644
View file @
7ffb0921
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
>
struct
ImageToColumn
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I3
=
number
<
3
>
{};
static
constexpr
auto
I4
=
number
<
4
>
{};
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
InDataType
=
remove_cvref_t
<
typename
Problem
::
InDataType
>
;
using
OutDataType
=
remove_cvref_t
<
typename
Problem
::
OutDataType
>
;
static
constexpr
index_t
NDimSpatial
=
Problem
::
NDimSpatial
;
static
constexpr
index_t
AligmentIn
=
Problem
::
AligmentIn
;
static
constexpr
index_t
AligmentOut
=
Problem
::
AligmentOut
;
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
static
constexpr
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
index_t
kKPerBlock
=
Problem
::
BlockShape
::
kKPerBlock
;
struct
Kargs
{
const
void
*
p_in
;
void
*
p_out
;
const
long_index_t
G
;
const
long_index_t
N
;
const
long_index_t
C
;
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
void
*
p_in
,
void
*
p_out
,
const
long_index_t
G
,
const
long_index_t
N
,
const
long_index_t
C
,
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
,
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
,
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
,
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
)
{
return
Kargs
{
p_in
,
p_out
,
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
GemmM
,
index_t
GemmK
,
index_t
Batch
)
{
return
dim3
(
integer_divide_ceil
(
GemmM
,
kMPerBlock
),
integer_divide_ceil
(
GemmK
,
kKPerBlock
),
Batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
CK_TILE_DEVICE
auto
MakeImageMKDesc
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
kargs
.
N
,
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
C
),
make_tuple
(
kargs
.
image_g_n_c_wis_strides
[
I1
],
kargs
.
image_g_n_c_wis_strides
[
I3
],
kargs
.
image_g_n_c_wis_strides
[
I4
],
kargs
.
image_g_n_c_wis_strides
[
I2
]),
number
<
AligmentIn
>
{},
I1
);
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_left_pads
[
I0
],
kargs
.
input_right_pads
[
I0
]),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
input_left_pads
[
I1
],
kargs
.
input_right_pads
[
I1
]),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I0
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I0
],
kargs
.
conv_filter_strides
[
I0
])),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
output_spatial_lengths
[
I1
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I1
],
kargs
.
conv_filter_strides
[
I1
])),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{},
sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
N
,
kargs
.
output_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I1
])),
make_merge_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
C
))),
make_tuple
(
sequence
<
0
,
2
,
4
>
{},
sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
CK_TILE_DEVICE
auto
CalculateMKDims
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
index_t
M
=
kargs
.
N
*
static_cast
<
index_t
>
(
kargs
.
output_spatial_lengths
[
I0
]
*
kargs
.
output_spatial_lengths
[
I1
]);
const
index_t
K
=
kargs
.
C
*
static_cast
<
index_t
>
(
kargs
.
filter_spatial_lengths
[
I0
]
*
kargs
.
filter_spatial_lengths
[
I1
]);
return
make_tuple
(
M
,
K
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBlockTileDistribution
()
{
using
P
=
typename
Problem
::
BlockShape
;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
P
::
kMWarpPerBlock
,
P
::
kMThreadPerWarp
,
P
::
kMPerThread
>
,
sequence
<
P
::
kKWarpPerBlock
,
P
::
kKThreadPerWarp
,
P
::
kKPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
2
>>
{});
}
CK_TILE_DEVICE
void
ConvTensorRearrange
(
const
Kargs
&
kargs
)
const
{
const
auto
[
M
,
K
]
=
CalculateMKDims
(
kargs
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kMPerBlock
);
const
index_t
iK
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kKPerBlock
);
const
index_t
iBatch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
in_offset
=
iBatch
*
kargs
.
image_g_n_c_wis_strides
[
I0
];
const
auto
out_offset
=
iBatch
*
kargs
.
gemm_g_m_k_strides
[
I0
];
const
auto
image_m_k
=
make_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
InDataType
*>
(
kargs
.
p_in
)
+
in_offset
,
MakeImageMKDesc
(
kargs
));
const
auto
gemm_m_k
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
OutDataType
*>
(
kargs
.
p_out
)
+
out_offset
,
make_tuple
(
M
,
K
),
make_tuple
(
kargs
.
gemm_g_m_k_strides
[
I1
],
kargs
.
gemm_g_m_k_strides
[
I2
]),
number
<
AligmentOut
>
{},
I1
);
const
auto
image_m_k_padded
=
pad_tensor_view
(
image_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
const
auto
gemm_m_k_padded
=
pad_tensor_view
(
gemm_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
constexpr
auto
dstr
=
MakeBlockTileDistribution
();
const
auto
image_tile
=
make_tile_window
(
image_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
auto
gemm_tile
=
make_tile_window
(
gemm_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
// load from Global
const
auto
loaded_tile
=
load_tile
(
image_tile
);
// save to Global
store_tile
(
gemm_tile
,
loaded_tile
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
&
kargs
)
const
{
ConvTensorRearrange
(
kargs
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
0 → 100644
View file @
7ffb0921
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
InDataType_
,
typename
OutDataType_
,
typename
BlockShape_
,
index_t
NDimSpatial_
,
index_t
AligmentIn_
,
index_t
AligmentOut_
>
struct
BlockImageToColumnProblem
{
using
InDataType
=
remove_cvref_t
<
InDataType_
>
;
using
OutDataType
=
remove_cvref_t
<
OutDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
index_t
NDimSpatial
=
NDimSpatial_
;
static
constexpr
index_t
AligmentIn
=
AligmentIn_
;
static
constexpr
index_t
AligmentOut
=
AligmentOut_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
0 → 100644
View file @
7ffb0921
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ThreadTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
BlockTile
>
// Sequence<...
struct
TileImageToColumnShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kKThreadPerWarp
=
kKPerWarp
/
kKPerThread
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kKWarpPerBlock
=
kKPerBlock
/
kKPerWarp
;
static
constexpr
index_t
kBlockSize
=
warpSize
*
kMWarpPerBlock
*
kKWarpPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
7ffb0921
...
@@ -31,8 +31,14 @@ struct Layernorm2dFwd
...
@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
struct
Kargs
{
{
...
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
...
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence
<
2
>>
{});
sequence
<
2
>>
{});
}
}
template
<
typename
Dstr
>
CK_TILE_DEVICE
static
int
GetWelfordMaxCount
(
int
N
)
CK_TILE_DEVICE
static
constexpr
auto
GetNPerThread
(
Dstr
)
{
{
constexpr
auto
nDstrSpan
=
Dstr
::
get_distributed_spans
().
template
at
<
1
>();
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
using
Lengths
=
decltype
(
nDstrSpan
.
impl_
);
ck_tile
::
index_t
ret
=
1
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
int
max_count
=
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNPerThread
*
(
N
/
kNPerBlock
));
int
n_per_block_tail_loop
=
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
ck_tile
::
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
if
(
n_per_block_tail_loop
>
0
)
[
&
](
auto
idx
)
{
ret
*=
Lengths
::
template
at
(
idx
);
});
{
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
;
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
}
return
re
t
;
return
max_coun
t
;
}
}
template
<
typename
DistributedTensor
>
template
<
typename
DistributedTensor
>
...
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
...
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return
out_dstr_tensor
;
return
out_dstr_tensor
;
}
}
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
template
<
typename
XBlockWindow
,
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
const
XDataType
*
p_x
,
typename
GammaBlockWindow
,
const
GammaDataType
*
p_gamma
,
typename
BetaBlockWindow
,
const
BetaDataType
*
p_beta
,
typename
YBlockWindow
,
YDataType
*
p_y
,
typename
MeanBlockWindow
,
MeanDataType
*
p_mean
,
typename
InvStdBlockWindow
,
InvStdDataType
*
p_invStd
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
const
ComputeDataType
epsilon
,
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
ck_tile
::
index_t
M
,
TwoPassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
ck_tile
::
index_t
N
)
const
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
{
constexpr
auto
I0
=
number
<
0
>
{};
// TODO - Optimize tail loop to reduce move_tile_window()
constexpr
auto
I1
=
number
<
1
>
{};
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
const
auto
x_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
gamma_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
beta_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
32
>
{},
number
<
1
>
{});
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
N
/
kNPerBlock
);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr
auto
NPerThread
=
GetNPerThread
(
xDstr
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
type_convert
<
int
>
(
NPerThread
*
N
/
kNPerBlock
)};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
auto
mean_compute_block_tensor
=
...
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
...
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_mean
,
make_tuple
(
M
),
number
<
32
>
{});
auto
mean_block_window
=
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
}
if
constexpr
(
kSaveInvStd
)
if
constexpr
(
kSaveInvStd
)
{
store_tile
(
inv_std_block_window
,
const
auto
inv_std_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
p_invStd
,
make_tuple
(
M
),
number
<
32
>
{});
auto
inv_std_block_window
=
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
store_tile
(
inv_std_block_window
,
cast_tile
<
MeanDataType
>
(
inv_std_compute_block_tensor
));
}
// TODO: Extract normalize pipeline
const
auto
y_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
32
>
{},
number
<
1
>
{});
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
// reverse read x to reuse cache
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
N
-
kNPerBlock
;
ck_tile
::
index_t
stride_to_right_most_window
=
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
...
@@ -274,17 +243,209 @@ struct Layernorm2dFwd
...
@@ -274,17 +243,209 @@ struct Layernorm2dFwd
}
}
}
}
template
<
typename
XBlockWindow
,
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
auto
var_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// normalize
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
{
TwoPassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
const
auto
x_m_n
=
[
&
]()
{
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
N
,
1
),
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
number
<
kNPerThread
>
{},
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
number
<
1
>
{});
kargs
.
M
,
kargs
.
N
);
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
auto
mean_block_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
mean_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
auto
inv_std_block_window
=
[
&
]()
{
if
constexpr
(
kSaveInvStd
)
{
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
inv_std_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
if
(
kargs
.
N
<=
kNPerBlock
)
OnePassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
N
);
else
TwoPassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
N
);
}
}
};
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
7ffb0921
...
@@ -14,17 +14,21 @@ template <typename XDataType_,
...
@@ -14,17 +14,21 @@ template <typename XDataType_,
typename
YDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
>
typename
BlockShape_
,
bool
kPadM_
,
bool
kPadN_
>
struct
BlockLayernorm2dFwdProblem
struct
BlockLayernorm2dFwdProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
7ffb0921
...
@@ -64,9 +64,9 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -64,9 +64,9 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
# Do not build mha instances if gfx94 targets are not on the target list
# Do not build mha instances if gfx94
or gfx90a
targets are not on the target list
foreach
(
source IN LISTS ARGN
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx90a"
AND
source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
...
@@ -85,7 +85,7 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -85,7 +85,7 @@ function(add_instance_library INSTANCE_NAME)
elseif
(
ARGN MATCHES
"_wmma"
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"mha"
)
elseif
(
ARGN MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908
gfx90a
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
endif
()
set
(
offload_targets
)
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
foreach
(
target IN LISTS INST_TARGETS
)
...
@@ -320,8 +320,7 @@ if(CK_DEVICE_CONV_INSTANCES)
...
@@ -320,8 +320,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
set
(
gpu_list
${
INST_TARGETS
}
)
list
(
FILTER gpu_list INCLUDE REGEX
"^gfx94"
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
if
(
gpu_list
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
target_compile_features
(
device_mha_operations PUBLIC
)
...
...
script/cmake-ck-dev.sh
View file @
7ffb0921
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if
[
$#
-ge
2
]
;
then
if
[
$#
-ge
2
]
;
then
GPU_TARGETS
=
$2
GPU_TARGETS
=
$2
REST_ARGS
=
${
@
:3
}
else
else
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
REST_ARGS
=
fi
fi
cmake
\
cmake
\
...
@@ -20,4 +22,5 @@ cmake
...
@@ -20,4 +22,5 @@ cmake
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
$REST_ARGS
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
script/cmake-ck-release.sh
View file @
7ffb0921
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
...
@@ -7,8 +7,10 @@ MY_PROJECT_SOURCE=$1
if
[
$#
-ge
2
]
;
then
if
[
$#
-ge
2
]
;
then
GPU_TARGETS
=
$2
GPU_TARGETS
=
$2
REST_ARGS
=
${
@
:3
}
else
else
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
REST_ARGS
=
fi
fi
cmake
\
cmake
\
...
@@ -20,5 +22,6 @@ cmake
...
@@ -20,5 +22,6 @@ cmake
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
$REST_ARGS
\
${
MY_PROJECT_SOURCE
}
${
MY_PROJECT_SOURCE
}
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment