Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dd0188b3
Commit
dd0188b3
authored
Jan 06, 2025
by
Mateusz Ozga
Browse files
Merge remote-tracking branch 'origin/develop' into mozga-amd/universal_gemm_weight
parents
7ff4d613
888317e6
Changes
160
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
445 additions
and
390 deletions
+445
-390
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
+71
-7
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
..._tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
+30
-9
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
...fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
+0
-48
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
+30
-10
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
...ile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
+0
-54
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
...ude/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
+0
-105
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
+25
-7
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+120
-44
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
...tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+8
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+2
-0
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+2
-0
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+30
-27
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+27
-13
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-9
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+2
-0
include/ck_tile/ops/norm_reduce.hpp
include/ck_tile/ops/norm_reduce.hpp
+3
-3
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
include/ck_tile/ops/norm_reduce/block/block_norm_reduce.hpp
+78
-48
No files found.
include/ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp
View file @
dd0188b3
...
...
@@ -20,10 +20,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
...
@@ -84,7 +83,7 @@ struct FmhaFwdKernel
return
n
.
empty
()
?
n
:
std
::
string
(
"p"
)
+
n
;
}();
return
_SS_
(
"fmha_fwd_d"
)
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
_SS_
(
t2s
<
QDataType
>::
name
)
+
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
+
_SS_
(
TilePartitioner
::
name
)
+
"_"
"_"
+
(
kIsGroupMode
?
"group"
:
"batch"
)
+
"_"
"b"
+
_TS_
(
bfs
::
kM0
)
+
"x"
+
_TS_
(
bfs
::
kN0
)
+
"x"
+
_TS_
(
bfs
::
kK0
)
+
"x"
+
_TS_
(
bfs
::
kN1
)
+
"x"
+
_TS_
(
bfs
::
kK1
)
+
"x"
+
_TS_
(
bfs
::
kQKHeaddim
)
+
"_"
+
"r"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
0
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
1
>
{}))
+
"x"
+
_TS_
(
g0br
::
at
(
ck_tile
::
number
<
2
>
{}))
+
"_"
+
...
...
@@ -867,9 +866,75 @@ struct FmhaFwdKernel
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
ck_tile
::
index_t
hdim_v_
,
bool
has_padded_seqlen_k
=
false
)
{
return
TilePartitioner
::
GridSize
(
batch_size_
,
nhead_
,
seqlen_q_
,
hdim_v_
);
// has_padded_seqlen_k is determined by checking (seqlen_k_ptr != nullptr)
if
(
has_padded_seqlen_k
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
));
}
else
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
FmhaPipeline
::
kN1
),
nhead_
,
batch_size_
);
}
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
bool
has_padded_seqlen_k
=
false
;
if
constexpr
(
kIsGroupMode
)
has_padded_seqlen_k
=
(
kargs
.
seqlen_k_ptr
!=
nullptr
);
if
(
has_padded_seqlen_k
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
else
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -885,8 +950,7 @@ struct FmhaFwdKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp
View file @
dd0188b3
...
...
@@ -5,10 +5,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVCombineKernel
{
using
TilePartitioner
=
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
remove_cvref_t
<
EpiloguePipeline_
>
;
...
...
@@ -235,12 +234,35 @@ struct FmhaFwdSplitKVCombineKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
);
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -256,8 +278,7 @@ struct FmhaFwdSplitKVCombineKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp
deleted
100644 → 0
View file @
7ff4d613
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
index_t
kM0_
,
index_t
kN1_
>
struct
FmhaFwdSplitKVCombineTilePartitioner
{
static
constexpr
ck_tile
::
index_t
kM0
=
kM0_
;
static
constexpr
ck_tile
::
index_t
kN1
=
kN1_
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
),
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp
View file @
dd0188b3
...
...
@@ -17,10 +17,9 @@
namespace
ck_tile
{
template
<
typename
TilePartitioner_
,
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
template
<
typename
FmhaPipeline_
,
typename
EpiloguePipeline_
>
struct
FmhaFwdSplitKVKernel
{
using
TilePartitioner
=
ck_tile
::
remove_cvref_t
<
TilePartitioner_
>
;
using
FmhaPipeline
=
ck_tile
::
remove_cvref_t
<
FmhaPipeline_
>
;
using
EpiloguePipeline
=
ck_tile
::
remove_cvref_t
<
EpiloguePipeline_
>
;
static
constexpr
ck_tile
::
index_t
kBlockSize
=
FmhaPipeline
::
kBlockSize
;
...
...
@@ -476,13 +475,35 @@ struct FmhaFwdSplitKVKernel
return
kargs
;
}
__host__
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
return
TilePartitioner
::
GridSize
(
batch_size
,
nhead
,
max_seqlen_q
,
hdim_v
,
num_splits
);
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
FmhaPipeline
::
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
FmhaPipeline
::
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
CK_TILE_DEVICE
static
constexpr
auto
GetTileIndex
(
const
Kargs
&
kargs
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
kargs
.
hdim_v
,
FmhaPipeline
::
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
kBlockSize
);
}
...
...
@@ -498,8 +519,7 @@ struct FmhaFwdSplitKVKernel
__shared__
char
smem_ptr
[
GetSmemSize
()];
// divide problem
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
TilePartitioner
{}(
kargs
.
seqlen_q
,
kargs
.
hdim_v
,
kargs
.
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
]
=
GetTileIndex
(
kargs
);
const
index_t
i_m0
=
__builtin_amdgcn_readfirstlane
(
i_tile_m
*
FmhaPipeline
::
kM0
);
const
index_t
i_n1
=
__builtin_amdgcn_readfirstlane
(
i_tile_n
*
FmhaPipeline
::
kN1
);
...
...
include/ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp
deleted
100644 → 0
View file @
7ff4d613
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdSplitKVTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size
,
ck_tile
::
index_t
nhead
,
ck_tile
::
index_t
max_seqlen_q
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
max_seqlen_q
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
)
*
num_splits
,
nhead
,
batch_size
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
,
ck_tile
::
index_t
num_splits
)
{
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
mn
,
i_split
]
=
f
(
blockIdx
.
x
,
num_splits
);
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
mn
,
num_tile_n1
);
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_split
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp
deleted
100644 → 0
View file @
7ff4d613
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"shb"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
),
nhead_
,
batch_size_
);
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
x
;
const
index_t
i_nhead
=
blockIdx
.
y
;
const
index_t
i_batch
=
blockIdx
.
z
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
template
<
typename
BlockFmhaShape_
>
using
FmhaFwdTilePartitioner_SHB
=
FmhaFwdTilePartitioner
<
BlockFmhaShape_
>
;
template
<
typename
BlockFmhaShape_
>
struct
FmhaFwdTilePartitioner_HBS
{
using
BlockFmhaShape
=
ck_tile
::
remove_cvref_t
<
BlockFmhaShape_
>
;
static
constexpr
ck_tile
::
index_t
kM0
=
BlockFmhaShape
::
kM0
;
static
constexpr
ck_tile
::
index_t
kN0
=
BlockFmhaShape
::
kN0
;
static
constexpr
ck_tile
::
index_t
kK0
=
BlockFmhaShape
::
kK0
;
static
constexpr
ck_tile
::
index_t
kN1
=
BlockFmhaShape
::
kN1
;
static
constexpr
ck_tile
::
index_t
kK1
=
BlockFmhaShape
::
kK1
;
static
constexpr
const
char
*
name
=
"hbs"
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
batch_size_
,
ck_tile
::
index_t
nhead_
,
ck_tile
::
index_t
seqlen_q_
,
ck_tile
::
index_t
hdim_v_
)
{
// TODO: this may need tuning
return
dim3
(
nhead_
,
batch_size_
,
ck_tile
::
integer_divide_ceil
(
seqlen_q_
,
kM0
)
*
ck_tile
::
integer_divide_ceil
(
hdim_v_
,
kN1
));
}
CK_TILE_DEVICE
auto
operator
()(
ck_tile
::
index_t
/*seqlen_q*/
,
ck_tile
::
index_t
hdim_v
)
{
// const index_t num_tile_m0 = seqlen_q / kM0;
const
index_t
num_tile_n1
=
ck_tile
::
integer_divide_ceil
(
hdim_v
,
kN1
);
const
index_t
i_block
=
blockIdx
.
z
;
const
index_t
i_nhead
=
blockIdx
.
x
;
const
index_t
i_batch
=
blockIdx
.
y
;
const
auto
f
=
[](
index_t
dividend
,
index_t
divisor
)
{
index_t
quotient
=
dividend
/
divisor
;
index_t
modulus
=
dividend
-
quotient
*
divisor
;
return
ck_tile
::
make_tuple
(
quotient
,
modulus
);
};
const
auto
[
i_tile_m
,
i_tile_n
]
=
f
(
i_block
,
num_tile_n1
);
return
ck_tile
::
make_tuple
(
i_tile_m
,
i_tile_n
,
i_nhead
,
i_batch
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/batched_gemm_kernel.hpp
View file @
dd0188b3
...
...
@@ -67,9 +67,10 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
using
KernelArgs
=
BatchedGemmKernelArgs
;
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_count
)
__host__
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
batch_count
)
{
return
TilePartitioner
::
GridSize
(
M
,
N
,
batch_count
);
return
TilePartitioner
::
GridSize
(
M
,
N
,
KBatch
*
batch_count
);
}
__host__
static
constexpr
auto
BlockSize
()
{
return
dim3
(
Base
::
KernelBlockSize
);
}
...
...
@@ -85,7 +86,8 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
},
hostArgs
.
stride_C
,
hostArgs
.
k_batch
},
hostArgs
.
batch_stride_A
,
hostArgs
.
batch_stride_B
,
hostArgs
.
batch_stride_C
,
...
...
@@ -100,22 +102,38 @@ struct BatchedGemmKernel : public GemmKernel<TilePartitioner_, GemmPipeline_, Ep
CK_TILE_DEVICE
void
operator
()(
BatchedGemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
i_batch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
/
kargs
.
KBatch
);
const
auto
i_k
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
-
i_batch
*
kargs
.
KBatch
);
const
typename
Base
::
SplitKBatchOffset
splitk_batch_offset
(
kargs
,
i_k
);
// options
const
auto
batch_stride_A
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_A
);
const
auto
batch_offset_A
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_A
);
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
;
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
batch_offset_A
+
splitk_batch_offset
.
a_k_split_offset
;
const
auto
batch_stride_B
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_B
);
const
auto
batch_offset_B
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_B
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
;
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
batch_offset_B
+
splitk_batch_offset
.
b_k_split_offset
;
const
auto
batch_stride_C
=
__builtin_amdgcn_readfirstlane
(
kargs
.
batch_stride_C
);
const
auto
batch_offset_C
=
__builtin_amdgcn_readfirstlane
(
i_batch
*
batch_stride_C
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
)
+
batch_offset_C
;
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
if
(
kargs
.
KBatch
==
1
)
{
this
->
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
this
->
template
RunGemm
<
memory_operation_enum
::
atomic_add
>(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
};
...
...
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
dd0188b3
...
...
@@ -93,6 +93,7 @@ struct GemmKernel
index_t
stride_A
;
index_t
stride_B
;
index_t
stride_C
;
index_t
KBatch
;
};
CK_TILE_HOST
static
constexpr
GemmKernelArgs
MakeKernelArgs
(
const
GemmHostArgs
&
hostArgs
)
...
...
@@ -105,28 +106,72 @@ struct GemmKernel
hostArgs
.
K
,
hostArgs
.
stride_A
,
hostArgs
.
stride_B
,
hostArgs
.
stride_C
};
hostArgs
.
stride_C
,
hostArgs
.
k_batch
};
}
// CK_TILE_HOST static constexpr GemmKernelArgs MakeKernelArgs(const void* a_ptr,
// const void* b_ptr,
// void* c_ptr,
// index_t M,
// index_t N,
// index_t K,
// index_t stride_A,
// index_t stride_B,
// index_t stride_C)
// {
// return GemmKernelArgs{a_ptr, b_ptr, c_ptr, M, N, K, stride_A, stride_B, stride_C};
// }
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
max
(
GemmPipeline
::
GetSmemSize
(),
EpiloguePipeline
::
GetSmemSize
());
}
struct
SplitKBatchOffset
{
__device__
SplitKBatchOffset
(
const
GemmKernelArgs
&
kargs
,
const
std
::
size_t
k_id
=
blockIdx
.
z
)
{
constexpr
auto
K1
=
TilePartitioner
::
BlockGemmShape
::
WarpTile
::
at
(
number
<
2
>
{});
const
index_t
K_t
=
kargs
.
KBatch
*
K1
;
const
index_t
KRead
=
(
kargs
.
K
+
K_t
-
1
)
/
K_t
*
K1
;
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
a_k_split_offset
=
k_id
*
KRead
*
kargs
.
stride_A
;
}
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
*
kargs
.
stride_B
;
}
else
if
constexpr
(
std
::
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>
)
{
b_k_split_offset
=
k_id
*
KRead
;
}
if
(
k_id
<
static_cast
<
uint32_t
>
(
kargs
.
KBatch
-
1
))
{
splitted_k
=
KRead
;
}
else
{
splitted_k
=
kargs
.
K
-
KRead
*
(
kargs
.
KBatch
-
1
);
}
}
index_t
a_k_split_offset
;
index_t
b_k_split_offset
;
index_t
splitted_k
;
};
CK_TILE_HOST
static
bool
IsSupportedArgument
(
const
GemmKernelArgs
&
kargs
)
{
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
(
!
((
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
)
||
!
(
std
::
is_same_v
<
CDataType
,
fp16_t
>
||
std
::
is_same_v
<
CDataType
,
bf16_t
>
)))
{
if
(
kargs
.
KBatch
!=
1
)
{
return
false
;
}
}
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
(
kargs
.
K
%
TilePartitioner
::
kK
!=
0
&&
GemmPipeline
::
kPadK
==
false
)
...
...
@@ -198,17 +243,19 @@ struct GemmKernel
return
true
;
}
CK_TILE_DEVICE
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
)
const
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
auto
MakeGemmTensorViews
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
)
{
const
auto
&
a_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
...
...
@@ -217,7 +264,7 @@ struct GemmKernel
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
M
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
1
,
kargs
.
stride_A
),
number
<
1
>
{},
number
<
1
>
{});
...
...
@@ -229,7 +276,7 @@ struct GemmKernel
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
1
,
kargs
.
stride_B
),
number
<
1
>
{},
number
<
1
>
{});
...
...
@@ -238,7 +285,7 @@ struct GemmKernel
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_ptr
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
N
,
splitk_batch_offset
.
splitted_k
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
...
...
@@ -248,7 +295,7 @@ struct GemmKernel
const
auto
&
c_tensor_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
,
DstInMemOp
>
(
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
stride_C
,
1
),
...
...
@@ -257,7 +304,7 @@ struct GemmKernel
}
else
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
,
DstInMemOp
>
(
c_ptr
,
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_C
),
...
...
@@ -270,7 +317,7 @@ struct GemmKernel
}
template
<
typename
TensorView
>
CK_TILE_DEVICE
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
const
CK_TILE_DEVICE
static
auto
MakeGemmPadViews
(
const
TensorView
&
views
)
{
const
auto
&
a_pad_view
=
[
&
]()
{
const
auto
&
a_tensor_view
=
views
.
at
(
I0
);
...
...
@@ -330,8 +377,8 @@ struct GemmKernel
}
template
<
typename
PadView
>
CK_TILE_DEVICE
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
const
CK_TILE_DEVICE
static
auto
MakeGemmTileWindows
(
const
PadView
&
views
,
const
index_t
i_m
,
const
index_t
i_n
)
{
const
auto
&
a_pad_view
=
views
.
at
(
I0
);
const
auto
&
a_block_window
=
make_tile_window
(
...
...
@@ -363,23 +410,27 @@ struct GemmKernel
* @param kargs GEMM kernel arguments
* @param block_idx_m The GEMM's output M dimension tile index processed by this workgroup.
* @param block_idx_n The GEMM's output N dimension tile index processed by this workgroup.
*
* @tparam DstInMemOp Destination memory operation (default: set).
*/
CK_TILE_DEVICE
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
const
GemmKernelArgs
&
kargs
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
const
template
<
memory_operation_enum
DstInMemOp
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
static
void
RunGemm
(
const
ADataType
*
a_ptr
,
const
BDataType
*
b_ptr
,
CDataType
*
c_ptr
,
void
*
smem_ptr
,
const
GemmKernelArgs
&
kargs
,
const
SplitKBatchOffset
&
splitk_batch_offset
,
const
index_t
block_idx_m
,
const
index_t
block_idx_n
)
{
// Create Gemm tensor views, pad views and tile windows
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
);
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
const
auto
&
gemm_tensor_views_tuple
=
MakeGemmTensorViews
<
DstInMemOp
>
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
splitk_batch_offset
);
;
const
auto
&
gemm_pad_views
=
MakeGemmPadViews
(
gemm_tensor_views_tuple
);
auto
gemm_tile_windows
=
MakeGemmTileWindows
(
gemm_pad_views
,
block_idx_m
,
block_idx_n
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
kargs
.
K
);
const
index_t
num_loop
=
TilePartitioner
::
GetLoopNum
(
splitk_batch_offset
.
splitted_k
);
// Run GEMM cooperatively by whole workgroup.
const
auto
&
a_block_window
=
gemm_tile_windows
.
at
(
I0
);
...
...
@@ -389,18 +440,43 @@ struct GemmKernel
// Run Epilogue Pipeline
auto
&
c_block_window
=
gemm_tile_windows
.
at
(
I2
);
EpiloguePipeline
{}(
c_block_window
,
c_block_tile
);
constexpr
bool
is_output_c_reg_transposed
=
EpiloguePipeline
::
IsOutputTransposed
()
!=
GemmPipeline
::
IsTransposeC
();
if
constexpr
((
DstInMemOp
==
memory_operation_enum
::
set
)
||
(
sizeof
(
CDataType
)
>
2
)
||
(
GemmPipeline
::
VectorSizeC
%
2
==
0
&&
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
&&
is_output_c_reg_transposed
))
{
EpiloguePipeline
{}
.
template
operator
()
<
decltype
(
c_block_window
),
decltype
(
c_block_tile
),
DstInMemOp
>(
c_block_window
,
c_block_tile
);
}
}
CK_TILE_DEVICE
void
operator
()(
GemmKernelArgs
kargs
)
const
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
SplitKBatchOffset
splitk_batch_offset
(
kargs
);
// options
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
const
ADataType
*
a_ptr
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
)
+
splitk_batch_offset
.
a_k_split_offset
;
const
BDataType
*
b_ptr
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
)
+
splitk_batch_offset
.
b_k_split_offset
;
CDataType
*
c_ptr
=
static_cast
<
CDataType
*>
(
kargs
.
c_ptr
);
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
kargs
,
i_m
,
i_n
);
if
(
kargs
.
KBatch
==
1
)
{
RunGemm
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
else
{
RunGemm
<
memory_operation_enum
::
atomic_add
>
(
a_ptr
,
b_ptr
,
c_ptr
,
smem_ptr
,
kargs
,
splitk_batch_offset
,
i_m
,
i_n
);
}
}
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v3.hpp
View file @
dd0188b3
...
...
@@ -82,6 +82,8 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
dd0188b3
...
...
@@ -132,6 +132,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
:
public
PipelineImplBase
{
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
dd0188b3
...
...
@@ -53,6 +53,8 @@ struct GemmPipelineAGmemBGmemCRegV1
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
dd0188b3
...
...
@@ -13,6 +13,8 @@ namespace ck_tile {
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
static
constexpr
bool
TransposeC
=
false
;
#if 0
// 2d
template <typename Problem>
...
...
@@ -114,8 +116,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
constexpr
index_t
smem_size
=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
...
...
@@ -485,13 +486,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
constexpr
bool
TransposeC
=
false
;
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
constexpr
auto
I2
=
number
<
2
>
{};
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
dd0188b3
...
...
@@ -36,6 +36,8 @@ struct GemmPipelineAGmemBGmemCRegV2
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Policy
::
IsTransposeC
();
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
dd0188b3
...
...
@@ -444,6 +444,8 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
dd0188b3
...
...
@@ -4,8 +4,8 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
namespace
ck_tile
{
...
...
@@ -43,36 +43,38 @@ struct Layernorm2dFwdPipelineDefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
()
{
using
P_
=
Block
Welford
Problem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
return
Block
Welford
<
P_
>
{};
using
P_
=
Block
NormReduce
Problem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
NormReduce
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
Sync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
Sync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
Sync
<
P_
>
{};
return
Block
NormReduce
Sync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
Welford
CrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlock
NormReduce
CrossWarpSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
return
Block
Welford
CrossWarpSync
<
P_
>
{};
return
Block
NormReduce
CrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
...
...
@@ -80,19 +82,20 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
P_
=
BlockNormReduceProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
,
Problem
::
Traits
::
kWelford
>
;
using
block_welford
=
Block
Welford
<
P_
>
;
using
block_welford
=
Block
NormReduce
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
return
GetBlock
Welford
CrossWarpSync
<
Problem
>
()
return
GetBlock
NormReduce
CrossWarpSync
<
Problem
>
()
.
template
GetSmemSize
<
mean_var_block_tile
>();
}
else
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
dd0188b3
...
...
@@ -37,6 +37,7 @@ struct Layernorm2dFwdPipelineOnePass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -95,11 +96,16 @@ struct Layernorm2dFwdPipelineOnePass
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
auto
block_welford
=
Policy
::
template
GetBlockWelford
<
Problem
>();
auto
block_welford_sync
=
Policy
::
template
GetBlockWelfordSync
<
Problem
>();
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
auto
block_norm_reduce
=
Policy
::
template
GetBlockNormReduce
<
Problem
>();
auto
block_norm_reduce_sync
=
Policy
::
template
GetBlockNormReduceSync
<
Problem
>();
auto
block_norm_reduce_cross_warp_sync
=
Policy
::
template
GetBlockNormReduceCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
x
));
auto
mean
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
clear_tile
(
mean
);
clear_tile
(
var
);
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
...
...
@@ -117,12 +123,21 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
// compute reduce each-thread->cross-lane->cross-warp
block_norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_norm_reduce_sync
(
mean
,
var
,
cur_count
);
block_norm_reduce_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
if
(
kWelford
)
{
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
}
else
{
sweep_tile
(
mean
,
[
&
](
auto
idx
)
{
mean
(
idx
)
=
mean
(
idx
)
/
type_convert
<
MeanDataType
>
(
row_size
);
var
(
idx
)
=
var
(
idx
)
/
type_convert
<
MeanDataType
>
(
row_size
)
-
mean
(
idx
)
*
mean
(
idx
);
});
}
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
...
...
@@ -153,8 +168,7 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
dd0188b3
...
...
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineTwoPass
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
Traits
::
kWelford
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
...
@@ -77,6 +78,7 @@ struct Layernorm2dFwdPipelineTwoPass
void
*
smem
,
Epilogue
)
const
{
static_assert
(
kWelford
==
true
,
"2 pass only supports welford merge"
);
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
...
...
@@ -102,14 +104,14 @@ struct Layernorm2dFwdPipelineTwoPass
int
max_count
=
(
num_n_tile_iteration
-
1
)
*
count_per_iter
+
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
last_iter_n
);
auto
block_
welford
=
Policy
::
template
GetBlock
Welford
<
Problem
>();
auto
block_
welford
_sync
=
Policy
::
template
GetBlock
Welford
Sync
<
Problem
>();
auto
block_
welford
_cross_warp_sync
=
Policy
::
template
GetBlock
Welford
CrossWarpSync
<
Problem
>();
auto
block_
norm_reduce
=
Policy
::
template
GetBlock
NormReduce
<
Problem
>();
auto
block_
norm_reduce
_sync
=
Policy
::
template
GetBlock
NormReduce
Sync
<
Problem
>();
auto
block_
norm_reduce
_cross_warp_sync
=
Policy
::
template
GetBlock
NormReduce
CrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
)));
auto
mean
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
mean
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_
norm_reduce
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
...
...
@@ -133,11 +135,11 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_
welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
block_
norm_reduce
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
block_
welford
_sync
(
mean
,
var
,
cur_count
);
block_
welford
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_
norm_reduce
_sync
(
mean
,
var
,
cur_count
);
block_
norm_reduce
_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
,
constant
<
kFastFDiv
>
{});
// compute inv-std
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
dd0188b3
...
...
@@ -40,6 +40,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kFastFDiv_
,
bool
kWelford_
,
bool
kTwoPass_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
...
...
@@ -48,6 +49,7 @@ struct Layernorm2dFwdTraits
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kWelford
=
kWelford_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
...
...
include/ck_tile/ops/
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
.hpp
View file @
dd0188b3
...
...
@@ -3,8 +3,8 @@
#pragma once
#include "ck_tile/ops/
welford
/block/block_
welford
.hpp"
#include "ck_tile/ops/
welford
/block/block_
welford
_problem.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp"
#include "ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
_problem.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/
welford
/block/block_
welford
.hpp
→
include/ck_tile/ops/
norm_reduce
/block/block_
norm_reduce
.hpp
View file @
dd0188b3
...
...
@@ -4,22 +4,23 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/
welford
/thread/thread_welford.hpp"
#include "ck_tile/ops/
norm_reduce
/thread/thread_welford.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
struct
Block
NormReduce
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
XDataType
=
typename
Problem
::
XDataType
;
using
ComputeDataType
=
typename
Problem
::
ComputeDataType
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
CK_TILE_DEVICE
constexpr
Block
Welford
()
{}
CK_TILE_DEVICE
constexpr
Block
NormReduce
()
{}
// [CAUSION] - max_count_ is to deal with the padding problem
// max_count_ is depend on caller, eg: naive and splitN
welford
will have different
// max_count_ is depend on caller, eg: naive and splitN
norm_reduce
will have different
// calculation of max_count_
// -> use block_welford_calculate_max_count to compute
template
<
typename
XDistributedTensor_
,
...
...
@@ -40,18 +41,24 @@ struct BlockWelford
if
(
cur_count_
<
max_count_
)
{
++
cur_count_
;
sweep_tile_span
(
spans
[
I0
],
[
&
](
auto
dstr_idx_i0
)
{
constexpr
auto
in_dstr_idx
=
make_tuple
(
dstr_idx_i0
,
dstr_idx_i1
);
constexpr
auto
out_dstr_idx
=
make_tuple
(
dstr_idx_i0
);
auto
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
in_dstr_idx
]);
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
welford_update
(
mean_tensor
(
out_dstr_idx
),
var_tensor
(
out_dstr_idx
),
x
,
cur_count_
,
constant
<
kFastFDiv
>
{});
}
else
{
mean_tensor
(
out_dstr_idx
)
+=
x
;
var_tensor
(
out_dstr_idx
)
+=
x
*
x
;
}
});
}
});
...
...
@@ -91,10 +98,11 @@ struct BlockWelford
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
Sync
struct
Block
NormReduce
Sync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
template
<
typename
MeanDistributedTensor_
,
typename
VarDistributedTensor_
>
CK_TILE_DEVICE
void
...
...
@@ -152,36 +160,48 @@ struct BlockWelfordSync
(
number
<
lid_over_rid_derivative
<<
istage
.
value
>
{}.
value
);
// pull data from remote lane
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// welford merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
const
auto
v_remote_mean
=
warp_shuffle
(
v_local_mean
,
src_lane
);
const
auto
v_remote_var
=
warp_shuffle
(
v_local_var
,
src_lane
);
if
(
kWelford
)
{
const
auto
v_remote_count
=
warp_shuffle
(
v_local_count
,
src_lane
);
// norm_reduce merge
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
}
});
mean_tensor
.
get_thread_buffer
()(
i
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i
)
=
v_local_var
;
count
=
v_local_count
;
if
(
kWelford
)
{
count
=
v_local_count
;
}
});
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Block
Welford
CrossWarpSync
struct
Block
NormReduce
CrossWarpSync
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
BlockShape
=
typename
Problem
::
BlockShape
;
static
constexpr
bool
kFastFDiv
=
Problem
::
kFastFDiv
;
static
constexpr
bool
kWelford
=
Problem
::
kWelford
;
using
smem_dtype
=
std
::
conditional_t
<
kWelford
,
fp32x4_t
,
fp32x2_t
>
;
template
<
typename
MeanDistributedTensor_
>
CK_TILE_DEVICE
static
constexpr
index_t
GetReduceWarps
()
...
...
@@ -252,7 +272,7 @@ struct BlockWelfordCrossWarpSync
static_assert
(
thread_buf_size
==
VarDistributedTensor_
::
get_thread_buffer_size
());
// Note: we always pack everything into fp32x4
fp32x4_t
*
smem_ptr
=
reinterpret_cast
<
fp32x4_t
*>
(
smem
);
smem_dtype
*
smem_ptr
=
reinterpret_cast
<
smem_dtype
*>
(
smem
);
const
index_t
lane_id
=
get_lane_id
();
const
index_t
warp_id
=
get_warp_id
();
constexpr
auto
num_reduce_warps
=
GetReduceWarps
<
MeanDistributedTensor_
>
();
...
...
@@ -267,11 +287,13 @@ struct BlockWelfordCrossWarpSync
if
(
lane_id
==
0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i
)
{
fp32x4_t
local_scratch_
;
smem_dtype
local_scratch_
;
local_scratch_
[
0
]
=
bit_cast
<
float
>
(
mean_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
1
]
=
bit_cast
<
float
>
(
var_tensor
.
get_thread_buffer
()[
i
]);
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
if
(
kWelford
)
{
local_scratch_
[
2
]
=
bit_cast
<
float
>
(
count
);
}
smem_ptr
[
smem_offset
+
i
*
num_warps
]
=
local_scratch_
;
});
}
...
...
@@ -280,7 +302,7 @@ struct BlockWelfordCrossWarpSync
// load from smem. here we let everythread to do compute :)
index_t
local_warp_id
=
warp_id
/
num_reduce_warps
;
index_t
local_smem_os
=
local_warp_id
*
num_reduce_warps
;
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
smem_dtype
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
]
=
...
...
@@ -293,32 +315,40 @@ struct BlockWelfordCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
auto
v_local
=
all_scratch
[
i_0
*
num_reduce_warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
int
v_local_count
=
kWelford
?
bit_cast
<
int
>
(
v_local
[
2
])
:
0
;
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
smem_dtype
v_remote
=
all_scratch
[
i_0
*
num_reduce_warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
if
(
kWelford
)
{
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
welford_merge
(
v_local_mean
,
v_local_var
,
v_local_count
,
v_remote_mean
,
v_remote_var
,
v_remote_count
,
constant
<
kFastFDiv
>
{});
}
else
{
v_local_mean
+=
v_remote_mean
;
v_local_var
+=
v_remote_var
;
}
});
mean_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_mean
;
var_tensor
.
get_thread_buffer
()(
i_0
)
=
v_local_var
;
count
=
v_local_count
;
if
(
kWelford
)
count
=
v_local_count
;
});
}
};
...
...
Prev
1
2
3
4
5
6
7
8
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment