Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
cdfceb0a
Commit
cdfceb0a
authored
Jan 08, 2025
by
Astha Rai
Browse files
Merge branch 'codegen_hiprtc' of github.com:ROCm/composable_kernel into codegen_hiprtc
parents
b46349df
3b9a77df
Changes
47
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
293 additions
and
89 deletions
+293
-89
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+131
-47
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
...ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
+11
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
...k_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
+38
-10
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+1
-2
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
+5
-3
include/ck_tile/ops/fused_moe.hpp
include/ck_tile/ops/fused_moe.hpp
+1
-1
include/ck_tile/ops/gemm.hpp
include/ck_tile/ops/gemm.hpp
+1
-1
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+1
-1
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+1
-1
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+28
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+18
-2
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+2
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+34
-5
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+15
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+1
-1
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+1
-1
include/ck_tile/ops/reduce.hpp
include/ck_tile/ops/reduce.hpp
+1
-1
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+1
-1
include/ck_tile/ops/smoothquant.hpp
include/ck_tile/ops/smoothquant.hpp
+1
-1
include/ck_tile/ops/softmax.hpp
include/ck_tile/ops/softmax.hpp
+1
-1
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
cdfceb0a
...
@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel
...
@@ -47,10 +47,16 @@ struct FmhaFwdSplitKVKernel
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
FmhaPipeline
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
FmhaPipeline
::
Problem
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static
constexpr
bool
kIsPagedKV
=
FmhaPipeline
::
Problem
::
kIsPagedKV
;
static
constexpr
bool
kMergeNumHeadGroupsSeqLenQ
=
FmhaPipeline
::
Problem
::
kMergeNumHeadGroupsSeqLenQ
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
using
FmhaMask
=
ck_tile
::
remove_cvref_t
<
typename
FmhaPipeline
::
FmhaMask
>
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static
constexpr
bool
kHasMask
=
FmhaMask
::
IsMasking
;
static_assert
(
!
kMergeNumHeadGroupsSeqLenQ
||
(
kMergeNumHeadGroupsSeqLenQ
&&
BiasEnum
==
BlockAttentionBiasEnum
::
NO_BIAS
&&
!
kHasMask
));
// clang-format off
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
...
@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel
...
@@ -476,15 +482,20 @@ struct FmhaFwdSplitKVKernel
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
nhead_q
,
ck_tile
::
index_t
nhead_kv
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
ck_tile
::
index_t
num_splits
)
{
{
ck_tile
::
index_t
nhead_
=
kMergeNumHeadGroupsSeqLenQ
?
nhead_kv
:
nhead_q
;
ck_tile
::
index_t
max_seqlen_q_
=
max_seqlen_q
*
(
kMergeNumHeadGroupsSeqLenQ
?
nhead_q
/
nhead_kv
:
1
);
// TODO: this may need tuning
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
nhead
,
nhead
_
,
batch_size
);
batch_size
);
}
}
...
@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -562,7 +573,7 @@ struct FmhaFwdSplitKVKernel
// # of required blocks is different in each groups, terminate unnecessary blocks
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
// earlier
if
(
kargs
.
seqlen_q
<=
i_m0
)
if
(
kargs
.
seqlen_q
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
<=
i_m0
)
{
{
return
;
return
;
}
}
...
@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel
...
@@ -617,30 +628,60 @@ struct FmhaFwdSplitKVKernel
}
}
// for simplicity, batch stride we just modify the pointer
// for simplicity, batch stride we just modify the pointer
const
index_t
i_nhead_k
=
(
kMergeNumHeadGroupsSeqLenQ
?
i_nhead
:
i_nhead
/
kargs
.
nhead_ratio_qk
);
const
QDataType
*
q_ptr
=
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
q_ptr
)
+
const
QDataType
*
q_ptr
=
reinterpret_cast
<
const
QDataType
*>
(
kargs
.
q_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_q
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
*
kargs
.
nhead_stride_q
+
batch_offset_q
;
batch_offset_q
;
const
KDataType
*
k_ptr
=
const
KDataType
*
k_ptr
=
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
reinterpret_cast
<
const
KDataType
*>
(
kargs
.
k_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_k
+
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_k
+
batch_offset_k
;
batch_offset_k
;
const
VDataType
*
v_ptr
=
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
const
VDataType
*
v_ptr
=
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_v
+
reinterpret_cast
<
const
VDataType
*>
(
kargs
.
v_ptr
)
+
batch_offset_v
;
static_cast
<
long_index_t
>
(
i_nhead
/
kargs
.
nhead_ratio_qk
)
*
kargs
.
nhead_stride_v
+
batch_offset_v
;
ODataType
*
o_acc_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_acc_ptr
)
+
ODataType
*
o_acc_ptr
=
reinterpret_cast
<
ODataType
*>
(
kargs
.
o_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
kargs
.
nhead_stride_o_acc
+
static_cast
<
long_index_t
>
(
i_nhead
)
*
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
*
kargs
.
nhead_stride_o_acc
+
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
batch_offset_o_acc
+
i_split
*
kargs
.
split_stride_o_acc
;
// Q/K/V DRAM and DRAM window
// Q/K/V DRAM and DRAM window
const
auto
q_dram
=
[
&
]()
{
const
auto
q_dram
=
[
&
]
{
const
auto
q_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
q_dram_naive
=
[
&
]
{
q_ptr
,
if
constexpr
(
kMergeNumHeadGroupsSeqLenQ
)
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
{
make_tuple
(
kargs
.
stride_q
,
1
),
// reshape: (nhead_ratio_qk, seqlen_q, hdim_q) -> (nhead_ratio_qk * seqlen_q,
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
// hdim_q)
number
<
1
>
{});
const
auto
view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
nhead_stride_q
,
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
return
transform_tensor_view
(
view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
)),
make_pass_through_transform
(
kargs
.
hdim_q
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
q_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_q
),
make_tuple
(
kargs
.
stride_q
,
1
),
number
<
FmhaPipeline
::
kAlignmentQ
>
{},
number
<
1
>
{});
}
}();
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
if
constexpr
(
FmhaPipeline
::
kQLoadOnce
)
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
...
@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -729,7 +770,7 @@ struct FmhaFwdSplitKVKernel
}
}
}();
}();
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
auto
k_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kIsPagedKV
)
if
constexpr
(
kIsPagedKV
)
{
{
const
auto
*
block_indices
=
const
auto
*
block_indices
=
...
@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -739,8 +780,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_k
;
kargs
.
nhead_stride_k
;
return
make_page_block_navigator
<
const
KDataType
,
0
>
(
return
make_page_block_navigator
<
const
KDataType
,
0
>
(
kargs
.
k_ptr
,
kargs
.
k_ptr
,
...
@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -760,7 +800,7 @@ struct FmhaFwdSplitKVKernel
}
}
}();
}();
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
,
i_nhead_
=
i_nhead
]()
{
auto
v_page_block_navigator
=
[
&
,
i_batch_
=
i_batch
]()
{
if
constexpr
(
kIsPagedKV
)
if
constexpr
(
kIsPagedKV
)
{
{
const
auto
*
block_indices
=
const
auto
*
block_indices
=
...
@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel
...
@@ -770,8 +810,7 @@ struct FmhaFwdSplitKVKernel
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
integer_divide_ceil
(
kv_l2p_offset
+
kargs
.
seqlen_k
,
kargs
.
page_block_size
);
const
long_index_t
fixed_offset
=
const
long_index_t
fixed_offset
=
static_cast
<
long_index_t
>
(
i_nhead_
/
kargs
.
nhead_ratio_qk
)
*
static_cast
<
long_index_t
>
(
i_nhead_k
)
*
kargs
.
nhead_stride_v
;
kargs
.
nhead_stride_v
;
return
make_page_block_navigator
<
const
VDataType
,
1
>
(
return
make_page_block_navigator
<
const
VDataType
,
1
>
(
kargs
.
v_ptr
,
kargs
.
v_ptr
,
...
@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel
...
@@ -842,19 +881,40 @@ struct FmhaFwdSplitKVKernel
// lse acc
// lse acc
auto
lse_acc_dram_window
=
[
&
,
i_nhead_
=
i_nhead
,
i_split_
=
i_split
]()
{
auto
lse_acc_dram_window
=
[
&
,
i_nhead_
=
i_nhead
,
i_split_
=
i_split
]()
{
constexpr
auto
lse_acc_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{});
constexpr
auto
lse_acc_dram_window_lengths
=
make_tuple
(
number
<
FmhaPipeline
::
kM0
>
{});
LSEDataType
*
lse_acc_ptr
=
LSEDataType
*
lse_acc_ptr
=
reinterpret_cast
<
LSEDataType
*>
(
kargs
.
lse_acc_ptr
)
+
reinterpret_cast
<
LSEDataType
*>
(
kargs
.
lse_acc_ptr
)
+
static_cast
<
long_index_t
>
(
i_nhead_
)
*
static_cast
<
long_index_t
>
(
i_nhead_
)
*
kargs
.
nhead_stride_lse_acc
+
(
kMergeNumHeadGroupsSeqLenQ
?
kargs
.
nhead_ratio_qk
:
1
)
*
batch_offset_lse_acc
+
i_split_
*
kargs
.
split_stride_lse_acc
;
kargs
.
nhead_stride_lse_acc
+
batch_offset_lse_acc
+
i_split_
*
kargs
.
split_stride_lse_acc
;
const
auto
lse_acc_dram
=
[
&
]()
{
const
auto
lse_acc_dram_naive
=
const
auto
lse_acc_dram
=
[
&
]
{
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
const
auto
lse_acc_dram_naive
=
[
&
]
{
make_tuple
(
kargs
.
seqlen_q
),
if
constexpr
(
kMergeNumHeadGroupsSeqLenQ
)
make_tuple
(
1
),
{
number
<
1
>
{},
// reshape: (nhead_ratio_qk, seqlen_q) -> (nhead_ratio_qk * seqlen_q)
number
<
1
>
{});
const
auto
view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
),
make_tuple
(
kargs
.
nhead_stride_lse_acc
,
1
),
number
<
1
>
{},
number
<
1
>
{});
return
transform_tensor_view
(
view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
))),
make_tuple
(
sequence
<
0
,
1
>
{}),
make_tuple
(
sequence
<
0
>
{}));
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
lse_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
),
make_tuple
(
1
),
number
<
1
>
{},
number
<
1
>
{});
}
}();
return
pad_tensor_view
(
return
pad_tensor_view
(
lse_acc_dram_naive
,
lse_acc_dram_window_lengths
,
sequence
<
kPadSeqLenQ
>
{});
lse_acc_dram_naive
,
lse_acc_dram_window_lengths
,
sequence
<
kPadSeqLenQ
>
{});
}();
}();
...
@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel
...
@@ -953,13 +1013,37 @@ struct FmhaFwdSplitKVKernel
}();
}();
// Oacc DRAM and Oacc DRAM window
// Oacc DRAM and Oacc DRAM window
auto
o_acc_dram
=
[
&
]()
{
auto
o_acc_dram
=
[
&
]
{
const
auto
o_acc_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
o_acc_dram_naive
=
[
&
]
{
o_acc_ptr
,
if
constexpr
(
kMergeNumHeadGroupsSeqLenQ
)
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
{
make_tuple
(
kargs
.
stride_o_acc
,
1
),
// reshape: (nhead_ratio_qk, seqlen_q, hdim_v) -> (nhead_ratio_qk * seqlen_q,
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
// hdim_v)
number
<
1
>
{});
const
auto
view
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
nhead_stride_o_acc
,
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
return
transform_tensor_view
(
view
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
nhead_ratio_qk
,
kargs
.
seqlen_q
)),
make_pass_through_transform
(
kargs
.
hdim_v
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
o_acc_ptr
,
make_tuple
(
kargs
.
seqlen_q
,
kargs
.
hdim_v
),
make_tuple
(
kargs
.
stride_o_acc
,
1
),
number
<
FmhaPipeline
::
kAlignmentOacc
>
{},
number
<
1
>
{});
}
}();
return
pad_tensor_view
(
return
pad_tensor_view
(
o_acc_dram_naive
,
o_acc_dram_naive
,
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp
View file @
cdfceb0a
...
@@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem
...
@@ -94,16 +94,17 @@ struct BlockFmhaFwdSplitKVPipelineProblem
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
static
constexpr
bool
kIsGroupMode
=
kIsGroupMode_
;
// attributes from traits
// attributes from traits
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenQ
=
Traits
::
kPadSeqLenQ
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadSeqLenK
=
Traits
::
kPadSeqLenK
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimQ
=
Traits
::
kPadHeadDimQ
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Traits
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Traits
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Traits
::
kStoreLSE
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kDoFp8StaticQuant
=
Traits
::
kDoFp8StaticQuant
;
static
constexpr
bool
kIsPagedKV
=
Traits
::
kIsPagedKV
;
static
constexpr
bool
kIsPagedKV
=
Traits
::
kIsPagedKV
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
bool
kHasUnevenSplits
=
kIsGroupMode
||
Traits
::
kHasUnevenSplits
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
static
constexpr
bool
kMergeNumHeadGroupsSeqLenQ
=
Traits
::
kMergeNumHeadGroupsSeqLenQ
;
static
constexpr
index_t
kBlockPerCu
=
Traits
::
kBlockPerCu
;
};
};
// extract tile size attributes to remove dependency on traits
// extract tile size attributes to remove dependency on traits
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs.hpp
View file @
cdfceb0a
...
@@ -5,14 +5,14 @@
...
@@ -5,14 +5,14 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qs_ks_vs_default_policy.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
/// NOTICE: we no-longer use this pipeline.
// This pipeline is qkv all located in LDS
// This pipeline is qkv all located in LDS
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
template
<
typename
Problem_
,
typename
Policy_
=
BlockFmhaPipelineQSKSVSDefaultPolicy
>
struct
[[
deprecated
]]
BlockFmhaPipelineQSKSVS
struct
BlockFmhaPipelineQSKSVS
{
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
...
@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -51,6 +51,24 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
bool
kPadHeadDimV
=
Problem
::
kPadHeadDimV
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
auto
BiasEnum
=
Problem
::
BiasEnum
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kStoreLSE
=
Problem
::
kStoreLSE
;
static
constexpr
bool
kHasDropout
=
Problem
::
kHasDropout
;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static
constexpr
index_t
kAlignmentQ
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentQ
<
Problem
>();
static
constexpr
index_t
kAlignmentK
=
kPadHeadDimQ
?
1
:
Policy
::
template
GetAlignmentK
<
Problem
>();
static
constexpr
index_t
kAlignmentV
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
VLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
return
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
else
return
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentV
<
Problem
>();
}();
static
constexpr
index_t
kAlignmentO
=
kPadHeadDimV
?
1
:
Policy
::
template
GetAlignmentO
<
Problem
>();
static
constexpr
index_t
kAlignmentBias
=
kPadSeqLenK
?
1
:
Policy
::
template
GetAlignmentBias
<
Problem
>();
static
constexpr
index_t
kBlockPerCu
=
[]()
{
static
constexpr
index_t
kBlockPerCu
=
[]()
{
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
if
constexpr
(
Problem
::
kBlockPerCu
!=
-
1
)
...
@@ -81,6 +99,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -81,6 +99,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
static
constexpr
const
char
*
name
=
"qs"
;
static
constexpr
const
char
*
name
=
"qs"
;
using
DropoutType
=
std
::
conditional_t
<
kHasDropout
,
BlockDropout
,
NullBlockDropout
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -95,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -95,6 +115,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
QElementFunction
,
typename
QElementFunction
,
typename
KElementFunction
,
typename
KElementFunction
,
...
@@ -114,6 +135,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -114,6 +135,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const
VElementFunction
&
v_element_func
,
const
VElementFunction
&
v_element_func
,
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasElementFunction
&
bias_element_func
,
const
BiasElementFunction
&
bias_element_func
,
RandValDramBlockWindowTmp
&
/* unused_randval_dram_block_window_tmp */
,
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_window_tmp
,
// M0*1 tile
const
LSEElementFunction
&
lse_element_func
,
const
LSEElementFunction
&
lse_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
const
SAccElementFunction
&
s_acc_element_func
,
...
@@ -122,7 +144,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -122,7 +144,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
/* unused_dropout */
)
const
{
{
static_assert
(
static_assert
(
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
QDataType
,
remove_cvref_t
<
typename
QDramBlockWindowTmp
::
DataType
>>
&&
...
@@ -222,11 +245,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -222,11 +245,11 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
{
seqlen_k_start
,
0
});
{
seqlen_k_start
,
0
});
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
const
auto
bias_origin
=
bias_dram_block_window_tmp
.
get_window_origin
();
auto
bias_dram_window
=
make_tile_window
(
auto
bias_dram_window
=
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
bias_dram_block_window_tmp
.
get_bottom_tensor_view
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
bias_dram_block_window_tmp
.
get_window_lengths
(),
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
{
bias_origin
.
at
(
number
<
0
>
{}),
seqlen_k_start
},
// M/N
Policy
::
template
MakeBiasDramTileDistribution
<
Problem
,
decltype
(
gemm_0
)>());
Policy
::
template
MakeBiasDramTileDistribution
<
decltype
(
gemm_0
)>());
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
v_dram_block_window_tmp
.
get_bottom_tensor_view
(),
...
@@ -583,6 +606,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -583,6 +606,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
typename
KDramBlockWindowTmp
,
typename
KDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
VDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
BiasDramBlockWindowTmp
,
typename
RandValDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
LSEDramBlockWindowTmp
,
typename
PositionEncoding
>
typename
PositionEncoding
>
CK_TILE_HOST_DEVICE
auto
CK_TILE_HOST_DEVICE
auto
...
@@ -590,11 +614,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -590,11 +614,13 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
KDramBlockWindowTmp
&
k_dram_block_window_tmp
,
// N0*K0 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
VDramBlockWindowTmp
&
v_dram_block_window_tmp
,
// N1*K1 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
const
BiasDramBlockWindowTmp
&
bias_dram_block_window_tmp
,
// M0*N0 tile
RandValDramBlockWindowTmp
&
randval_dram_block_window_tmp
,
// M0*N0 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
LSEDramBlockWindowTmp
&
lse_dram_block_window_tmp
,
// M0*1 tile
FmhaMask
mask
,
FmhaMask
mask
,
PositionEncoding
position_encoding
,
PositionEncoding
position_encoding
,
float
scale_s
,
float
scale_s
,
void
*
smem_ptr
)
const
void
*
smem_ptr
,
DropoutType
&
dropout
)
const
{
{
return
operator
()(
q_dram_block_window_tmp
,
return
operator
()(
q_dram_block_window_tmp
,
identity
{},
identity
{},
...
@@ -604,6 +630,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -604,6 +630,7 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
identity
{},
identity
{},
bias_dram_block_window_tmp
,
bias_dram_block_window_tmp
,
identity
{},
identity
{},
randval_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
lse_dram_block_window_tmp
,
identity
{},
identity
{},
identity
{},
identity
{},
...
@@ -612,7 +639,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
...
@@ -612,7 +639,8 @@ struct [[deprecated]] BlockFmhaPipelineQSKSVS
mask
,
mask
,
position_encoding
,
position_encoding
,
scale_s
,
scale_s
,
smem_ptr
);
smem_ptr
,
dropout
);
}
}
};
};
...
...
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
cdfceb0a
...
@@ -125,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
...
@@ -125,9 +125,8 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
}
}
};
};
/// NOTICE: we no-longer use this policy.
template
<
>
template
<
>
struct
[[
deprecated
]]
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
struct
BlockFmhaPipelineQXCustomPolicy
<
/* QLoadOnce = */
false
>
{
{
static
constexpr
bool
QLoadOnce
=
false
;
static
constexpr
bool
QLoadOnce
=
false
;
...
...
include/ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp
View file @
cdfceb0a
...
@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...
@@ -43,7 +43,8 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool
kDoFp8StaticQuant_
,
bool
kDoFp8StaticQuant_
,
bool
kIsPagedKV_
,
bool
kIsPagedKV_
,
bool
kHasUnevenSplits_
,
bool
kHasUnevenSplits_
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
bool
kMergeNumHeadGroupsSeqLenQ_
=
false
,
index_t
kBlockPerCu_
=
-
1
/* overwrite occupancy if not -1 */
>
struct
TileFmhaFwdSplitKVTraits
struct
TileFmhaFwdSplitKVTraits
{
{
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
static
constexpr
bool
kPadSeqLenQ
=
kPadSeqLenQ_
;
...
@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits
...
@@ -56,8 +57,9 @@ struct TileFmhaFwdSplitKVTraits
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kDoFp8StaticQuant
=
kDoFp8StaticQuant_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
static
constexpr
bool
kIsPagedKV
=
kIsPagedKV_
;
// determine if some split (length) is not divisible by tile size
// determine if some split (length) is not divisible by tile size
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
static
constexpr
bool
kHasUnevenSplits
=
kHasUnevenSplits_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
static
constexpr
bool
kMergeNumHeadGroupsSeqLenQ
=
kMergeNumHeadGroupsSeqLenQ_
;
static
constexpr
index_t
kBlockPerCu
=
kBlockPerCu_
;
};
};
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
template
<
bool
kPadSeqLenQ_
/* padding for seqlen_q */
,
...
...
include/ck_tile/ops/fused_moe.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/gemm.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/image_to_column.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/layernorm2d.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
cdfceb0a
...
@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs
...
@@ -15,6 +15,7 @@ struct Layernorm2dFwdHostArgs
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
@@ -43,6 +44,7 @@ struct Layernorm2dFwd
...
@@ -43,6 +44,7 @@ struct Layernorm2dFwd
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
@@ -67,6 +69,7 @@ struct Layernorm2dFwd
...
@@ -67,6 +69,7 @@ struct Layernorm2dFwd
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -82,6 +85,7 @@ struct Layernorm2dFwd
...
@@ -82,6 +85,7 @@ struct Layernorm2dFwd
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
@@ -108,6 +112,7 @@ struct Layernorm2dFwd
...
@@ -108,6 +112,7 @@ struct Layernorm2dFwd
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_x_residual
,
hargs
.
p_x_scale
,
hargs
.
p_x_scale
,
hargs
.
p_x_bias
,
hargs
.
p_gamma
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_beta
,
hargs
.
p_y
,
hargs
.
p_y
,
...
@@ -152,6 +157,7 @@ struct Layernorm2dFwd
...
@@ -152,6 +157,7 @@ struct Layernorm2dFwd
using
S_
=
typename
Problem
::
BlockShape
;
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
std
::
string
n
;
if
(
kXbias
!=
Layernorm2dXBiasEnum
::
NO_BIAS
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dXBiasEnumName
<
kXbias
>::
name
;
if
(
kFusedAdd
!=
Layernorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedAdd
!=
Layernorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedQuant
!=
Layernorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kFusedQuant
!=
Layernorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kPadN
)
n
+=
"_pn"
;
...
@@ -228,6 +234,27 @@ struct Layernorm2dFwd
...
@@ -228,6 +234,27 @@ struct Layernorm2dFwd
}
}
}();
}();
const
auto
x_bias_window
=
[
&
]()
{
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XBiasDataType
*>
(
kargs
.
p_x_bias
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
@@ -371,6 +398,7 @@ struct Layernorm2dFwd
...
@@ -371,6 +398,7 @@ struct Layernorm2dFwd
Pipeline
{}(
x_window
,
Pipeline
{}(
x_window
,
x_residual_window
,
x_residual_window
,
x_bias_window
,
gamma_window
,
gamma_window
,
beta_window
,
beta_window
,
y_window
,
y_window
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
cdfceb0a
...
@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -18,6 +18,7 @@ struct Layernorm2dFwdPipelineOnePass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -38,6 +39,7 @@ struct Layernorm2dFwdPipelineOnePass
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -55,6 +57,7 @@ struct Layernorm2dFwdPipelineOnePass
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XResidualWindow
,
typename
XBiasWindow
,
typename
GammaWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YWindow
,
...
@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -66,6 +69,7 @@ struct Layernorm2dFwdPipelineOnePass
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XBiasWindow
&
x_bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window_
,
YWindow
&
y_window_
,
...
@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineOnePass
{
{
const
auto
x_window
=
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
x_bias_window
=
make_tile_window
(
x_bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
const
auto
beta_window
=
make_tile_window
(
...
@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -90,8 +96,9 @@ struct Layernorm2dFwdPipelineOnePass
auto
y_residual_window
=
make_tile_window
(
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
int
cur_count
=
0
;
int
cur_count
=
0
;
int
max_count
=
int
max_count
=
...
@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -112,6 +119,15 @@ struct Layernorm2dFwdPipelineOnePass
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
cdfceb0a
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
namespace
ck_tile
{
namespace
ck_tile
{
template
<
typename
XDataType_
,
template
<
typename
XDataType_
,
typename
XBiasDataType_
,
typename
GammaDataType_
,
typename
GammaDataType_
,
typename
BetaDataType_
,
typename
BetaDataType_
,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
...
@@ -21,6 +22,7 @@ template <typename XDataType_,
...
@@ -21,6 +22,7 @@ template <typename XDataType_,
struct
Layernorm2dFwdPipelineProblem
struct
Layernorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XBiasDataType
=
remove_cvref_t
<
XBiasDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
cdfceb0a
...
@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -17,6 +17,7 @@ struct Layernorm2dFwdPipelineTwoPass
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
...
@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -37,6 +38,7 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kXbias
=
Problem
::
Traits
::
kXbias
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -54,6 +56,7 @@ struct Layernorm2dFwdPipelineTwoPass
template
<
typename
XWindow
,
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
XResidualWindow
,
typename
XBiasWindow
,
typename
GammaWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YWindow
,
...
@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -65,6 +68,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
XBiasWindow
&
x_bias_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
YWindow
&
y_window
,
...
@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -81,6 +85,8 @@ struct Layernorm2dFwdPipelineTwoPass
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x_bias_window
=
make_tile_window
(
x_bias_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
auto
beta_window
=
make_tile_window
(
...
@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -115,13 +121,24 @@ struct Layernorm2dFwdPipelineTwoPass
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_bias_window
,
{
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
...
@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -167,6 +184,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -174,9 +192,19 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
const
auto
x_bias
=
load_tile
(
x_bias_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kXbias
==
Layernorm2dXBiasEnum
::
ADD_BIAS
)
{
sweep_tile
(
x
,
[
&
](
auto
idx
)
{
// compute x = bias + x
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_bias
[
j_idx
])
+
acc
(
idx
);
});
}
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
...
@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -209,6 +237,7 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_bias_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
cdfceb0a
...
@@ -7,6 +7,19 @@
...
@@ -7,6 +7,19 @@
namespace
ck_tile
{
namespace
ck_tile
{
enum
class
Layernorm2dXBiasEnum
{
NO_BIAS
=
0
,
// add bias before fused add
ADD_BIAS
=
1
,
};
// clang-format off
template
<
Layernorm2dXBiasEnum
>
struct
Layernorm2dXBiasEnumName
;
template
<
>
struct
Layernorm2dXBiasEnumName
<
Layernorm2dXBiasEnum
::
NO_BIAS
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dXBiasEnumName
<
Layernorm2dXBiasEnum
::
ADD_BIAS
>
{
static
constexpr
const
char
*
name
=
"xbias"
;
};
// clang-format on
enum
class
Layernorm2dFusedAddEnum
enum
class
Layernorm2dFusedAddEnum
{
{
NO_ADD
=
0
,
NO_ADD
=
0
,
...
@@ -42,6 +55,7 @@ template <bool kPadN_,
...
@@ -42,6 +55,7 @@ template <bool kPadN_,
bool
kFastFDiv_
,
bool
kFastFDiv_
,
bool
kWelford_
,
bool
kWelford_
,
bool
kTwoPass_
,
bool
kTwoPass_
,
Layernorm2dXBiasEnum
kXbias_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
Layernorm2dFusedQuantEnum
kFusedQuant_
>
struct
Layernorm2dFwdTraits
struct
Layernorm2dFwdTraits
...
@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits
...
@@ -51,6 +65,7 @@ struct Layernorm2dFwdTraits
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dXBiasEnum
kXbias
=
kXbias_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
};
...
...
include/ck_tile/ops/norm_reduce.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/permute.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/reduce.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/smoothquant.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/softmax.hpp
View file @
cdfceb0a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment