Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3a107090
Commit
3a107090
authored
Jul 27, 2023
by
Jing Zhang
Browse files
finish splitk
parent
9e1dd262
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
56 additions
and
53 deletions
+56
-53
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
...tion/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
+46
-36
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
...gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
+10
-17
No files found.
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl_fixed_nk.hpp
View file @
3a107090
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_gemm_fixed_nk.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
//
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
//#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp"
#include "ck/host_utility/device_prop.hpp"
...
...
@@ -44,6 +44,7 @@ __global__ void
kernel_grouped_gemm_xdl_fixed_nk
(
const
void
CK_CONSTANT_ADDRESS_SPACE
*
gemm_descs_const
,
const
index_t
group_count
,
const
index_t
grid_size_grp
,
const
index_t
KBatch
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
c_element_op
)
...
...
@@ -79,7 +80,7 @@ __global__ void
const
index_t
BlockStart
=
group_id
*
grid_size_grp
;
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
const
auto
local_b2e_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
KBatch
};
constexpr
auto
NumDTensor
=
DsDataType
::
Size
();
...
...
@@ -118,6 +119,7 @@ __global__ void
StrideB
,
StrideDs
,
StrideE
,
KBatch
,
block_2_etile_map
);
m_id
+=
1
;
...
...
@@ -193,7 +195,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
const
index_t
k_batch
=
1
;
static
const
index_t
k_batch
=
2
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -285,7 +287,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
AtomicAdd
,
NumPrefetch
,
// NumGemmKPrefetchStage
BlockSize
,
MPerBlock
,
...
...
@@ -351,7 +353,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
auto
idx_bot
=
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
));
return
make_tuple
(
idx_bot
[
Number
<
0
>
{}]
+
mblock_id_off_
,
idx_bot
[
Number
<
1
>
{}]);
return
make_tuple
(
idx_bot
[
Number
<
0
>
{}],
idx_bot
[
Number
<
1
>
{}]
+
mblock_id_off_
,
idx_bot
[
Number
<
2
>
{}]);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
...
...
@@ -379,34 +382,35 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
};
template
<
index_t
MPerBlock_
,
index_t
NPerBlock_
>
struct
BlockToCTileMap_M00_N0_M01Adapt_MLoops
struct
BlockToCTileMap_
KBatch_
M00_N0_M01Adapt_MLoops
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
const
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
()
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
const
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
operator
=
(
const
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&
operator
=
(
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_KBatch_M00_N0_M01Adapt_MLoops
(
index_t
M
,
index_t
N
,
index_t
KBatch
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
KBatch_
(
KBatch
),
M01_
(
M01
)
{
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
__host__
__device__
BlockToCTileMap_
KBatch_
M00_N0_M01Adapt_MLoops
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
KBatch
,
index_t
M01
=
8
)
:
BlockToCTileMap_
KBatch_
M00_N0_M01Adapt_MLoops
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
KBatch
,
M01
)
{
}
...
...
@@ -415,16 +419,16 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
return
math
::
integer_divide_ceil
(
M_
,
MPerBlock_
);
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
/*M*/
,
index_t
N
)
__host__
constexpr
index_t
CalculateGridSize
(
index_t
/*M*/
,
index_t
N
)
const
{
const
auto
M0
=
1
;
// math::integer_divide_ceil(M, MPerBlock);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
return
M0
*
N0
*
KBatch_
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
...
...
@@ -443,7 +447,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
auto
M0
=
1
;
// math::integer_divide_ceil(M_, MPerBlock_);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock_
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
block_1d_id
=
block_1d_id
%
(
M0
*
N0
*
KBatch_
);
// hide groups
const
index_t
idx_ksplit
=
block_1d_id
/
(
M0
*
N0
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
index_t
idx_N0
=
block_1d_id
%
N0
;
index_t
idx_M0
=
block_1d_id
/
N0
;
...
...
@@ -454,7 +461,8 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
return
make_tuple
(
idx_ksplit
,
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
...
...
@@ -468,10 +476,11 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
private:
index_t
M_
;
index_t
N_
;
index_t
KBatch_
;
index_t
M01_
;
};
using
Block2ETileMap
=
BlockToCTileMap_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
Block2ETileMap
=
BlockToCTileMap_
KBatch_
M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
struct
GemmBiasTransKernelArg
...
...
@@ -563,7 +572,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
const
index_t
StrideA
=
gemm_descs
[
i
].
stride_A_
;
const
index_t
StrideB
=
gemm_descs
[
i
].
stride_B_
;
const
index_t
Stride
C
=
gemm_descs
[
i
].
stride_C_
;
const
index_t
Stride
E
=
gemm_descs
[
i
].
stride_C_
;
// pointer
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
...
...
@@ -607,10 +616,10 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
#endif
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
Stride
C
);
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
Stride
E
);
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
,
k_batch
};
grid_size_grp
=
local_b2c_tile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
...
...
@@ -629,7 +638,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
if
(
!
GridwiseGemm
::
template
CheckValidity
<
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
GemmSpec
>(
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
Stride
C
,
1
))
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
Stride
E
,
1
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"
);
...
...
@@ -646,7 +655,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
StrideA
,
StrideB
,
StrideDs
,
Stride
C
,
Stride
E
,
});
group_id
++
;
...
...
@@ -769,6 +778,7 @@ struct DeviceGroupedGemm_Xdl_Fixed_NK : public DeviceGroupedGemmFixedNK<ALayout,
cast_pointer_to_constant_address_space
(
kernel_args_dev
),
arg
.
gemm_desc_kernel_arg_
.
size
(),
arg
.
grid_size_grp
,
k_batch
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_splitk_cshuffle.hpp
View file @
3a107090
...
...
@@ -247,7 +247,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
return
math
::
integer_least_multiple
(
N
,
NPerBlock
);
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
=
1
)
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
,
index_t
K_Batch
)
{
return
math
::
integer_least_multiple
(
K
,
KPerBlock
*
K_Batch
);
}
...
...
@@ -407,7 +407,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
=
1
)
const
index_t
KBatch
)
{
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
,
KBatch
);
...
...
@@ -674,22 +674,14 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
auto
block_work_idx
=
block_2_etile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
if
(
!
block_2_etile_map
.
ValidCTileIndex
(
block_work_idx
,
make_tuple
(
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I0
),
e_grid_desc_mblock_mperblock_nblock_nperblock
.
GetLength
(
I2
))))
{
return
;
}
// HACK: this force m/n_block_data_idx_on_grid into SGPR
const
index_t
kbatch_id
=
0
;
//
__builtin_amdgcn_readfirstlane(block_work_idx[I0]);
const
index_t
kbatch_id
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I0
]);
const
index_t
m_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
0
]
*
MPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
1
]
*
MPerBlock
);
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
1
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I
2
]
*
NPerBlock
);
// if(get_thread_local_1d_id() == 0)
//{
...
...
@@ -979,7 +971,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
make_tuple
(
make_multi_index
(
0
,
0
,
0
,
0
)),
generate_tuple
(
[
&
](
auto
)
{
return
make_multi_index
(
block_work_idx
[
I
0
],
0
,
block_work_idx
[
I
1
],
0
);
return
make_multi_index
(
block_work_idx
[
I
1
],
0
,
block_work_idx
[
I
2
],
0
);
},
Number
<
NumDTensor
>
{}));
...
...
@@ -1010,7 +1002,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
{
c_ds_desc_refs
,
idx_c_ds_block_begin
,
tie
(
e_grid_desc_mblock_mperblock_nblock_nperblock
),
make_tuple
(
make_multi_index
(
block_work_idx
[
I
0
],
0
,
block_work_idx
[
I
1
],
0
)),
make_tuple
(
make_multi_index
(
block_work_idx
[
I
1
],
0
,
block_work_idx
[
I
2
],
0
)),
cde_element_op
};
// space filling curve for threadwise C in VGPR before shuffle
...
...
@@ -1103,6 +1095,7 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
const
index_t
StrideB
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
const
index_t
KBatch
,
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ABDataType
*>
(
p_a_grid_
);
...
...
@@ -1128,10 +1121,10 @@ struct GridwiseGemmMultipleD_xdl_splitk_cshuffle
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_kbatch_ak0_m_ak1
=
MakeAGridDescriptor_KBatch_AK0_M_AK1
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
,
1
);
MakeAGridDescriptor_KBatch_AK0_M_AK1
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
,
KBatch
);
const
auto
b_grid_desc_kbatch_bk0_n_bk1
=
MakeBGridDescriptor_KBatch_BK0_N_BK1
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
,
1
);
MakeBGridDescriptor_KBatch_BK0_N_BK1
<
BLayout
,
GemmSpec
>
(
K
,
N
,
StrideB
,
KBatch
);
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment