Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6fcaeada
Commit
6fcaeada
authored
Oct 15, 2024
by
Astha Rai
Browse files
fixed merge conflict after merge with develop
parents
fc7a1825
d02a92cc
Changes
122
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1457 additions
and
156 deletions
+1457
-156
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
+4
-5
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+10
-7
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+424
-0
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+27
-0
include/ck_tile/ops/image_to_column.hpp
include/ck_tile/ops/image_to_column.hpp
+9
-0
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
...ile/ops/image_to_column/kernel/image_to_column_kernel.hpp
+224
-0
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
...mage_to_column/pipeline/block_image_to_column_problem.hpp
+27
-0
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
...s/image_to_column/pipeline/tile_image_to_column_shape.hpp
+32
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+247
-86
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
...ps/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
+13
-9
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
...library/reference_tensor_operation/gpu/reference_gemm.hpp
+245
-0
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+13
-25
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
+16
-6
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+6
-6
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+4
-0
script/cmake-ck-release.sh
script/cmake-ck-release.sh
+4
-0
test/CMakeLists.txt
test/CMakeLists.txt
+5
-12
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
test/ck_tile/image_to_column/CMakeLists.txt
test/ck_tile/image_to_column/CMakeLists.txt
+4
-0
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
+142
-0
No files found.
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp
View file @
6fcaeada
...
...
@@ -7,12 +7,11 @@
namespace
ck_tile
{
// Default policy for
Block
GemmPipelineAGmemBGmemCRegV2
// Default policy for GemmPipelineAGmemBGmemCRegV2
// Default policy class should not be templated, put template on member functions instead
// NOTE: policy should be binded to its corresponding operation. It's just a coincidence that
// BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
BlockGemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
// GemmPipelineAGmemBGmemCRegV2DefaultPolicy is the same as
// GemmPipelineAGmemBGmemCRegV1DefaultPolicy
using
GemmPipelineAGmemBGmemCRegV2DefaultPolicy
=
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
;
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/
block_
gemm_pipeline_problem.hpp
→
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
6fcaeada
...
...
@@ -13,20 +13,23 @@ template <typename ADataType_,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
bool
kPadA_
=
false
,
bool
kPadB_
=
false
,
bool
kPadC_
=
false
>
struct
BlockGemmPipelineProblem
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
using
LayoutA
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutA
>
;
using
LayoutB
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutB
>
;
using
LayoutC
=
remove_cvref_t
<
typename
GemmTraits
::
LayoutC
>
;
static
constexpr
index_t
AlignmentA
=
kPadA
?
1
:
VectorLoadSize
/
sizeof
(
ADataType
);
static
constexpr
index_t
AlignmentB
=
kPadB
?
1
:
VectorLoadSize
/
sizeof
(
BDataType
);
...
...
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace
ck_tile
{
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
LayoutA
>::
value
)
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
MLdsLayer
>
{},
number
<
MPerBlock
/
MLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
MLdsLayer
>
{},
I1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
K0
*
MLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
MLdsLayer
>
{})),
make_pass_through_transform
(
number
<
MPerBlock
/
MLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_ak0_kMLdsLayer_m_ak1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
MPerBlock
/
MLdsLayer
>
{},
number
<
MLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
else
// ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr
auto
M0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I0
);
constexpr
auto
M1
=
MPerBlock
/
M0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
M0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kM
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
M0
*
sizeof
(
ADataType
)
>
128
)
?
1
:
128
/
(
K1
*
M0
*
sizeof
(
ADataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=mpair<=kN0
constexpr
auto
mpair
=
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)))
>
M0
?
M0
:
128
/
(
K1
*
WarpGemm
::
kM
*
sizeof
(
ADataType
)));
constexpr
auto
a_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{},
number
<
mpair
>
{},
K1
));
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
M1
>
{},
number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
a_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
a_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
M1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
M0
/
mpair
>
{})),
make_pass_through_transform
(
number
<
mpair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
M0
/
mpair
>
{},
number
<
mpair
>
{},
number
<
M1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
a_lds_block_desc_m_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
LayoutB
>::
value
)
{
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
?
1
:
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
);
;
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
K1
),
make_tuple
(
K1
,
number
<
KPerBlock
*
NLdsLayer
>
{},
I1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
K0
*
NLdsLayer
>
{})),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
constexpr
auto
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
K0
,
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{}))),
make_tuple
(
sequence
<
0
,
3
>
{},
sequence
<
1
,
2
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
else
// RowMajor B
{
constexpr
auto
N0
=
get_warp_size
()
*
Problem
::
BlockGemmShape
::
BlockWarps
::
at
(
I1
);
constexpr
auto
N1
=
NPerBlock
/
N0
;
constexpr
auto
KThreadWrite
=
Problem
::
kBlockSize
/
N0
;
constexpr
auto
K0PerThreadWrite
=
K0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
WarpGemm
::
kN
;
constexpr
auto
K0PerThreadRead
=
K0
/
KThreadRead
;
constexpr
auto
kfold
=
(
K1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
K1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=kN0
constexpr
auto
npair
=
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
K1
*
WarpGemm
::
kN
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
K1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
K1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
K1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
return
b_lds_block_desc_n_k
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeA
()
{
constexpr
index_t
smem_size_a
=
sizeof
(
typename
Problem
::
ADataType
)
*
MakeALdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_a
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSizeB
()
{
constexpr
index_t
smem_size_b
=
sizeof
(
typename
Problem
::
BDataType
)
*
MakeBLdsBlockDescriptor
<
Problem
>
().
get_element_space_size
();
return
smem_size_b
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
constexpr
index_t
smem_size_a
=
GetSmemSizeA
<
Problem
>
();
constexpr
index_t
smem_size_b
=
GetSmemSizeB
<
Problem
>
();
index_t
smem_size
=
0
;
smem_size
+=
smem_size_a
+
smem_size_b
;
return
smem_size
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
AccDataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
WarpGemm
>
;
return
BlockGemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
bool
kPadA_
,
bool
kPadB_
,
bool
kPadC_
,
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
TileGemmTraits
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadC
=
kPadC_
;
using
LayoutA
=
LayoutA_
;
using
LayoutB
=
LayoutB_
;
using
LayoutC
=
LayoutC_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp"
#include "ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp"
#include "ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/image_to_column/kernel/image_to_column_kernel.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
template
<
typename
Problem_
>
struct
ImageToColumn
{
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I3
=
number
<
3
>
{};
static
constexpr
auto
I4
=
number
<
4
>
{};
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
InDataType
=
remove_cvref_t
<
typename
Problem
::
InDataType
>
;
using
OutDataType
=
remove_cvref_t
<
typename
Problem
::
OutDataType
>
;
static
constexpr
index_t
NDimSpatial
=
Problem
::
NDimSpatial
;
static
constexpr
index_t
AligmentIn
=
Problem
::
AligmentIn
;
static
constexpr
index_t
AligmentOut
=
Problem
::
AligmentOut
;
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
static
constexpr
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
index_t
kKPerBlock
=
Problem
::
BlockShape
::
kKPerBlock
;
struct
Kargs
{
const
void
*
p_in
;
void
*
p_out
;
const
long_index_t
G
;
const
long_index_t
N
;
const
long_index_t
C
;
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
;
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
;
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
;
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
;
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
;
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
;
};
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
void
*
p_in
,
void
*
p_out
,
const
long_index_t
G
,
const
long_index_t
N
,
const
long_index_t
C
,
const
array
<
long_index_t
,
NDimSpatial
>
input_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
filter_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
>
output_spatial_lengths
,
const
array
<
long_index_t
,
NDimSpatial
+
3
>
image_g_n_c_wis_strides
,
const
array
<
long_index_t
,
3
>
gemm_g_m_k_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_strides
,
const
array
<
long_index_t
,
NDimSpatial
>
conv_filter_dilations
,
const
array
<
long_index_t
,
NDimSpatial
>
input_left_pads
,
const
array
<
long_index_t
,
NDimSpatial
>
input_right_pads
)
{
return
Kargs
{
p_in
,
p_out
,
G
,
N
,
C
,
input_spatial_lengths
,
filter_spatial_lengths
,
output_spatial_lengths
,
image_g_n_c_wis_strides
,
gemm_g_m_k_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
index_t
GemmM
,
index_t
GemmK
,
index_t
Batch
)
{
return
dim3
(
integer_divide_ceil
(
GemmM
,
kMPerBlock
),
integer_divide_ceil
(
GemmK
,
kKPerBlock
),
Batch
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
kBlockSize
;
}
CK_TILE_DEVICE
auto
MakeImageMKDesc
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
kargs
.
N
,
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
C
),
make_tuple
(
kargs
.
image_g_n_c_wis_strides
[
I1
],
kargs
.
image_g_n_c_wis_strides
[
I3
],
kargs
.
image_g_n_c_wis_strides
[
I4
],
kargs
.
image_g_n_c_wis_strides
[
I2
]),
number
<
AligmentIn
>
{},
I1
);
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I0
],
kargs
.
input_left_pads
[
I0
],
kargs
.
input_right_pads
[
I0
]),
make_pad_transform
(
kargs
.
input_spatial_lengths
[
I1
],
kargs
.
input_left_pads
[
I1
],
kargs
.
input_right_pads
[
I1
]),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
kargs
.
N
),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I0
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I0
],
kargs
.
conv_filter_strides
[
I0
])),
make_embed_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
output_spatial_lengths
[
I1
]),
make_tuple
(
kargs
.
conv_filter_dilations
[
I1
],
kargs
.
conv_filter_strides
[
I1
])),
make_pass_through_transform
(
kargs
.
C
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{},
sequence
<
5
>
{}));
return
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kargs
.
N
,
kargs
.
output_spatial_lengths
[
I0
],
kargs
.
output_spatial_lengths
[
I1
])),
make_merge_transform
(
make_tuple
(
kargs
.
filter_spatial_lengths
[
I0
],
kargs
.
filter_spatial_lengths
[
I1
],
kargs
.
C
))),
make_tuple
(
sequence
<
0
,
2
,
4
>
{},
sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
}
CK_TILE_DEVICE
auto
CalculateMKDims
(
const
Kargs
&
kargs
)
const
{
static_assert
(
NDimSpatial
==
2
,
"Not supported."
);
const
index_t
M
=
kargs
.
N
*
static_cast
<
index_t
>
(
kargs
.
output_spatial_lengths
[
I0
]
*
kargs
.
output_spatial_lengths
[
I1
]);
const
index_t
K
=
kargs
.
C
*
static_cast
<
index_t
>
(
kargs
.
filter_spatial_lengths
[
I0
]
*
kargs
.
filter_spatial_lengths
[
I1
]);
return
make_tuple
(
M
,
K
);
}
CK_TILE_DEVICE
static
constexpr
auto
MakeBlockTileDistribution
()
{
using
P
=
typename
Problem
::
BlockShape
;
// P: {kMWarpPerBlock * kKWarpPerBlock, kMThreadPerWarp * kKThreadPerWarp}
// Y: {kMPerThread, kKPerThread}
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
P
::
kMWarpPerBlock
,
P
::
kMThreadPerWarp
,
P
::
kMPerThread
>
,
sequence
<
P
::
kKWarpPerBlock
,
P
::
kKThreadPerWarp
,
P
::
kKPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
2
>>
{});
}
CK_TILE_DEVICE
void
ConvTensorRearrange
(
const
Kargs
&
kargs
)
const
{
const
auto
[
M
,
K
]
=
CalculateMKDims
(
kargs
);
const
index_t
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kMPerBlock
);
const
index_t
iK
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kKPerBlock
);
const
index_t
iBatch
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
z
);
const
auto
in_offset
=
iBatch
*
kargs
.
image_g_n_c_wis_strides
[
I0
];
const
auto
out_offset
=
iBatch
*
kargs
.
gemm_g_m_k_strides
[
I0
];
const
auto
image_m_k
=
make_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
InDataType
*>
(
kargs
.
p_in
)
+
in_offset
,
MakeImageMKDesc
(
kargs
));
const
auto
gemm_m_k
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
OutDataType
*>
(
kargs
.
p_out
)
+
out_offset
,
make_tuple
(
M
,
K
),
make_tuple
(
kargs
.
gemm_g_m_k_strides
[
I1
],
kargs
.
gemm_g_m_k_strides
[
I2
]),
number
<
AligmentOut
>
{},
I1
);
const
auto
image_m_k_padded
=
pad_tensor_view
(
image_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
const
auto
gemm_m_k_padded
=
pad_tensor_view
(
gemm_m_k
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
sequence
<
false
,
true
>
{});
constexpr
auto
dstr
=
MakeBlockTileDistribution
();
const
auto
image_tile
=
make_tile_window
(
image_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
auto
gemm_tile
=
make_tile_window
(
gemm_m_k_padded
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
iM
,
iK
},
dstr
);
// load from Global
const
auto
loaded_tile
=
load_tile
(
image_tile
);
// save to Global
store_tile
(
gemm_tile
,
loaded_tile
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
&
kargs
)
const
{
ConvTensorRearrange
(
kargs
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/block_image_to_column_problem.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
template
<
typename
InDataType_
,
typename
OutDataType_
,
typename
BlockShape_
,
index_t
NDimSpatial_
,
index_t
AligmentIn_
,
index_t
AligmentOut_
>
struct
BlockImageToColumnProblem
{
using
InDataType
=
remove_cvref_t
<
InDataType_
>
;
using
OutDataType
=
remove_cvref_t
<
OutDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
index_t
NDimSpatial
=
NDimSpatial_
;
static
constexpr
index_t
AligmentIn
=
AligmentIn_
;
static
constexpr
index_t
AligmentOut
=
AligmentOut_
;
};
}
// namespace ck_tile
include/ck_tile/ops/image_to_column/pipeline/tile_image_to_column_shape.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
ThreadTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
BlockTile
>
// Sequence<...
struct
TileImageToColumnShape
{
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kKThreadPerWarp
=
kKPerWarp
/
kKPerThread
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kKPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kKWarpPerBlock
=
kKPerBlock
/
kKPerWarp
;
static
constexpr
index_t
kBlockSize
=
warpSize
*
kMWarpPerBlock
*
kKWarpPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
6fcaeada
...
...
@@ -31,8 +31,14 @@ struct Layernorm2dFwd
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
BlockShape
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
BlockShape
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
ck_tile
::
index_t
kNThreadPerWarp
=
Problem
::
BlockShape
::
kNThreadPerWarp
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
BlockShape
::
kNPerThread
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
...
...
@@ -96,19 +102,25 @@ struct Layernorm2dFwd
sequence
<
2
>>
{});
}
template
<
typename
Dstr
>
CK_TILE_DEVICE
static
constexpr
auto
GetNPerThread
(
Dstr
)
CK_TILE_DEVICE
static
int
GetWelfordMaxCount
(
int
N
)
{
constexpr
auto
nDstrSpan
=
Dstr
::
get_distributed_spans
().
template
at
<
1
>();
using
Lengths
=
decltype
(
nDstrSpan
.
impl_
);
constexpr
ck_tile
::
index_t
kNThreadPerBlock
=
kNPerBlock
/
kNPerThread
;
ck_tile
::
index_t
ret
=
1
;
int
thread_id_n
=
get_thread_id
()
%
kNThreadPerBlock
;
int
max_count
=
__builtin_amdgcn_readfirstlane
(
N
<
kNPerBlock
?
0
:
kNPerThread
*
(
N
/
kNPerBlock
));
int
n_per_block_tail_loop
=
__builtin_amdgcn_readfirstlane
(
N
-
max_count
*
kNThreadPerBlock
);
ck_tile
::
static_for
<
0
,
Lengths
::
size
(),
1
>
{}(
[
&
](
auto
idx
)
{
ret
*=
Lengths
::
template
at
(
idx
);
});
if
(
n_per_block_tail_loop
>
0
)
{
int
thread_max_n
=
(
thread_id_n
+
1
)
*
kNPerThread
;
int
delta
=
thread_max_n
-
n_per_block_tail_loop
;
delta
=
clamp
(
thread_max_n
-
n_per_block_tail_loop
,
0
,
kNPerThread
);
max_count
+=
kNPerThread
-
delta
;
}
return
re
t
;
return
max_coun
t
;
}
template
<
typename
DistributedTensor
>
...
...
@@ -129,42 +141,29 @@ struct Layernorm2dFwd
return
out_dstr_tensor
;
}
template
<
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
,
MeanDataType
*
p_mean
,
InvStdDataType
*
p_invStd
,
const
ComputeDataType
epsilon
,
ck_tile
::
index_t
M
,
ck_tile
::
index_t
N
)
const
template
<
typename
XBlockWindow
,
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
TwoPassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
constexpr
auto
I0
=
number
<
0
>
{};
constexpr
auto
I1
=
number
<
1
>
{};
const
auto
x_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_x
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
gamma_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_gamma
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
32
>
{},
number
<
1
>
{});
// TODO - Optimize tail loop to reduce move_tile_window()
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
N
,
kNPerBlock
));
const
auto
beta_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_beta
,
make_tuple
(
N
),
make_tuple
(
1
),
number
<
32
>
{},
number
<
1
>
{});
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
N
/
kNPerBlock
);
// TODO: padding - handle max_count if N % kNPerBlock != 0
constexpr
auto
NPerThread
=
GetNPerThread
(
xDstr
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
type_convert
<
int
>
(
NPerThread
*
N
/
kNPerBlock
)};
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
...
...
@@ -190,44 +189,14 @@ struct Layernorm2dFwd
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_mean
,
make_tuple
(
M
),
number
<
32
>
{});
auto
mean_block_window
=
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
}
if
constexpr
(
kSaveInvStd
)
{
const
auto
inv_std_m
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
p_invStd
,
make_tuple
(
M
),
number
<
32
>
{});
auto
inv_std_block_window
=
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
store_tile
(
inv_std_block_window
,
cast_tile
<
MeanDataType
>
(
inv_std_compute_block_tensor
));
}
// TODO: Extract normalize pipeline
const
auto
y_m_n
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_y
,
make_tuple
(
M
,
N
),
make_tuple
(
N
,
1
),
number
<
32
>
{},
number
<
1
>
{});
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
N
-
kNPerBlock
;
ck_tile
::
index_t
stride_to_right_most_window
=
N
%
kNPerBlock
==
0
?
N
-
kNPerBlock
:
N
-
N
%
kNPerBlock
;
move_tile_window
(
x_block_window
,
{
0
,
-
kNPerBlock
});
move_tile_window
(
gamma_block_window
,
{
stride_to_right_most_window
});
...
...
@@ -274,17 +243,209 @@ struct Layernorm2dFwd
}
}
template
<
typename
XBlockWindow
,
typename
GammaBlockWindow
,
typename
BetaBlockWindow
,
typename
YBlockWindow
,
typename
MeanBlockWindow
,
typename
InvStdBlockWindow
,
bool
Cond
=
(
kHasGamma
&&
kHasBeta
)>
CK_TILE_DEVICE
std
::
enable_if_t
<
Cond
>
OnePassLayernorm2dFwd
(
XBlockWindow
&
x_block_window
,
GammaBlockWindow
&
gamma_block_window
,
BetaBlockWindow
&
beta_block_window
,
YBlockWindow
&
y_block_window
,
MeanBlockWindow
&
mean_block_window
,
InvStdBlockWindow
&
inv_std_block_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
N
)
const
{
int
welford_max_count
=
GetWelfordMaxCount
(
N
);
ThreadWelford
<
ComputeDataType
,
XDataType
>
thread_welford
{
welford_max_count
};
using
XTensorType
=
decltype
(
load_tile
(
x_block_window
));
auto
mean_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
auto
var_compute_block_tensor
=
thread_welford
.
template
MakeInitialMeanVarDistributedTensor
<
XTensorType
>();
clear_tile
(
mean_compute_block_tensor
);
clear_tile
(
var_compute_block_tensor
);
const
auto
x_block_tensor
=
load_tile
(
x_block_window
);
thread_welford
(
x_block_tensor
,
mean_compute_block_tensor
,
var_compute_block_tensor
);
// TODO: support cross warp Welford
WarpMergeWelford
<
ComputeDataType
,
true
>
{}(
mean_compute_block_tensor
,
var_compute_block_tensor
,
thread_welford
.
cur_count_
);
auto
inv_std_compute_block_tensor
=
InvSqrt
(
var_compute_block_tensor
,
epsilon
);
if
constexpr
(
kSaveMean
)
store_tile
(
mean_block_window
,
cast_tile
<
MeanDataType
>
(
mean_compute_block_tensor
));
if
constexpr
(
kSaveInvStd
)
store_tile
(
inv_std_block_window
,
cast_tile
<
InvStdDataType
>
(
inv_std_compute_block_tensor
));
// normalize
const
auto
gamma_block_tensor
=
load_tile
(
gamma_block_window
);
const
auto
beta_block_tensor
=
load_tile
(
beta_block_window
);
constexpr
auto
x_spans
=
decltype
(
x_block_tensor
)
::
get_distributed_spans
();
auto
y_block_tensor
=
make_static_distributed_tensor
<
YDataType
>
(
x_block_tensor
.
get_tile_distribution
());
sweep_tile_span
(
x_spans
[
I1
],
[
&
](
auto
idx1
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx1
);
const
auto
gamma
=
type_convert
<
ComputeDataType
>
(
gamma_block_tensor
[
j_idx
]);
const
auto
beta
=
type_convert
<
ComputeDataType
>
(
beta_block_tensor
[
j_idx
]);
sweep_tile_span
(
x_spans
[
I0
],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
const
auto
mean
=
mean_compute_block_tensor
[
i_idx
];
const
auto
inv_std
=
inv_std_compute_block_tensor
[
i_idx
];
const
auto
x
=
type_convert
<
ComputeDataType
>
(
x_block_tensor
[
i_j_idx
]);
auto
y
=
(
x
-
mean
)
*
inv_std
*
gamma
+
beta
;
y_block_tensor
(
i_j_idx
)
=
type_convert
<
YDataType
>
(
y
);
});
});
store_tile
(
y_block_window
,
y_block_tensor
);
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
TwoPassLayernorm2dFwd
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
M
,
kargs
.
N
);
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
const
auto
gamma_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
beta_n
=
[
&
]()
{
const
auto
gamma_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BetaDataType
*>
(
kargs
.
p_beta
),
make_tuple
(
kargs
.
N
),
make_tuple
(
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
gamma_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{}),
sequence
<
kPadN
>
{});
}();
const
auto
iM
=
get_block_id
()
*
kMPerBlock
;
constexpr
auto
xDstr
=
MakeXBlockTileDistribution
();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
},
xDstr
);
const
auto
y_m_n
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
M
,
kargs
.
N
),
make_tuple
(
kargs
.
N
,
1
),
number
<
kNPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
auto
y_block_window
=
make_tile_window
(
y_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
iM
,
0
});
constexpr
auto
gammaDstr
=
MakeGammaBetaBlockTileDistribution
();
constexpr
auto
betaDstr
=
gammaDstr
;
auto
gamma_block_window
=
make_tile_window
(
gamma_n
,
make_tuple
(
number
<
kNPerBlock
>
{}),
{
0
},
gammaDstr
);
auto
beta_block_window
=
make_tile_window
(
beta_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
0
},
betaDstr
);
auto
mean_block_window
=
[
&
]()
{
if
constexpr
(
kSaveMean
)
{
const
auto
mean_m
=
[
&
]()
{
const
auto
mean_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
MeanDataType
*>
(
kargs
.
p_mean
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
mean_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
mean_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
auto
inv_std_block_window
=
[
&
]()
{
if
constexpr
(
kSaveInvStd
)
{
const
auto
inv_std_m
=
[
&
]()
{
const
auto
inv_std_dram_naive
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
InvStdDataType
*>
(
kargs
.
p_invStd
),
make_tuple
(
kargs
.
M
),
number
<
1
>
{});
return
pad_tensor_view
(
inv_std_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
inv_std_m
,
make_tuple
(
number
<
kMPerBlock
>
{}),
{
iM
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
kMPerBlock
>
{}));
}();
if
(
kargs
.
N
<=
kNPerBlock
)
OnePassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
N
);
else
TwoPassLayernorm2dFwd
(
x_block_window
,
gamma_block_window
,
beta_block_window
,
y_block_window
,
mean_block_window
,
inv_std_block_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
N
);
}
};
...
...
include/ck_tile/ops/layernorm2d/pipeline/block_layernorm2d_fwd_problem.hpp
View file @
6fcaeada
...
...
@@ -14,17 +14,21 @@ template <typename XDataType_,
typename
YDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
BlockShape_
>
typename
BlockShape_
,
bool
kPadM_
,
bool
kPadN_
>
struct
BlockLayernorm2dFwdProblem
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
};
}
// namespace ck_tile
library/include/ck/library/reference_tensor_operation/gpu/reference_gemm.hpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <sstream>
#include "ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
ComputeTypeA
,
typename
ComputeTypeB
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
naive_gemm_kernel
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
CDataType
*
__restrict__
p_c_grid
,
index_t
m
,
index_t
n
,
index_t
k
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
{
using
RowMajor
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
const
int
row_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int
col_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
row_idx
<
m
&&
col_idx
<
n
)
{
AccDataType
v_acc
=
static_cast
<
AccDataType
>
(
0.0
);
ComputeTypeA
v_a
=
static_cast
<
ComputeTypeA
>
(
0.0
);
ComputeTypeB
v_b
=
static_cast
<
ComputeTypeB
>
(
0.0
);
CDataType
v_c
=
static_cast
<
CDataType
>
(
0.0
);
for
(
int
k_idx
=
0
;
k_idx
<
k
;
++
k_idx
)
{
// check input matrices layout
int
element_idx_a
=
0
;
int
element_idx_b
=
0
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
RowMajor
>
)
{
element_idx_a
=
row_idx
*
k
+
k_idx
;
}
else
{
element_idx_a
=
row_idx
+
m
*
k_idx
;
}
if
constexpr
(
std
::
is_same_v
<
BLayout
,
RowMajor
>
)
{
element_idx_b
=
k_idx
*
n
+
col_idx
;
}
else
{
element_idx_b
=
k_idx
+
k
*
col_idx
;
}
// apply a_element_op
a_element_op
(
v_a
,
p_a_grid
[
element_idx_a
]);
// apply b_element_op
b_element_op
(
v_b
,
p_b_grid
[
element_idx_b
]);
// multiply and accumulate
v_acc
+=
static_cast
<
AccDataType
>
(
v_a
)
*
static_cast
<
AccDataType
>
(
v_b
);
}
// apply c_element_op
c_element_op
(
v_c
,
v_acc
);
// check output matrix layout
int
element_idx_c
=
0
;
if
constexpr
(
std
::
is_same_v
<
CLayout
,
RowMajor
>
)
{
element_idx_c
=
row_idx
*
n
+
col_idx
;
}
else
{
element_idx_c
=
row_idx
+
m
*
col_idx
;
}
// prepare output
p_c_grid
[
element_idx_c
]
=
v_c
;
}
}
}
// namespace ck
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
ComputeTypeA
=
CDataType
,
typename
ComputeTypeB
=
ComputeTypeA
>
struct
ReferenceGemm
:
public
device
::
BaseOperator
{
// Argument
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
void
*
p_c_grid
,
index_t
m
,
index_t
n
,
index_t
k
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
:
p_a_grid_
{
static_cast
<
const
ADataType
*>
(
p_a_grid
)},
p_b_grid_
{
static_cast
<
const
BDataType
*>
(
p_b_grid
)},
p_c_grid_
{
static_cast
<
CDataType
*>
(
p_c_grid
)},
m_
{
m
},
n_
{
n
},
k_
{
k
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
{
}
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
index_t
m_
;
index_t
n_
;
index_t
k_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
};
// Invoker
struct
Invoker
:
public
device
::
BaseInvoker
{
using
Argument
=
ReferenceGemm
::
Argument
;
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
int
block_size
=
16
;
dim3
block_dim
(
block_size
,
block_size
,
1
);
dim3
grid_dim
(
(
arg
.
m_
+
block_size
-
1
)
/
block_size
,
(
arg
.
n_
+
block_size
-
1
)
/
block_size
,
1
);
auto
launch_kernel
=
[
&
]()
{
const
auto
kernel
=
naive_gemm_kernel
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
BDataType
,
CDataType
,
AccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputeTypeA
,
ComputeTypeB
>
;
return
launch_and_time_kernel
(
stream_config
,
kernel
,
grid_dim
,
block_dim
,
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
m_
,
arg
.
n_
,
arg
.
k_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
};
return
launch_kernel
();
}
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
)
override
{
return
true
;
}
static
auto
MakeArgument
(
const
void
*
p_a_grid
,
const
void
*
p_b_grid
,
void
*
p_c_grid
,
index_t
m
,
index_t
n
,
index_t
k
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
return
Argument
{
p_a_grid
,
p_b_grid
,
p_c_grid
,
m
,
n
,
k
,
a_element_op
,
b_element_op
,
c_element_op
};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
virtual
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
(
Invoker
{});
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Device Reference Gemm"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
6fcaeada
...
...
@@ -37,11 +37,7 @@ function(add_instance_library INSTANCE_NAME)
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
# Do not build DL instances if DL_KERNELS macro is not set
foreach
(
source IN LISTS ARGN
)
...
...
@@ -64,9 +60,9 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
# Do not build mha instances if gfx94 targets are not on the target list
# Do not build mha instances if gfx94
or gfx90a
targets are not on the target list
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"mha"
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND
NOT INST_TARGETS MATCHES
"gfx90a"
AND
source MATCHES
"mha"
)
message
(
"removing mha instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
...
...
@@ -75,17 +71,13 @@ function(add_instance_library INSTANCE_NAME)
if
(
ARGN
)
set
(
INST_OBJ
)
foreach
(
source IN LISTS ARGN
)
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
elseif
(
ARGN MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908
gfx90a
gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
...
...
@@ -102,12 +94,14 @@ function(add_instance_library INSTANCE_NAME)
set
(
FMHA_FWD_FAST_EXP2 true
)
endif
()
if
(
FMHA_FWD_FAST_EXP2
)
list
(
APPEND
EXAMPLE_FMHA_FWD
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
list
(
APPEND
FMHA
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero
)
else
()
list
(
APPEND
EXAMPLE_FMHA_FWD
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
list
(
APPEND
FMHA
_COMPILE_OPTIONS -Wno-undefined-func-template -DCK_TILE_FMHA_FWD_FAST_EXP2=0
)
endif
()
list
(
APPEND EXAMPLE_FMHA_FWD_COMPILE_OPTIONS -Wno-float-equal
)
target_compile_options
(
device_mha_instance PRIVATE
${
EXAMPLE_FMHA_FWD_COMPILE_OPTIONS
}
)
list
(
APPEND FMHA_COMPILE_OPTIONS -Wno-float-equal
)
list
(
APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1
)
list
(
APPEND FMHA_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1
)
target_compile_options
(
device_mha_instance PRIVATE
${
FMHA_COMPILE_OPTIONS
}
)
endif
()
target_compile_features
(
${
INSTANCE_NAME
}
PUBLIC
)
...
...
@@ -189,12 +183,7 @@ FOREACH(subdir_path ${dir_list})
set
(
add_inst 1
)
endif
()
if
(
INSTANCES_ONLY
)
set
(
INST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
INST_TARGETS
${
GPU_TARGETS
}
)
endif
()
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
if
((
"
${
cmake_instance
}
"
MATCHES
"quantization"
)
AND
(
DEFINED DTYPES
)
AND
(
NOT DTYPES MATCHES
"int8"
))
message
(
"quantization instances will not be built!"
)
...
...
@@ -318,8 +307,7 @@ if(CK_DEVICE_CONV_INSTANCES)
endif
()
if
(
CK_DEVICE_MHA_INSTANCES
)
set
(
gpu_list
${
INST_TARGETS
}
)
list
(
FILTER gpu_list INCLUDE REGEX
"^gfx94"
)
if
(
gpu_list
)
if
(
gpu_list MATCHES
"gfx94"
OR gpu_list MATCHES
"gfx90a"
)
add_library
(
device_mha_operations STATIC
${
CK_DEVICE_MHA_INSTANCES
}
)
add_library
(
composablekernels::device_mha_operations ALIAS device_mha_operations
)
target_compile_features
(
device_mha_operations PUBLIC
)
...
...
library/src/tensor_operation_instance/gpu/mha/CMakeLists.txt
View file @
6fcaeada
...
...
@@ -32,23 +32,33 @@ if(EXISTS ${FMHA_CPP_FOLDER}/blob_list.txt)
file
(
REMOVE
${
FMHA_CPP_FOLDER
}
/blob_list.txt
)
endif
()
set
(
FMHA_KNOWN_APIS
"fwd,fwd_splitkv,fwd_appendkv,bwd"
)
# generate a list of kernels, but not actually emit files at config stage
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
execute_process
(
COMMAND
${
PYTHON_EXECUTABLE
}
${
CMAKE_SOURCE_DIR
}
/example/ck_tile/01_fmha
/generate.py
COMMAND
${
PYTHON_EXECUTABLE
}
${
FMHA_SRC_FOLDER
}
/generate.py
--list_blobs
${
FMHA_CPP_FOLDER
}
/blob_list.txt
--api
${
FMHA_KNOWN_APIS
}
--receipt 3
RESULT_VARIABLE ret
)
if
(
ret AND NOT ret EQUAL 0
)
message
(
FATAL_ERROR
"CK Tile MHA FAILED to genrate a list of kernels via Python."
)
else
()
file
(
STRINGS
${
FMHA_CPP_FOLDER
}
/blob_list.txt FMHA_
FWD_
GEN_BLOBS
)
file
(
STRINGS
${
FMHA_CPP_FOLDER
}
/blob_list.txt FMHA_GEN_BLOBS
)
endif
()
# actually generate the kernel content now
# Note: The receipt 3 arg filters the generated backwards instances to reduce compilation time.
# With receipt 3 set, we are generating instances for datatype == {fp16 || bfp16}, bias == {no || alibi}, deterministic == off, and dpad == dvpad.
add_custom_command
(
OUTPUT
${
FMHA_
FWD_
GEN_BLOBS
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
CMAKE_SOURCE_DIR
}
/example/ck_tile/01_fmha
/generate.py
OUTPUT
${
FMHA_GEN_BLOBS
}
COMMAND
${
PYTHON_EXECUTABLE
}
${
FMHA_SRC_FOLDER
}
/generate.py
--output_dir
${
FMHA_CPP_FOLDER
}
--api
${
FMHA_KNOWN_APIS
}
--receipt 3
COMMENT
"Generating mha kernel (cpp) files now ..."
VERBATIM
)
...
...
@@ -57,12 +67,12 @@ add_custom_command(
# have filename. Since, it was cauing the cmake
# to throw "File name too long"
set
(
device_files
)
foreach
(
filepath IN LISTS FMHA_
FWD_
GEN_BLOBS
)
foreach
(
filepath IN LISTS FMHA_GEN_BLOBS
)
get_filename_component
(
filename
${
filepath
}
NAME
)
# Append the filename to the device_files list
list
(
APPEND device_files
${
filename
}
)
endforeach
()
add_custom_target
(
generate_cpp_files DEPENDS
${
FMHA_
FWD_
GEN_BLOBS
}
)
add_custom_target
(
generate_cpp_files DEPENDS
${
FMHA_GEN_BLOBS
}
)
add_instance_library
(
device_mha_instance
${
device_files
}
)
...
...
profiler/src/CMakeLists.txt
View file @
6fcaeada
...
...
@@ -24,7 +24,7 @@ set(PROFILER_SOURCES
profile_permute_scale.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_contraction_bilinear.cpp
)
list
(
APPEND PROFILER_SOURCES profile_contraction_scale.cpp
)
...
...
@@ -49,7 +49,7 @@ if(GPU_TARGETS MATCHES "gfx9")
list
(
APPEND PROFILER_SOURCES profile_grouped_gemm_multiply_tile_loop.cpp
)
endif
()
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_add.cpp
)
if
(
GPU_TARGETS MATCHES
"gfx94"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx94"
)
list
(
APPEND PROFILER_SOURCES profile_gemm_multiply_multiply.cpp
)
list
(
APPEND PROFILER_SOURCES profile_gemm_ab_scale.cpp
)
endif
()
...
...
@@ -69,7 +69,7 @@ if(GPU_TARGETS MATCHES "gfx9")
endif
()
if
(
GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
OR GPU_TARGETS MATCHES
"gfx9"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx11"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx12"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
list
(
APPEND PROFILER_SOURCES profile_gemm_bilinear.cpp
)
endif
()
...
...
@@ -111,7 +111,7 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_column_to_image_inst
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_transpose_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_permute_scale_instance
)
if
(
GPU_TARGETS MATCHES
"gfx9"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
)
if
(
DTYPES MATCHES
"fp32"
OR DTYPES MATCHES
"fp64"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_bilinear_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_contraction_scale_instance
)
...
...
@@ -135,7 +135,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batched_gemm_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_add_instance
)
if
(
GPU_TARGETS MATCHES
"gfx94"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx94"
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_multiply_multiply_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_ab_scale_instance
)
endif
()
...
...
@@ -159,7 +159,7 @@ if(GPU_TARGETS MATCHES "gfx9")
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_grouped_conv3d_fwd_convinvscale_instance
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx9"
OR GPU_TARGETS MATCHES
"gfx11"
OR GPU_TARGETS MATCHES
"gfx12"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx9"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx11"
OR
SUPPORTED_
GPU_TARGETS MATCHES
"gfx12"
)
if
(
DTYPES MATCHES
"fp16"
OR NOT DEFINED DTYPES
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_gemm_bilinear_instance
)
endif
()
...
...
script/cmake-ck-dev.sh
View file @
6fcaeada
...
...
@@ -7,8 +7,11 @@ MY_PROJECT_SOURCE=$1
if
[
$#
-ge
2
]
;
then
GPU_TARGETS
=
$2
shift
2
REST_ARGS
=
$@
else
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
REST_ARGS
=
fi
cmake
\
...
...
@@ -20,4 +23,5 @@ cmake
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
$REST_ARGS
\
${
MY_PROJECT_SOURCE
}
script/cmake-ck-release.sh
View file @
6fcaeada
...
...
@@ -7,8 +7,11 @@ MY_PROJECT_SOURCE=$1
if
[
$#
-ge
2
]
;
then
GPU_TARGETS
=
$2
shift
2
REST_ARGS
=
$@
else
GPU_TARGETS
=
"gfx908;gfx90a;gfx940"
REST_ARGS
=
fi
cmake
\
...
...
@@ -20,5 +23,6 @@ cmake
-D
GPU_TARGETS
=
$GPU_TARGETS
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
$REST_ARGS
\
${
MY_PROJECT_SOURCE
}
test/CMakeLists.txt
View file @
6fcaeada
...
...
@@ -41,11 +41,7 @@ function(add_test_executable TEST_NAME)
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
...
@@ -122,11 +118,7 @@ function(add_gtest_executable TEST_NAME)
endforeach
()
endif
()
if
(
INSTANCES_ONLY
)
set
(
TEST_TARGETS
${
DEFAULT_GPU_TARGETS
}
)
else
()
set
(
TEST_TARGETS
${
GPU_TARGETS
}
)
endif
()
set
(
TEST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT DEFINED DL_KERNELS AND source MATCHES
"_dl"
)
...
...
@@ -173,6 +165,7 @@ function(add_gtest_executable TEST_NAME)
endfunction
()
add_compile_options
(
-Wno-c++20-extensions
)
add_subdirectory
(
ck_tile
)
add_subdirectory
(
magic_number_division
)
add_subdirectory
(
space_filling_curve
)
add_subdirectory
(
conv_util
)
...
...
@@ -210,10 +203,10 @@ add_subdirectory(conv_tensor_rearrange)
add_subdirectory
(
transpose
)
add_subdirectory
(
permute_scale
)
add_subdirectory
(
wrapper
)
if
(
GPU_TARGETS MATCHES
"gfx11"
)
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx11"
)
add_subdirectory
(
wmma_op
)
endif
()
if
(
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
if
(
SUPPORTED_
GPU_TARGETS MATCHES
"gfx942"
AND CK_HIP_VERSION_MAJOR GREATER_EQUAL 6 AND CK_HIP_VERSION_MINOR GREATER_EQUAL 2
)
# smfmac needs ROCm6.2
add_subdirectory
(
smfmac_op
)
endif
()
add_subdirectory
(
position_embedding
)
test/ck_tile/CMakeLists.txt
0 → 100644
View file @
6fcaeada
add_subdirectory
(
image_to_column
)
test/ck_tile/image_to_column/CMakeLists.txt
0 → 100644
View file @
6fcaeada
# Currently ck_tile is only built on gfx9
if
(
GPU_TARGETS MATCHES
"gfx9"
)
add_gtest_executable
(
test_tile_image_to_column test_tile_image_to_column.cpp
)
endif
()
test/ck_tile/image_to_column/test_tile_image_to_column.cpp
0 → 100644
View file @
6fcaeada
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <gtest/gtest.h>
#include "ck_tile/host.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/ops/image_to_column.hpp"
// Host API implementation
template
<
typename
DataType
>
class
TestCkTileImageToColumn
:
public
::
testing
::
Test
{
static
constexpr
ck_tile
::
index_t
VectorSize
=
1
;
static
constexpr
ck_tile
::
index_t
NDimSpatial
=
2
;
protected:
void
Run
(
const
ck_tile
::
conv
::
ConvParam
conv_params
)
{
using
ImLayout
=
ck_tile
::
tensor_layout
::
convolution
::
NHWGC
;
const
auto
G
=
conv_params
.
G_
;
const
auto
N
=
conv_params
.
N_
;
const
auto
C
=
conv_params
.
C_
;
const
ck_tile
::
long_index_t
NDoHoWo
=
N
*
std
::
accumulate
(
conv_params
.
output_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
output_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
ck_tile
::
long_index_t
CZYX
=
C
*
std
::
accumulate
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
std
::
next
(
conv_params
.
filter_spatial_lengths_
.
begin
(),
NDimSpatial
),
1
,
std
::
multiplies
<>
());
const
auto
in_desc
=
ck_tile
::
conv
::
make_input_host_tensor_descriptor_g_n_c_wis_packed
<
ImLayout
>
(
conv_params
);
const
auto
out_desc
=
ck_tile
::
HostTensorDescriptor
({
G
,
NDoHoWo
,
CZYX
});
// host verify
ck_tile
::
HostTensor
<
DataType
>
in
(
in_desc
);
ck_tile
::
HostTensor
<
DataType
>
out_device
(
out_desc
);
ck_tile
::
HostTensor
<
DataType
>
out_host
(
out_desc
);
std
::
cout
<<
"input: "
<<
in
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"output: "
<<
out_device
.
mDesc
<<
std
::
endl
;
ck_tile
::
FillUniformDistributionIntegerValue
<
DataType
>
{
-
5.
f
,
5.
f
}(
in
);
ck_tile
::
DeviceMem
in_device_buf
(
in
.
get_element_space_size_in_bytes
());
ck_tile
::
DeviceMem
out_device_buf
(
out_device
.
get_element_space_size_in_bytes
());
in_device_buf
.
ToDevice
(
in
.
data
());
using
thread_tile
=
ck_tile
::
sequence
<
4
,
4
>
;
using
warp_tile
=
ck_tile
::
sequence
<
8
,
128
>
;
using
block_tile
=
ck_tile
::
sequence
<
32
,
128
>
;
using
Shape
=
ck_tile
::
TileImageToColumnShape
<
thread_tile
,
warp_tile
,
block_tile
>
;
using
PipelineProblem
=
ck_tile
::
BlockImageToColumnProblem
<
DataType
,
DataType
,
Shape
,
NDimSpatial
,
VectorSize
,
VectorSize
>
;
using
Kernel
=
ck_tile
::
ImageToColumn
<
PipelineProblem
>
;
auto
kargs
=
Kernel
::
MakeKargs
(
in_device_buf
.
GetDeviceBuffer
(),
out_device_buf
.
GetDeviceBuffer
(),
G
,
N
,
C
,
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
filter_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
output_spatial_lengths_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
+
3
>
(
in_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
3
>
(
out_desc
.
get_strides
()),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_strides_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
conv_filter_dilations_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_left_pads_
),
ck_tile
::
to_array
<
ck_tile
::
long_index_t
,
NDimSpatial
>
(
conv_params
.
input_right_pads_
));
const
dim3
grids
=
Kernel
::
GridSize
(
kargs
.
N
*
kargs
.
output_spatial_lengths
[
0
]
*
kargs
.
output_spatial_lengths
[
1
],
kargs
.
filter_spatial_lengths
[
0
]
*
kargs
.
filter_spatial_lengths
[
1
]
*
kargs
.
C
,
kargs
.
G
);
constexpr
dim3
blocks
=
Kernel
::
BlockSize
();
constexpr
ck_tile
::
index_t
kBlockPerCu
=
2
;
ck_tile
::
launch_kernel
(
ck_tile
::
stream_config
{},
ck_tile
::
make_kernel
<
blocks
.
x
,
kBlockPerCu
>
(
Kernel
{},
grids
,
blocks
,
0
,
kargs
));
// reference
ck_tile
::
reference_im2col
<
DataType
,
DataType
,
NDimSpatial
>
(
in
,
out_host
,
conv_params
);
out_device_buf
.
FromDevice
(
out_device
.
data
());
bool
pass
=
ck_tile
::
check_err
(
out_device
,
out_host
);
EXPECT_TRUE
(
pass
);
}
};
class
TestCkTileImageToColumnFloat
:
public
TestCkTileImageToColumn
<
float
>
{
};
class
TestCkTileImageToColumnHalf
:
public
TestCkTileImageToColumn
<
ck_tile
::
half_t
>
{
};
TEST_F
(
TestCkTileImageToColumnFloat
,
TestCorrectness
)
{
this
->
Run
({
2
,
2
,
4
,
1
,
192
,
{
3
,
3
},
{
28
,
28
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
7
,
7
},
{
3
,
3
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
}});
}
TEST_F
(
TestCkTileImageToColumnHalf
,
TestCorrectness
)
{
this
->
Run
({
2
,
2
,
4
,
1
,
192
,
{
3
,
3
},
{
28
,
28
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
7
,
7
},
{
3
,
3
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
1
,
64
,
1
,
64
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
Run
({
2
,
2
,
64
,
1
,
64
,
{
3
,
3
},
{
28
,
28
},
{
2
,
2
},
{
2
,
2
},
{
1
,
1
},
{
1
,
1
}});
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment