Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0475a327
Commit
0475a327
authored
Nov 04, 2024
by
dummycoderfe
Browse files
Merge branch 'ck_tile/layernorm2d_fwd_optimize' into ck_tile/ln_add_cache_clear
parents
c9b961ab
27ff3dec
Changes
267
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1157 additions
and
182 deletions
+1157
-182
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
+14
-10
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+413
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
...le/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
+71
-0
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+12
-12
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+4
-6
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+44
-9
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+7
-9
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
...e/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
+26
-26
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
+29
-29
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+1
-0
include/ck_tile/ops/layernorm2d.hpp
include/ck_tile/ops/layernorm2d.hpp
+2
-1
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+162
-31
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+6
-5
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+57
-15
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+7
-7
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+67
-19
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+54
-0
include/ck_tile/ops/permute.hpp
include/ck_tile/ops/permute.hpp
+9
-0
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
...ude/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
+169
-0
No files found.
include/ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp
View file @
0475a327
...
...
@@ -9,26 +9,30 @@ namespace ck_tile {
template
<
typename
BlockGemmShape_
>
struct
GemmTilePartitioner
{
using
BlockGemmShape
=
ck_tile
::
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
ck_tile
::
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
ck_tile
::
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
ck_tile
::
index_t
kK
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
kM
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kN
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kK
=
BlockGemmShape
::
kK
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
,
ck_tile
::
index_t
batch_size
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
M
,
index_t
N
,
index_t
batch_size
)
{
ck_tile
::
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
ck_tile
::
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
ck_tile
::
index_t
GridDimZ
=
batch_size
;
index_t
GridDimX
=
(
M
+
kM
-
1
)
/
kM
;
index_t
GridDimY
=
(
N
+
kN
-
1
)
/
kN
;
index_t
GridDimZ
=
batch_size
;
return
dim3
(
GridDimX
,
GridDimY
,
GridDimZ
);
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetLoopNum
(
index_t
K
)
{
return
integer_divide_ceil
(
K
,
kK
);
}
CK_TILE_DEVICE
auto
operator
()()
{
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kM
);
const
index_t
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kN
);
return
ck_tile
::
make_tuple
(
iM
,
iN
);
return
make_tuple
(
iM
,
iN
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template
<
typename
Problem
>
struct
BaseGemmPipelineAgBgCrMem
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
// TODO: Is this 32K value gfx9 arch specific?
static
constexpr
index_t
MinMemInFlyBytes
=
32768
;
static
constexpr
index_t
WgpPerCU
=
(
4
*
get_warp_size
()
/
BlockSize
)
>=
1
?
4
*
get_warp_size
()
/
BlockSize
:
1
;
static
constexpr
index_t
FullMemBandPrefetchStages
=
integer_divide_ceil
(
MinMemInFlyBytes
/
WgpPerCU
,
(
MPerBlock
*
sizeof
(
ADataType
)
+
NPerBlock
*
sizeof
(
BDataType
))
*
KPerBlock
);
static
constexpr
index_t
PrefetchStages
=
FullMemBandPrefetchStages
>=
2
?
FullMemBandPrefetchStages
<=
8
?
FullMemBandPrefetchStages
:
8
:
2
;
static
constexpr
index_t
LocalPrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
CK_TILE_HOST
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
CK_TILE_HOST
static
constexpr
TailNumber
GetBlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
return
TailNumber
::
One
;
}
else
if
(
num_loop
%
PrefetchStages
==
2
)
{
return
TailNumber
::
Two
;
}
else
if
(
num_loop
%
PrefetchStages
==
3
)
{
return
TailNumber
::
Three
;
}
else
if
(
num_loop
%
PrefetchStages
==
4
)
{
return
TailNumber
::
Four
;
}
else
if
(
num_loop
%
PrefetchStages
==
5
)
{
return
TailNumber
::
Five
;
}
else
if
(
num_loop
%
PrefetchStages
==
6
)
{
return
TailNumber
::
Six
;
}
else
if
(
num_loop
%
PrefetchStages
==
7
)
{
return
TailNumber
::
Seven
;
}
else
{
return
TailNumber
::
Full
;
}
}
};
// Maximum Global Memory throughput pipeline with >=32KB data in fly
// GlobalPrefetchStages: >=2
// LocalPreFillStages: 1
// LocalPreFetchStages: 0
// LocalSharedMemoryBuffer: 1
template
<
typename
Problem
,
typename
Policy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
>
struct
GemmPipelineAgBgCrMem
:
public
BaseGemmPipelineAgBgCrMem
<
Problem
>
{
using
Base
=
BaseGemmPipelineAgBgCrMem
<
Problem
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
using
BlockGemm
=
remove_cvref_t
<
decltype
(
Policy
::
template
GetBlockGemm
<
Problem
>())
>
;
using
I0
=
number
<
0
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
VectorSizeA
=
Problem
::
VectorSizeA
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
auto
TailNum
=
Problem
::
TailNum
;
static
constexpr
auto
Scheduler
=
Problem
::
Scheduler
;
using
Base
::
PrefetchStages
;
CK_TILE_HOST_DEVICE
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
GemmPipelineScheduler
Scheduler
>
struct
PipelineImpl
{
};
template
<
>
struct
PipelineImpl
<
GemmPipelineScheduler
::
Intrawave
>
{
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
});
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
CK_TILE_DEVICE
void
LocalPrefill
(
DstTileWindow
&
lds_tile_window
,
const
SrcBlockTile
&
src_block_tile
,
const
ElementFunction
&
element_func
)
const
{
const
auto
block_tile_tmp
=
tile_elementwise_in
(
element_func
,
src_block_tile
);
store_tile
(
lds_tile_window
,
block_tile_tmp
);
}
template
<
bool
HasHotLoop
,
TailNumber
TailNum
,
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
static_assert
(
std
::
is_same_v
<
ADataType
,
remove_cvref_t
<
typename
ADramBlockWindowTmp
::
DataType
>>
&&
std
::
is_same_v
<
BDataType
,
remove_cvref_t
<
typename
BDramBlockWindowTmp
::
DataType
>>
,
"A/B Dram block window should have the same data type as appropriate "
"([A|B]DataType) defined in Problem definition!"
);
static_assert
(
MPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
NPerBlock
==
BDramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
0
>
{}]
&&
KPerBlock
==
ADramBlockWindowTmp
{}.
get_window_lengths
()[
number
<
1
>
{}],
"A/B block window appropriate sizes must be equal to MPerBlock/NPerblock"
" or KPerBlock!"
);
// ------------------------------------------------------------------------------------
// Definitions of all needed tiles
// A tile in LDS
ADataType
*
p_a_lds
=
static_cast
<
ADataType
*>
(
p_smem
);
constexpr
auto
a_lds_block_desc
=
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>();
auto
a_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_a_lds
,
a_lds_block_desc
);
// TODO: LDS alignment should come from Policy!
constexpr
index_t
a_lds_block_space_size_aligned
=
integer_divide_ceil
(
sizeof
(
ADataType
)
*
a_lds_block_desc
.
get_element_space_size
(),
16
)
*
16
;
// B tile in LDS
BDataType
*
p_b_lds
=
static_cast
<
BDataType
*>
(
static_cast
<
void
*>
(
static_cast
<
char
*>
(
p_smem
)
+
a_lds_block_space_size_aligned
));
constexpr
auto
b_lds_block_desc
=
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>();
auto
b_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
p_b_lds
,
b_lds_block_desc
);
// A DRAM tile window for load
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
auto
a_copy_lds_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
auto
b_copy_lds_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// B LDS tile for block GEMM
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
// Block GEMM
constexpr
auto
block_gemm
=
BlockGemm
();
auto
c_block_tile
=
block_gemm
.
MakeCBlockTile
();
using
ABlockTileDistr
=
decltype
(
a_copy_dram_window
.
get_tile_distribution
());
using
BBlockTileDistr
=
decltype
(
b_copy_dram_window
.
get_tile_distribution
());
using
ABlockTile
=
decltype
(
make_static_distributed_tensor
<
ADataType
>
(
ABlockTileDistr
{}));
using
BBlockTile
=
decltype
(
make_static_distributed_tensor
<
BDataType
>
(
BBlockTileDistr
{}));
tuple_array
<
ABlockTile
,
PrefetchStages
>
a_block_tiles
;
tuple_array
<
BBlockTile
,
PrefetchStages
>
b_block_tiles
;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
GlobalPrefetch
(
a_block_tiles
.
get
(
I0
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
I0
{}),
b_copy_dram_window
);
// initialize C
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
I0
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
I0
{}),
b_element_func
);
// Global prefetch [1, PrefetchStages]
static_for
<
1
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
// main body
if
constexpr
(
HasHotLoop
)
{
index_t
i
=
0
;
do
{
static_for
<
0
,
PrefetchStages
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
(
prefetch_idx
+
1
)
%
PrefetchStages
>
{}),
b_element_func
);
GlobalPrefetch
(
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_copy_dram_window
);
GlobalPrefetch
(
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_copy_dram_window
);
});
i
+=
PrefetchStages
;
}
while
(
i
<
(
num_loop
-
PrefetchStages
));
}
auto
HotLoopTail
=
[
&
](
auto
tail_num
)
{
static_for
<
1
,
tail_num
,
1
>
{}([
&
](
auto
prefetch_idx
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
block_sync_lds
();
LocalPrefill
(
a_copy_lds_window
,
a_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
a_element_func
);
LocalPrefill
(
b_copy_lds_window
,
b_block_tiles
.
get
(
number
<
prefetch_idx
>
{}),
b_element_func
);
});
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
};
if
constexpr
(
TailNum
==
TailNumber
::
One
)
{
block_sync_lds
();
// block_gemm.LocalPrefetch();
block_gemm
(
c_block_tile
,
a_lds_gemm_window
,
b_lds_gemm_window
);
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Two
)
{
HotLoopTail
(
number
<
2
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Three
)
{
HotLoopTail
(
number
<
3
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Four
)
{
HotLoopTail
(
number
<
4
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Five
)
{
HotLoopTail
(
number
<
5
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Six
)
{
HotLoopTail
(
number
<
6
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Seven
)
{
HotLoopTail
(
number
<
7
>
{});
}
else
if
constexpr
(
TailNum
==
TailNumber
::
Full
)
{
HotLoopTail
(
number
<
PrefetchStages
>
{});
}
return
c_block_tile
;
}
};
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
,
typename
AElementFunction
,
typename
BElementFunction
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
AElementFunction
&
a_element_func
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BElementFunction
&
b_element_func
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
a_element_func
,
b_dram_block_window_tmp
,
b_element_func
,
num_loop
,
p_smem
);
}
template
<
typename
ADramBlockWindowTmp
,
typename
BDramBlockWindowTmp
>
CK_TILE_DEVICE
auto
operator
()(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
index_t
num_loop
,
void
*
p_smem
)
const
{
return
PipelineImpl
<
Scheduler
>
{}.
template
operator
()
<
HasHotLoop
,
TailNum
>(
a_dram_block_window_tmp
,
[](
const
ADataType
&
a
)
{
return
a
;
},
b_dram_block_window_tmp
,
[](
const
BDataType
&
b
)
{
return
b
;
},
num_loop
,
p_smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <ostream>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
enum
struct
GemmPipelineScheduler
{
Intrawave
,
Interwave
,
};
enum
struct
TailNumber
{
// Single / Double buffer pipeline
Odd
,
Even
,
// Long prefetch pipeline, up to 8
One
,
Two
,
Three
,
Four
,
Five
,
Six
,
Seven
,
// Unroll stages > Prefetch stages, number of loop is multiple of unroll stages
Empty
,
// Unroll stages <= Prefetch stages, number of loop is multiple of unroll stages add
// prefetchstages
Full
,
};
}
// namespace ck_tile
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck_tile
::
GemmPipelineScheduler
&
s
)
{
switch
(
s
)
{
case
ck_tile
::
GemmPipelineScheduler
::
Intrawave
:
os
<<
"Intrawave"
;
break
;
case
ck_tile
::
GemmPipelineScheduler
::
Interwave
:
os
<<
"Interwave"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
ck_tile
::
TailNumber
&
s
)
{
switch
(
s
)
{
case
ck_tile
::
TailNumber
::
Odd
:
os
<<
"Odd"
;
break
;
case
ck_tile
::
TailNumber
::
Even
:
os
<<
"Even"
;
break
;
case
ck_tile
::
TailNumber
::
One
:
os
<<
"One"
;
break
;
case
ck_tile
::
TailNumber
::
Two
:
os
<<
"Two"
;
break
;
case
ck_tile
::
TailNumber
::
Three
:
os
<<
"Three"
;
break
;
case
ck_tile
::
TailNumber
::
Four
:
os
<<
"Four"
;
break
;
case
ck_tile
::
TailNumber
::
Five
:
os
<<
"Five"
;
break
;
case
ck_tile
::
TailNumber
::
Six
:
os
<<
"Six"
;
break
;
case
ck_tile
::
TailNumber
::
Seven
:
os
<<
"Seven"
;
break
;
case
ck_tile
::
TailNumber
::
Empty
:
os
<<
"Empty"
;
break
;
case
ck_tile
::
TailNumber
::
Full
:
os
<<
"Full"
;
break
;
default:
os
<<
""
;
}
return
os
;
}
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -19,27 +19,27 @@ struct GemmPipelineAGmemBGmemCRegV1
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
Alignment
A
=
Problem
::
Alignment
A
;
static
constexpr
index_t
Alignment
B
=
Problem
::
Alignment
B
;
static
constexpr
index_t
Alignment
C
=
Problem
::
Alignment
C
;
static
constexpr
index_t
VectorSize
A
=
Problem
::
VectorSize
A
;
static
constexpr
index_t
VectorSize
B
=
Problem
::
VectorSize
B
;
static
constexpr
index_t
VectorSize
C
=
Problem
::
VectorSize
C
;
static
constexpr
bool
kPadA
=
Problem
::
kPadA
;
static
constexpr
bool
kPadB
=
Problem
::
kPadB
;
static
constexpr
bool
kPadC
=
Problem
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
Problem
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
Problem
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
Problem
::
LayoutC
>
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
...
...
@@ -48,7 +48,7 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -71,8 +71,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
...
...
@@ -93,7 +91,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeA
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
...
@@ -101,7 +99,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSizeB
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
...
...
@@ -109,7 +107,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -25,9 +25,9 @@ struct GemmPipelineAGmemBGmemCRegV2
static
constexpr
index_t
kNPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
kKPerBlock
=
BlockGemmShape
::
kK
;
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
ck_tile
::
integer_divide_ceil
(
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#define VectorLoadSize 16
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
static
constexpr
int
_VectorSize
=
16
;
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
...
...
@@ -22,18 +23,52 @@ struct GemmPipelineProblem
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutC
>
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
1
:
_VectorSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
VectorSizeB
=
kPadB
?
1
:
_VectorSize
/
sizeof
(
BDataType
);
static
constexpr
index_t
VectorSizeC
=
kPadC
?
1
:
_VectorSize
/
sizeof
(
CDataType
);
};
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
Alignment
A
=
kPadA
?
1
:
Vector
Load
Size
/
sizeof
(
ADataType
);
static
constexpr
index_t
Alignment
B
=
kPadB
?
1
:
Vector
Load
Size
/
sizeof
(
BDataType
);
static
constexpr
index_t
Alignment
C
=
kPadC
?
1
:
Vector
Load
Size
/
sizeof
(
CDataType
);
static
constexpr
index_t
VectorSize
A
=
kPadA
?
_
VectorSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
VectorSize
B
=
kPadB
?
_
VectorSize
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
VectorSize
C
=
kPadC
?
_
VectorSize
/
sizeof
(
CDataType
)
:
1
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPadA_
,
bool
kPadB_
,
bool
kPadC_
,
typename
Layout
A
_
,
typename
Layout
B
_
,
typename
Layout
C
_
>
typename
A
Layout_
,
typename
B
Layout_
,
typename
C
Layout_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
using
Layout
A
=
Layout
A
_
;
using
Layout
B
=
Layout
B
_
;
using
Layout
C
=
Layout
C
_
;
using
A
Layout
=
A
Layout_
;
using
B
Layout
=
B
Layout_
;
using
C
Layout
=
C
Layout_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -39,9 +39,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
...
...
@@ -52,8 +52,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_32x32x8f16
(
a_vec
,
b_vec
,
fp32x16_t
{
0.
f
},
0
,
0
,
0
));
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
...
...
@@ -90,9 +90,9 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
#if defined(__gfx9__)
c_vec
=
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
c_vec
,
0
,
0
,
0
);
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
...
...
@@ -103,8 +103,8 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return
bit_cast
<
CVecType
>
(
__builtin_amdgcn_mfma_f32_16x16x16f16
(
a_vec
,
b_vec
,
fp32x4_t
{
0.
f
},
0
,
0
,
0
));
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
...
...
@@ -154,9 +154,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
...
...
@@ -181,8 +181,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
...
...
@@ -231,9 +231,9 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
...
...
@@ -258,8 +258,8 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
...
...
@@ -320,9 +320,9 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
c_vec
=
__builtin_amdgcn_mfma_f32_32x32x2f32
(
a_f32
,
b_f32
,
c_vec
,
0
,
0
,
0
);
});
#else
ck_tile
::
ignore
=
c_vec
;
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
c_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
#endif
}
...
...
@@ -356,8 +356,8 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
});
return
c_vec
;
#else
ck_tile
::
ignore
=
a_vec
;
ck_tile
::
ignore
=
b_vec
;
ignore
=
a_vec
;
ignore
=
b_vec
;
return
CVecType
{
0.
f
};
#endif
}
...
...
include/ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off
// fp16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
half_t
,
ck_tile
::
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
half_t
,
half_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaF16F16F32M32N32K16SwizzleA
;
};
// bf16
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
16
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
false
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
16
,
16
,
32
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf16_t
,
ck_tile
::
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
8
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf16_t
,
bf16_t
,
float
,
32
,
32
,
16
,
false
,
true
>
{
using
Type
=
WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA
;
};
// fp8
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
fp8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
ck_tile
::
bf8_t
,
ck_tile
::
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
fp8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
fp8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
false
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8
;
};
template
<
>
struct
WarpGemmMfmaDispatcher
<
bf8_t
,
bf8_t
,
float
,
32
,
32
,
16
,
true
>
{
using
Type
=
WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed
;
};
// clang-format on
}
// namespace impl
...
...
include/ck_tile/ops/image_to_column.hpp
View file @
0475a327
...
...
@@ -6,4 +6,5 @@
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d.hpp
View file @
0475a327
...
...
@@ -4,9 +4,10 @@
#pragma once
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp"
#include "ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_shape.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
0475a327
...
...
@@ -5,19 +5,24 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
namespace
ck_tile
{
// host side args
struct
Layernorm2dFwdHostArgs
{
const
void
*
p_x
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
void
*
p_y
;
void
*
p_mean
;
void
*
p_invStd
;
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_y_residual
;
// [m, n], shortcut output, prec same as input, nullptr if not used
void
*
p_y_scale
;
// [m, 1], output a dynamic quant per row, nullptr if not used
void
*
p_mean
;
// [m, 1], output mean, prec same as input, nullptr if not used
void
*
p_invStd
;
// [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
float
epsilon
;
...
...
@@ -27,10 +32,11 @@ struct Layernorm2dFwdHostArgs
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
,
typename
Epilogue_
>
struct
Layernorm2dFwd
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
...
...
@@ -40,18 +46,26 @@ struct Layernorm2dFwd
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
XScaleDataType
=
remove_cvref_t
<
typename
Problem
::
XScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
null_type
>
;
static
constexpr
bool
kSaveMeanInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kSaveMeanInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
...
...
@@ -62,13 +76,18 @@ struct Layernorm2dFwd
struct
Kargs
{
const
void
*
p_x
;
const
void
*
p_gamma
;
const
void
*
p_beta
;
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
void
*
p_y
;
void
*
p_mean
;
void
*
p_invStd
;
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_y_residual
;
// [m, n], shortcut output, prec same as input, nullptr if not used
void
*
p_y_scale
;
// [m, 1], output a dynamic quant per row, nullptr if not used
void
*
p_mean
;
// [m, 1], output mean, prec same as input, nullptr if not used
void
*
p_invStd
;
// [m, 1], output inv-stdvariance, prec same as input, nullptr if not used
float
epsilon
;
...
...
@@ -81,9 +100,13 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_x_scale
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_y
,
hargs
.
p_y_residual
,
hargs
.
p_y_scale
,
hargs
.
p_mean
,
hargs
.
p_invStd
,
hargs
.
epsilon
,
...
...
@@ -94,7 +117,7 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
return
dim3
(
integer_divide_ceil
(
hargs
.
m
,
Block_M
))
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
...
...
@@ -106,6 +129,7 @@ struct Layernorm2dFwd
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
// in byte
...
...
@@ -113,24 +137,41 @@ struct Layernorm2dFwd
CK_TILE_HOST
static
std
::
string
GetName
()
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kFusedAdd
!=
Layernorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedQuant
!=
Layernorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Layernorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveMeanInvStd
)
n
+=
"_mv"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
//
if (kTwoPass) n += "_2p";
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
auto
prec_str
=
[
&
]
()
{
std
::
string
base_str
=
_SS_
(
t2s
<
XDataType
>::
name
);
if
(
!
std
::
is_same_v
<
XDataType
,
YDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
XScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
#undef _SS_
#undef _TS_
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
...
@@ -153,6 +194,31 @@ struct Layernorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
x_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
// NOTE: we don't do any pad in this kernel for loading, assume that inside kernel
// will check the max count dynamically
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
false
,
false
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
...
@@ -194,6 +260,28 @@ struct Layernorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
auto
y_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
auto
mean_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
...
...
@@ -232,17 +320,60 @@ struct Layernorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
auto
x_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_x_scale
),
make_tuple
(
kargs
.
n
),
number
<
Vector_N
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
// x_scale no need pad
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}();
auto
y_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_y_scale
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x_residual_window
,
gamma_window
,
beta_window
,
y_window
,
y_residual_window
,
mean_window
,
inv_std_window
,
x_scale_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
smem
);
smem
,
Epilogue
{});
}
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
0475a327
...
...
@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
{
...
...
@@ -44,7 +45,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
...
...
@@ -54,7 +55,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
...
...
@@ -64,7 +65,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
...
...
@@ -76,13 +77,13 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
X
DataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
Compute
DataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
0475a327
...
...
@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp"
#include <string>
#include <type_traits>
...
...
@@ -24,14 +25,19 @@ struct Layernorm2dFwdPipelineOnePass
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -46,20 +52,30 @@ struct Layernorm2dFwdPipelineOnePass
}
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
typename
InvStdWindow
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
YWindow
&
y_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
const
XScaleWindow
&
x_scale_window_
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
...
...
@@ -67,8 +83,14 @@ struct Layernorm2dFwdPipelineOnePass
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
const
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
const
auto
x
=
load_tile
(
x_window
);
int
cur_count
=
0
;
int
max_count
=
block_tile_welford_calculate_max_count
<
typename
Problem
::
BlockShape
>
(
row_size
);
...
...
@@ -81,8 +103,21 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
...
...
@@ -90,7 +125,8 @@ struct Layernorm2dFwdPipelineOnePass
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
},
var
);
...
...
@@ -100,20 +136,26 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// layernorm computation
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
auto
ln
=
make_static_distributed_tensor
<
Compute
DataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
ln
(
idx
)
=
ln_
;
});
store_tile
(
y_window
,
y
);
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
x_scale_window_
,
y_scale_window
,
ln
,
smem
);
}
else
Epilogue
{}(
y_window_
,
ln
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -14,10 +14,10 @@ template <typename XDataType_,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
XScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
>
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
...
...
@@ -27,14 +27,14 @@ struct Layernorm2dFwdPipelineProblem
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
XScaleDataType
=
remove_cvref_t
<
XScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
0475a327
...
...
@@ -24,20 +24,25 @@ struct Layernorm2dFwdPipelineTwoPass
using
MeanDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasBeta
=
!
std
::
is_same_v
<
BetaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveMean
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveMean
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kSaveInvStd
=
Problem
::
Traits
::
kSaveMeanInvStd
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr"
;
// block per row
return
"bpr
_2p
"
;
// block per row
else
return
"wpr"
;
// warp per row
return
"wpr
_2p
"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
...
...
@@ -46,20 +51,30 @@ struct Layernorm2dFwdPipelineTwoPass
}
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
BetaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
InvStdWindow
>
typename
InvStdWindow
,
typename
XScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
BetaWindow
&
beta_window_
,
YWindow
&
y_window
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
const
XScaleWindow
&
/*x_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
...
...
@@ -67,6 +82,10 @@ struct Layernorm2dFwdPipelineTwoPass
gamma_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
beta_window
=
make_tile_window
(
beta_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
...
@@ -87,15 +106,33 @@ struct Layernorm2dFwdPipelineTwoPass
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
))
)
;
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
block_welford_sync
(
mean
,
var
,
cur_count
);
...
...
@@ -118,9 +155,8 @@ struct Layernorm2dFwdPipelineTwoPass
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
// x_window.foo();
// gamma_window.foo();
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
beta_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
...
...
@@ -128,29 +164,41 @@ struct Layernorm2dFwdPipelineTwoPass
// layernorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
}
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
Compute
DataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
acc
(
idx
)
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
ln
(
idx
)
=
ln_
;
});
store_tile
(
y_window
,
y
);
static_assert
(
kFusedQuant
!=
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
);
Epilogue
{}(
y_window
,
ln
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
beta_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
enum
class
Layernorm2dFusedAddEnum
{
NO_ADD
=
0
,
// fused add before layernorm and store result to global
PRE_ADD_STORE
=
1
,
// fused add before layernorm, but not store result
PRE_ADD
=
2
,
};
// clang-format off
template
<
Layernorm2dFusedAddEnum
>
struct
Layernorm2dFusedAddEnumName
;
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
NO_ADD
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
>
{
static
constexpr
const
char
*
name
=
"pras"
;
};
template
<
>
struct
Layernorm2dFusedAddEnumName
<
Layernorm2dFusedAddEnum
::
PRE_ADD
>
{
static
constexpr
const
char
*
name
=
"pra"
;
};
// clang-format on
enum
class
Layernorm2dFusedQuantEnum
{
NO_SWEEP
=
0
,
SMOOTH_DYNAMIC_QUANT
=
1
,
// smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT
=
2
,
// rowwise quant, store out a y-scale
};
// clang-format off
template
<
Layernorm2dFusedQuantEnum
>
struct
Layernorm2dFusedQuantEnumName
;
template
<
>
struct
Layernorm2dFusedQuantEnumName
<
Layernorm2dFusedQuantEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Layernorm2dFusedQuantEnumName
<
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dqt"
;
};
template
<
>
struct
Layernorm2dFusedQuantEnumName
<
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"smdqt"
;
};
// clang-format on
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kTwoPass_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
struct
Layernorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
}
// namespace ck_tile
include/ck_tile/ops/permute.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/permute/kernel/generic_permute_kernel.hpp"
#include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/permute/kernel/generic_permute_kernel.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
// #include "ck_tile/ops/permute/pipeline/generic_petmute_problem.hpp"
namespace
ck_tile
{
/* independent host side argument, no template
*/
struct
GenericPermuteHostArgs
{
static
constexpr
index_t
kMaxRanks
=
8
;
// TODO: hardcoded
const
void
*
p_src
;
void
*
p_dst
;
index_t
rank
;
index_t
shape
[
kMaxRanks
];
// input shape
index_t
perm
[
kMaxRanks
];
// permute index
};
/*
simulate torch.permute:
x_ = x_.view(x.shape[0],
x.shape[1]//16, 16,
x.shape[2]//32, 4, 8)
x_ = x_.permute(0,1,3,4,2,5)
x_ = x_.contiguous()
x_ = x_.view(x.shape[0], x.shape[1], x.shape[2]);//
this kernel is supposed not to be performant(just OK), with functional support up to kMaxRanks
dim of permutation, with a single kernel
*/
template
<
typename
Problem_
>
struct
GenericPermute
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
DataType
=
remove_cvref_t
<
typename
Problem
::
DataType
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMaxRanks
=
Problem
::
kMaxRanks
;
static
constexpr
bool
KeepLastDim
=
Problem
::
KeepLastDim
;
struct
__attribute__
((
packed
))
Kargs
{
const
void
*
p_src
;
void
*
p_dst
;
// index_t rank;
index_t
num_elements
;
index_t
perm_length
[
kMaxRanks
];
// tensor length after permutation
index_t
perm_stride
[
kMaxRanks
];
// tensor stride after permutation
};
CK_TILE_HOST
static
constexpr
index_t
TotalElements
(
const
GenericPermuteHostArgs
&
h
)
{
index_t
n
=
1
;
for
(
auto
i
=
0
;
i
<
h
.
rank
;
i
++
)
{
n
*=
h
.
shape
[
i
];
}
return
n
;
}
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
GenericPermuteHostArgs
&
h
)
{
Kargs
a
;
a
.
p_src
=
h
.
p_src
;
a
.
p_dst
=
h
.
p_dst
;
// assert rank <= kMaxRanks
index_t
i
=
0
;
index_t
perm
[
kMaxRanks
];
index_t
x_shape
[
kMaxRanks
];
index_t
x_stride
[
kMaxRanks
];
// index_t perm_length[kMaxRanks];
for
(;
i
<
h
.
rank
;
i
++
)
{
x_shape
[
i
]
=
h
.
shape
[
i
];
perm
[
i
]
=
h
.
perm
[
i
];
}
for
(;
i
<
kMaxRanks
;
i
++
)
{
x_shape
[
i
]
=
1
;
perm
[
i
]
=
i
;
// will index to len = 1
}
index_t
stride
=
1
;
for
(
index_t
j
=
kMaxRanks
-
1
;
j
>=
0
;
j
--
)
{
x_stride
[
j
]
=
stride
;
stride
*=
x_shape
[
j
];
}
for
(
index_t
j
=
0
;
j
<
kMaxRanks
;
j
++
)
{
a
.
perm_length
[
j
]
=
x_shape
[
perm
[
j
]];
a
.
perm_stride
[
j
]
=
x_stride
[
perm
[
j
]];
}
a
.
num_elements
=
TotalElements
(
h
);
return
a
;
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
GenericPermuteHostArgs
h
)
{
auto
total
=
TotalElements
(
h
);
auto
grids
=
dim3
((
total
+
BlockSize
()
-
1
)
/
BlockSize
());
// printf("### total:%d, grids:%dx%dx%d\n", total, );
return
grids
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
kBlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
id
=
blockIdx
.
x
*
BlockSize
()
+
threadIdx
.
x
;
if
(
id
>=
kargs
.
num_elements
)
return
;
const
auto
perm_length
=
generate_tuple
([
&
](
auto
I
)
{
return
kargs
.
perm_length
[
I
];
},
number
<
kMaxRanks
>
{});
const
auto
perm_stride
=
generate_tuple
([
&
](
auto
I
)
{
return
kargs
.
perm_stride
[
I
];
},
number
<
kMaxRanks
>
{});
const
DataType
*
p_src
=
reinterpret_cast
<
const
DataType
*>
(
kargs
.
p_src
);
DataType
*
p_dst
=
reinterpret_cast
<
DataType
*>
(
kargs
.
p_dst
);
const
auto
src_view_0
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_src
,
perm_length
,
perm_stride
,
number
<
1
>
{},
number
<
1
>
{});
const
auto
src_view
=
transform_tensor_view
(
src_view_0
,
make_tuple
(
make_merge_transform
(
perm_length
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
kMaxRanks
,
1
>::
type
{}),
make_tuple
(
sequence
<
0
>
{}));
auto
dst_view_0
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_dst
,
perm_length
,
number
<
1
>
{});
auto
dst_view
=
transform_tensor_view
(
dst_view_0
,
make_tuple
(
make_merge_transform
(
perm_length
)),
make_tuple
(
typename
arithmetic_sequence_gen
<
0
,
kMaxRanks
,
1
>::
type
{}),
make_tuple
(
sequence
<
0
>
{}));
// TODO: hard code to vector 1
using
vector_t
=
thread_buffer
<
DataType
,
1
>
;
const
auto
src_coord
=
make_tensor_coordinate
(
src_view
.
get_tensor_descriptor
(),
array
<
index_t
,
1
>
{
id
});
const
auto
dst_coord
=
make_tensor_coordinate
(
dst_view
.
get_tensor_descriptor
(),
array
<
index_t
,
1
>
{
id
});
// printf("src id:%d, os:%d\n", id, src_coord.get_offset());
// printf("dst id:%d, os:%d\n", id, dst_coord.get_offset());
const
vector_t
x
=
src_view
.
template
get_vectorized_elements
<
vector_t
>(
src_coord
,
0
);
dst_view
.
template
set_vectorized_elements
<
vector_t
>(
dst_coord
,
0
,
x
);
}
};
}
// namespace ck_tile
Prev
1
…
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment