Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7ccf0bb5
Unverified
Commit
7ccf0bb5
authored
Oct 19, 2023
by
Chao Liu
Committed by
GitHub
Oct 19, 2023
Browse files
refactor gemm+softmax+gemm (#19)
* refactor gemm+softmax+gemm using block-gemm * reorg files * clean
parent
2dfbfbbc
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
406 additions
and
156 deletions
+406
-156
example/91_tile_program/gemm_softmax_gemm_impl.hpp
example/91_tile_program/gemm_softmax_gemm_impl.hpp
+25
-120
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp
...program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp
+31
-0
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp
...tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp
+150
-0
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
...ock_tile/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
+131
-0
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp
...program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp
+32
-0
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp
...tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp
+1
-16
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_problem.hpp
...rogram/block_tile/block_gemm_asmem_bsmem_creg_problem.hpp
+31
-0
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp
...ile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp
+1
-16
include/ck/tile_program/tile/load_tile.hpp
include/ck/tile_program/tile/load_tile.hpp
+4
-4
No files found.
example/91_tile_program/gemm_softmax_gemm_impl.hpp
View file @
7ccf0bb5
...
@@ -9,12 +9,14 @@
...
@@ -9,12 +9,14 @@
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck/tile_program/block_tile_pipeline/block_gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
#include "ck/tile_program/block_tile/block_reduce.hpp"
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
// S[M0, N0] = Q[M0, K0] * K[N0, K0]
...
@@ -46,95 +48,19 @@ struct GemmSoftmaxGemmImpl
...
@@ -46,95 +48,19 @@ struct GemmSoftmaxGemmImpl
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
;
ck
::
tile_program
::
block
::
BlockGemmPipelineAGmemBGmemCRegV2DefaultPolicy
>
;
// block gemm1
// block gemm1
using
BlockGemm1
=
ck
::
tile_program
::
block
::
BlockGemmARegB
S
memCRegV1
<
using
BlockGemm1
=
ck
::
tile_program
::
block
::
BlockGemmARegB
G
memCRegV1
<
ck
::
tile_program
::
block
::
BlockGemmARegB
S
memCReg
V1
Problem
<
ck
::
tile_program
::
block
::
BlockGemmARegB
G
memCRegProblem
<
PDataType
,
PDataType
,
VDataType
,
VDataType
,
OaccDataType
,
OaccDataType
,
kBlockSize
,
kBlockSize
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
kN0PerBlock
>>
,
ck
::
tile_program
::
TileGemmShape
<
kM0PerBlock
,
kN1PerBlock
,
kN0PerBlock
>>
,
ck
::
tile_program
::
block
::
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
ck
::
tile_program
::
block
::
BlockGemmARegBGmemCRegV1DefaultPolicy
>
;
#if 0
// 2d
__device__ static constexpr auto MakeVLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = kN1PerBlock;
constexpr index_t kKPerBlock = kN0PerBlock;
constexpr auto b_lds_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_desc;
}
#else
// fake XOR
__device__
static
constexpr
auto
MakeVLdsBlockDescriptor
()
{
using
namespace
ck
;
using
BDataType
=
VDataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
auto
b_lds_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
kNPerBlock
/
2
,
2
,
kKPerBlock
),
Number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
kNPerBlock
/
2
,
kKPerBlock
),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
b_lds_desc_n_k
=
transform_tensor_descriptor
(
b_lds_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
kNPerBlock
/
2
,
2
)),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_desc_n_k
;
}
#endif
__device__
static
constexpr
auto
MakeVDramTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
BDataType
=
VDataType
;
constexpr
index_t
kNPerBlock
=
kN1PerBlock
;
constexpr
index_t
kKPerBlock
=
kN0PerBlock
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
}
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
{
{
using
namespace
ck
;
return
ck
::
math
::
max
(
BlockGemm0Pipeline
::
GetStaticLdsSize
(),
BlockGemm1
::
GetStaticLdsSize
());
return
math
::
max
(
BlockGemm0Pipeline
::
GetStaticLdsSize
(),
static_cast
<
index_t
>
(
MakeVLdsBlockDescriptor
().
GetElementSpaceSize
()
*
sizeof
(
VDataType
)));
}
}
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
__device__
void
operator
()(
const
QDataType
*
q_ptr
,
...
@@ -162,7 +88,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -162,7 +88,7 @@ struct GemmSoftmaxGemmImpl
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetStaticLdsSize
()];
__shared__
char
smem_ptr
[
GetStaticLdsSize
()];
// Q/K/V DRAM
and DRAM window
// Q/K/V DRAM
// FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1]
// FIXME: assume layout Q[M0, K0], K[N0, K0], V[N1, N0], O[M0, N1]
const
auto
q_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
const
auto
q_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
q_ptr
,
make_tuple
(
M0
,
K0
),
make_tuple
(
StrideQ
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
q_ptr
,
make_tuple
(
M0
,
K0
),
make_tuple
(
StrideQ
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
...
@@ -173,25 +99,15 @@ struct GemmSoftmaxGemmImpl
...
@@ -173,25 +99,15 @@ struct GemmSoftmaxGemmImpl
const
auto
v_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
const
auto
v_dram
=
make_naive_tensor_view
<
AddressSpaceEnum
::
Global
>
(
v_ptr
,
make_tuple
(
N1
,
N0
),
make_tuple
(
StrideV
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
v_ptr
,
make_tuple
(
N1
,
N0
),
make_tuple
(
StrideV
,
1
),
Number
<
32
>
{},
Number
<
1
>
{});
// Q/K/V DRAM window
auto
q_dram_window
=
make_tile_window
(
auto
q_dram_window
=
make_tile_window
(
q_dram
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
iM0
,
0
});
q_dram
,
make_tuple
(
Number
<
kM0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
iM0
,
0
});
auto
k_dram_window
=
make_tile_window
(
auto
k_dram_window
=
make_tile_window
(
k_dram
,
make_tuple
(
Number
<
kN0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
0
,
0
});
k_dram
,
make_tuple
(
Number
<
kN0PerBlock
>
{},
Number
<
kK0PerBlock
>
{}),
{
0
,
0
});
auto
v_dram_window
=
auto
v_dram_window
=
make_tile_window
(
make_tile_window
(
v_dram
,
v_dram
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
iN1
,
0
});
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
iN1
,
0
},
MakeVDramTileDistribution
());
// V LDS and LDS window
// V LDS occupies the same LDS allocation Q/K LDS
auto
v_lds
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
VDataType
*>
(
smem_ptr
),
MakeVLdsBlockDescriptor
());
auto
v_lds_window
=
make_tile_window
(
v_lds
,
make_tuple
(
Number
<
kN1PerBlock
>
{},
Number
<
kN0PerBlock
>
{}),
{
0
,
0
});
// Block GEMM0 pipeline and Block GEMM1
// Block GEMM0 pipeline and Block GEMM1
constexpr
auto
gemm0_pipeline
=
BlockGemm0Pipeline
{};
constexpr
auto
gemm0_pipeline
=
BlockGemm0Pipeline
{};
...
@@ -214,7 +130,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -214,7 +130,7 @@ struct GemmSoftmaxGemmImpl
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
using
MLBlockTileType
=
decltype
(
block_tile_reduce
<
SMPLComputeDataType
>
(
SBlockTileType
{},
Sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
SBlockTileType
{},
Sequence
<
1
>
{},
f_max
,
SMPLComputeDataType
{
0
}));
using
OaccBlockTileType
=
decltype
(
gemm1
(
PBlockTileType
{},
v_dram_window
));
using
OaccBlockTileType
=
decltype
(
gemm1
(
PBlockTileType
{},
v_dram_window
,
smem_ptr
));
// init Oacc, M, L
// init Oacc, M, L
auto
o_acc
=
OaccBlockTileType
{};
auto
o_acc
=
OaccBlockTileType
{};
...
@@ -286,7 +202,7 @@ struct GemmSoftmaxGemmImpl
...
@@ -286,7 +202,7 @@ struct GemmSoftmaxGemmImpl
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
// FIXME: this use different equation from FA v2 paper,
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// but produce correc
t
result.
// Is the equation wrong?
// Is the equation wrong?
o_acc
(
i_j_idx
)
*=
tmp
;
o_acc
(
i_j_idx
)
*=
tmp
;
});
});
...
@@ -296,30 +212,19 @@ struct GemmSoftmaxGemmImpl
...
@@ -296,30 +212,19 @@ struct GemmSoftmaxGemmImpl
const
auto
p
=
const
auto
p
=
tile_elementwise_in
(
type_convert
<
PDataType
,
SMPLComputeDataType
>
,
p_compute
);
tile_elementwise_in
(
type_convert
<
PDataType
,
SMPLComputeDataType
>
,
p_compute
);
// Block GEMM1: Oacc{j} += P{j} * V{j}
// wait for gemm0 pipeline to finish reading Lds
{
// load V{j}
const
auto
v
=
load_tile
(
v_dram_window
);
// wait for gemm0 pipeline to finish
block_sync_lds
();
store_tile
(
v_lds_window
,
v
);
// wait for store_tile to finish
block_sync_lds
();
block_sync_lds
();
// Oacc{j} += P{j} * V{j}
// Block GEMM1: Oacc{j} += P{j} * V{j}
gemm1
(
o_acc
,
p
,
v_lds_window
);
gemm1
(
o_acc
,
p
,
v_dram_window
,
smem_ptr
);
// wait for gemm1 to finish
block_sync_lds
();
}
// move tile windows
// move
K/V
tile windows
for next iteration (J loop)
move_tile_window
(
k_dram_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
k_dram_window
,
{
kN0PerBlock
,
0
});
move_tile_window
(
v_dram_window
,
{
0
,
kN0PerBlock
});
move_tile_window
(
v_dram_window
,
{
0
,
kN0PerBlock
});
// wait for gemm1 to finish reading Lds, before next iteration (J loop)
block_sync_lds
();
iN0
+=
kN0PerBlock
;
iN0
+=
kN0PerBlock
;
}
while
(
iN0
<
N0
);
}
while
(
iN0
<
N0
);
...
...
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp
0 → 100644
View file @
7ccf0bb5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// Problem Description for BlockGemmARegBGmemCReg
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmARegBGmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1.hpp
0 → 100644
View file @
7ccf0bb5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/static_tile_distribution_helper.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// A is block distributed tensor
// B is block window on global memory
// C is block distributed tensor
// This will:
// 1. Load B from global memory into shared memory and then
// 2. Call BlockGemmARegSGmemCRegV1
template
<
typename
Problem
,
typename
Policy
=
BlockGemmARegBGmemCRegV1DefaultPolicy
>
struct
BlockGemmARegBGmemCRegV1
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
CDataType
=
remove_cvref_t
<
typename
Problem
::
CDataType
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
// use BlockGemmARegBSmemCRegV1 as the underlying block-GEMM implementation
using
BlockGemmARegBSmemCRegImpl
=
BlockGemmARegBSmemCRegV1
<
BlockGemmARegBSmemCRegProblem
<
ADataType
,
BDataType
,
CDataType
,
kBlockSize
,
BlockGemmShape
>
,
BlockGemmARegBSmemCRegV1DefaultPolicy
>
;
__host__
__device__
static
constexpr
ck
::
index_t
GetStaticLdsSize
()
{
return
sizeof
(
BDataType
)
*
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>().
GetElementSpaceSize
();
}
// C += A * B
template
<
typename
CBlockTensor
,
typename
ABlockTensor
,
typename
BBlockGmemWindowTmp
>
__device__
void
operator
()(
CBlockTensor
&
c_block_tensor
,
const
ABlockTensor
&
a_block_tensor
,
const
BBlockGmemWindowTmp
&
b_block_gmem_window_tmp
,
void
*
smem_ptr
)
const
{
static_assert
(
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockGmemWindowTmp
::
DataType
>>
&&
is_same_v
<
CDataType
,
remove_cv_t
<
typename
CBlockTensor
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensor
{}.
GetLengths
()[
Number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockGmemWindowTmp
{}.
GetWindowLengths
()[
Number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensor
{}.
GetLengths
()[
Number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
const
auto
b_block_gmem_window
=
make_tile_window
(
b_block_gmem_window_tmp
.
GetBottomTensorView
(),
make_tuple
(
Number
<
NPerBlock
>
{},
Number
<
KPerBlock
>
{}),
b_block_gmem_window_tmp
.
GetWindowOrigin
(),
Policy
::
template
MakeBGmemTileDistribution
<
Problem
>());
// B LDS and LDS window
auto
b_block_smem
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
BDataType
*>
(
smem_ptr
),
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>());
auto
b_block_smem_window
=
make_tile_window
(
b_block_smem
,
make_tuple
(
Number
<
MPerBlock
>
{},
Number
<
KPerBlock
>
{}),
{
0
,
0
});
// load B tile from global mem
const
auto
b_block_tile
=
load_tile
(
b_block_gmem_window
);
// store B tile into shared mem
store_tile
(
b_block_smem_window
,
b_block_tile
);
// wait for store_tile to finish
block_sync_lds
();
// block GEMM
BlockGemmARegBSmemCRegImpl
{}(
c_block_tensor
,
a_block_tensor
,
b_block_smem_window
);
}
// C = A * B
template
<
typename
ABlockTensor
,
typename
BBlockGmemWindowTmp
>
__device__
auto
operator
()(
const
ABlockTensor
&
a_block_tensor
,
const
BBlockGmemWindowTmp
&
b_block_gmem_window_tmp
,
void
*
smem_ptr
)
const
{
static_assert
(
is_same_v
<
ADataType
,
remove_cv_t
<
typename
ABlockTensor
::
DataType
>>
&&
is_same_v
<
BDataType
,
remove_cv_t
<
typename
BBlockGmemWindowTmp
::
DataType
>>
,
"wrong!"
);
constexpr
index_t
MPerBlock
=
ABlockTensor
{}.
GetLengths
()[
Number
<
0
>
{}];
constexpr
index_t
NPerBlock
=
BBlockGmemWindowTmp
{}.
GetWindowLengths
()[
Number
<
0
>
{}];
constexpr
index_t
KPerBlock
=
ABlockTensor
{}.
GetLengths
()[
Number
<
1
>
{}];
static_assert
(
MPerBlock
==
BlockGemmShape
::
kM
&&
NPerBlock
==
BlockGemmShape
::
kN
&&
KPerBlock
==
BlockGemmShape
::
kK
,
"wrong!"
);
const
auto
b_block_gmem_window
=
make_tile_window
(
b_block_gmem_window_tmp
.
GetBottomTensorView
(),
make_tuple
(
Number
<
NPerBlock
>
{},
Number
<
KPerBlock
>
{}),
b_block_gmem_window_tmp
.
GetWindowOrigin
(),
Policy
::
template
MakeBGmemTileDistribution
<
Problem
>());
// B LDS and LDS window
auto
b_block_smem
=
make_tensor_view
<
AddressSpaceEnum
::
Lds
>
(
reinterpret_cast
<
BDataType
*>
(
smem_ptr
),
Policy
::
template
MakeBSmemBlockDescriptor
<
Problem
>());
auto
b_block_smem_window
=
make_tile_window
(
b_block_smem
,
make_tuple
(
Number
<
MPerBlock
>
{},
Number
<
KPerBlock
>
{}),
{
0
,
0
});
// load B tile from global mem
const
auto
b_block_tile
=
load_tile
(
b_block_gmem_window
);
// store B tile into shared mem
store_tile
(
b_block_smem_window
,
b_block_tile
);
// wait for store_tile to finish
block_sync_lds
();
// block GEMM
return
BlockGemmARegBSmemCRegImpl
{}(
a_block_tensor
,
b_block_smem_window
);
}
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile/block_gemm_areg_bgmem_creg_v1_default_policy.hpp
0 → 100644
View file @
7ccf0bb5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
#include "ck/tile_program/tile/tile_distribution.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// Default policy for BlockGemmARegBGmemCRegV1
// Default policy class should not be templated, put template on member functions instead
struct
BlockGemmARegBGmemCRegV1DefaultPolicy
{
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
MakeBGmemTileDistribution
()
{
using
namespace
ck
;
using
namespace
ck
::
tile_program
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
StaticTileDistributionEncoding
<
Sequence
<
1
>
,
Tuple
<
Sequence
<
N0
,
N1
,
N2
>
,
Sequence
<
K0
,
K1
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
1
,
2
>>
,
Tuple
<
Sequence
<
1
>
,
Sequence
<
2
,
0
>>
,
Sequence
<
1
,
2
>
,
Sequence
<
0
,
1
>>
{});
}
#if 0
// 2d
template <typename Problem>
__host__ __device__ static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), Number<32>{});
return b_lds_block_desc;
}
#elif
0
// 3d + padding
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
MakeBSmemBlockDescriptor
()
{
using
namespace
ck
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
kKPerBlock
/
8
>
{},
Number
<
kNPerBlock
>
{},
Number
<
8
>
{}),
make_tuple
(
Number
<
(
kNPerBlock
+
1
)
*
8
>
{},
Number
<
8
>
{},
Number
<
1
>
{}),
Number
<
8
>
{},
Number
<
1
>
{});
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
b_lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
kNPerBlock
),
make_merge_transform
(
make_tuple
(
kKPerBlock
/
8
,
8
))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
#elif 1
// fake XOR
template
<
typename
Problem
>
__host__
__device__
static
constexpr
auto
MakeBSmemBlockDescriptor
()
{
using
namespace
ck
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
kNPerBlock
/
2
>
{},
Number
<
2
>
{},
Number
<
kKPerBlock
>
{}),
Number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
kNPerBlock
/
2
>
{},
Number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
Number
<
kNPerBlock
/
2
>
{},
Number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp
0 → 100644
View file @
7ccf0bb5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// Problem Description for BlockGemmARegBSmemCReg
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmARegBSmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1.hpp
View file @
7ccf0bb5
...
@@ -13,28 +13,13 @@
...
@@ -13,28 +13,13 @@
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
#include "ck/tile_program/block_tile/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
namespace
ck
{
namespace
ck
{
namespace
tile_program
{
namespace
tile_program
{
namespace
block
{
namespace
block
{
// Problem Description for BlockGemmARegBSmemCRegV1
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmARegBSmemCRegV1Problem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
// A is block distributed tensor
// A is block distributed tensor
// B is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
// C is block distributed tensor
...
...
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_problem.hpp
0 → 100644
View file @
7ccf0bb5
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/utility/type.hpp"
namespace
ck
{
namespace
tile_program
{
namespace
block
{
// Problem Description for BlockGemmASmemBSmemCRegV1
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmASmemBSmemCRegProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
}
// namespace block
}
// namespace tile_program
}
// namespace ck
include/ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1.hpp
View file @
7ccf0bb5
...
@@ -14,28 +14,13 @@
...
@@ -14,28 +14,13 @@
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_elementwise.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/tile/tile_gemm_shape.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/warp_tile/warp_gemm.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_problem.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck/tile_program/block_tile/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
namespace
ck
{
namespace
ck
{
namespace
tile_program
{
namespace
tile_program
{
namespace
block
{
namespace
block
{
// Problem Description for BlockGemmASmemBSmemCRegV1
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
index_t
kBlockSize_
,
typename
BlockGemmShape_
>
struct
BlockGemmASmemBSmemCRegV1Problem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
};
// A is block window on shared memory
// A is block window on shared memory
// B is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
// C is block distributed tensor
...
...
include/ck/tile_program/tile/load_tile.hpp
View file @
7ccf0bb5
...
@@ -20,7 +20,7 @@ template <typename BottomTensorView_,
...
@@ -20,7 +20,7 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
TileDistribution_
,
index_t
NumCoord
>
index_t
NumCoord
>
__device__
auto
load_tile
(
TileWindowWithStaticDistribution
<
BottomTensorView_
,
__device__
auto
load_tile
(
const
TileWindowWithStaticDistribution
<
BottomTensorView_
,
WindowLengths_
,
WindowLengths_
,
TileDistribution_
,
TileDistribution_
,
NumCoord
>&
tile_window
)
NumCoord
>&
tile_window
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment