Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
fee2002c
Commit
fee2002c
authored
Jul 17, 2023
by
Jing Zhang
Browse files
add b2c_tile_map
parent
326d6bc6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
273 additions
and
48 deletions
+273
-48
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
+18
-0
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
...sor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
+246
-44
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
+9
-4
No files found.
include/ck/tensor_operation/gpu/device/device_grouped_gemm.hpp
View file @
fee2002c
...
...
@@ -20,6 +20,24 @@ struct GemmDesc
std
::
vector
<
ck
::
index_t
>
stride_Ds_
;
};
template
<
index_t
NumDTensor
=
0
>
struct
GroupedGemmKernelArgument
{
const
void
*
p_a_grid
;
const
void
*
p_b_grid
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
void
*
p_e_grid
;
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
;
index_t
StrideE
;
};
template
<
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_gemm_xdl.hpp
View file @
fee2002c
...
...
@@ -25,7 +25,12 @@ namespace device {
template
<
typename
GridwiseGemm
,
typename
GemmDesc
,
GemmSpecialization
GemmSpec
,
typename
Block2ETileMapKSplit
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
Block2ETileMap
,
typename
GroupedGemmBlock2ETileMap
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
...
...
@@ -98,11 +103,12 @@ __global__ void
if
(
M
==
0
||
N
==
0
||
K
==
0
)
return
;
const
index_t
StrideA
=
K
;
const
index_t
StrideB
=
K
;
const
index_t
StrideDs
[]
=
{}
;
const
index_t
StrideE
=
N
;
const
auto
StrideA
=
gemm_desc_ptr
[
group_id
].
StrideA
;
const
auto
StrideB
=
gemm_desc_ptr
[
group_id
].
StrideB
;
const
auto
StrideDs
=
gemm_desc_ptr
[
group_id
].
StrideDs
;
const
auto
StrideE
=
gemm_desc_ptr
[
group_id
].
StrideE
;
#if 0
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
...
...
@@ -110,34 +116,60 @@ __global__ void
using BLayout = Col;
using DsLayout = ck::Tuple<>;
using ELayout = Row;
#endif
using
DsDataType
=
ck
::
Tuple
<>
;
const
auto
e_grid_desc_m_n
=
GridwiseGemm
::
template
MakeEGridDescriptor_M_N
<
ELayout
,
GemmSpec
>(
M
,
N
,
StrideE
);
const
index_t
BlockStart
=
group_id
*
grid_size_grp
;
using
GroupedGemmBlock2ET
ile
M
ap
=
OffsettedBlockToCTileMap
<
Block2ETileMapKSplit
>
;
const
auto
local_b2e_t
ile
_m
ap
=
Block2ETileMap
{
e_grid_desc_m_n
}
;
const
auto
local_b2e_tile_map
=
Block2ETileMapKSplit
{
e_grid_desc_m_n
};
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
);
constexpr
auto
NumDTensor
=
0
;
using
DsGridPointer
=
decltype
(
GridwiseGemm
::
MakeDsGridPointer
());
DsGridPointer
p_ds_grid_
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
gemm_desc_ptr
[
group_id
].
p_ds_grid
[
i
]);
});
auto
m_loops
=
local_b2e_tile_map
.
CalculateMLoops
();
index_t
m_id
=
0
;
do
{
const
auto
block_2_etile_map
=
GroupedGemmBlock2ETileMap
(
local_b2e_tile_map
,
BlockStart
,
m_id
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
p_a_grid
,
gemm_desc_ptr
[
group_id
].
p_b_grid
,
p_ds_grid_
,
gemm_desc_ptr
[
group_id
].
p_e_grid
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
block_2_etile_map
);
m_id
+=
1
;
}
while
(
m_id
<
m_loops
);
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
>(
gemm_desc_ptr
[
group_id
].
a_ptr_
,
gemm_desc_ptr
[
group_id
].
b_ptr_
,
gemm_desc_ptr
[
group_id
].
ds_ptr_
,
gemm_desc_ptr
[
group_id
].
e_ptr_
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideDs
,
StrideE
,
block_2_etile_map
);
#endif
#else
...
...
@@ -342,18 +374,162 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
using
Block2ETileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMap
<
Block2ETileMap
>
;
template
<
typename
UnderlyingBlockToCTileMap
>
struct
OffsettedBlockToCTileMapMLoops
{
using
underlying_type
=
UnderlyingBlockToCTileMap
;
__host__
__device__
OffsettedBlockToCTileMapMLoops
(
UnderlyingBlockToCTileMap
block_to_ctile_map
,
index_t
block_start
,
index_t
mblock_id_off
=
0
)
{
block_to_ctile_map_
=
block_to_ctile_map
;
block_start_
=
block_start
;
mblock_id_off_
=
mblock_id_off
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
idx_bot
=
block_to_ctile_map_
.
CalculateBottomIndex
(
make_multi_index
(
idx_top
[
Number
<
0
>
{}]
-
block_start_
));
return
make_tuple
(
idx_bot
[
Number
<
0
>
{}]
+
mblock_id_off_
,
idx_bot
[
Number
<
1
>
{}]);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
{
return
block_to_ctile_map_
.
ValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
template
<
typename
CGridDesc_M_N
>
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
return
block_to_ctile_map_
.
CalculateGridSize
(
c_grid_desc_m_n
);
}
UnderlyingBlockToCTileMap
block_to_ctile_map_
;
index_t
block_start_
;
index_t
mblock_id_off_
;
};
template
<
index_t
MPerBlock_
,
index_t
NPerBlock_
>
struct
BlockToCTileMap_M00_N0_M01Adapt_MLoops
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
const
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt_MLoops
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt_MLoops
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
}
__host__
__device__
constexpr
index_t
CalculateMLoops
()
const
{
return
math
::
integer_divide_ceil
(
M_
,
MPerBlock_
);
}
__host__
static
constexpr
index_t
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
const
auto
M0
=
math
::
integer_divide_ceil
(
M
,
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
N
,
NPerBlock
);
return
M0
*
N0
;
}
template
<
typename
CGridDesc_M_N
>
__host__
static
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
return
CalculateGridSize
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
));
}
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
1
;
// math::integer_divide_ceil(M_, MPerBlock_);
const
auto
N0
=
math
::
integer_divide_ceil
(
N_
,
NPerBlock_
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
index_t
idx_N0
=
block_1d_id
%
N0
;
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
private:
index_t
M_
;
index_t
N_
;
index_t
M01_
;
};
using
Block2ETileMap
=
BlockToCTileMap_M00_N0_M01Adapt_MLoops
<
MPerBlock
,
NPerBlock
>
;
using
GroupedGemmBlock2ETileMap
=
OffsettedBlockToCTileMapMLoops
<
Block2ETileMap
>
;
struct
GemmBiasTransKernelArg
{
// pointers
const
ADataType
*
a_ptr_
;
const
BDataType
*
b_ptr_
;
typename
GridwiseGemm
::
DsGridPointer
ds_ptr_
;
EDataType
*
e_ptr_
;
const
void
*
a_ptr_
;
const
void
*
b_ptr_
;
std
::
array
<
const
void
*
,
NumDTensor
>
ds_ptr_
;
void
*
e_ptr_
;
index_t
M
,
N
,
K
;
index_t
M_
,
N_
,
K_
;
index_t
StrideA_
,
StrideB_
;
std
::
array
<
index_t
,
NumDTensor
>
StrideDs_
;
index_t
StrideE_
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k_
;
...
...
@@ -415,12 +591,12 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const
index_t
StrideC
=
gemm_descs
[
i
].
stride_C_
;
// pointer
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid
{}
;
std
::
array
<
const
void
*
,
NumDTensor
>
p_ds_grid
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
j
)
{
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
j
.
value
,
DsDataType
>>
;
p_ds_grid
(
j
)
=
static_cast
<
const
DDataType
*>
(
p_Ds
[
i
][
j
]);
p_ds_grid
[
j
]
=
static_cast
<
const
DDataType
*>
(
p_Ds
[
i
][
j
]);
});
// tensor descriptors for problem definiton
...
...
@@ -436,9 +612,6 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
M
,
N
,
gemm_descs
[
i
].
stride_Ds_
[
j
]);
});
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideC
);
// tensor descriptors for block/thread-wise copy
const
auto
a_grid_desc_ak0_m_ak1
=
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
);
...
...
@@ -446,6 +619,9 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
const
auto
b_grid_desc_bk0_n_bk1
=
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
);
const
auto
e_grid_desc_m_n
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
M
,
N
,
StrideC
);
// block-to-e-tile map
const
auto
local_b2c_tile_map
=
Block2ETileMap
{
e_grid_desc_m_n
};
...
...
@@ -479,13 +655,17 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
e_grid_desc_m_n
);
gemm_desc_kernel_arg_
.
push_back
(
GemmBiasTransKernelArg
{
static_cast
<
const
ADataType
*>
(
p_As
[
i
]
)
,
static_cast
<
const
BDataType
*>
(
p_Bs
[
i
]
)
,
GemmBiasTransKernelArg
{
p_As
[
i
],
p_Bs
[
i
],
p_ds_grid
,
static_cast
<
EDataType
*>
(
p_Es
[
i
]
)
,
p_Es
[
i
],
M
,
N
,
K
,
StrideA
,
StrideB
,
{},
StrideC
,
a_grid_desc_m_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
...
...
@@ -526,6 +706,10 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
bool
has_main_k_block_loop
=
true
;
std
::
vector
<
GroupedGemmKernelArgument
<
NumDTensor
>>
grouped_gemm_kernel_args
;
grouped_gemm_kernel_args
.
reserve
(
arg
.
gemm_desc_kernel_arg_
.
size
());
for
(
std
::
size_t
i
=
0
;
i
<
arg
.
gemm_desc_kernel_arg_
.
size
();
i
++
)
{
#if DEBUG_LOG
...
...
@@ -568,12 +752,25 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
{
throw
std
::
runtime_error
(
"wrong! not all gemm has_main_k_block_loop"
);
}
grouped_gemm_kernel_args
.
push_back
(
GroupedGemmKernelArgument
<
NumDTensor
>
{
arg
.
gemm_desc_kernel_arg_
[
i
].
a_ptr_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
b_ptr_
,
{},
arg
.
gemm_desc_kernel_arg_
[
i
].
e_ptr_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
M_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
N_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
K_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideA_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideB_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideDs_
,
arg
.
gemm_desc_kernel_arg_
[
i
].
StrideE_
});
}
hipGetErrorString
(
hipMemcpyWithStream
(
arg
.
p_workspace_
,
arg
.
gemm_desc
_kernel_arg
_
.
data
(),
arg
.
gemm_desc
_kernel_arg
_
.
size
()
*
sizeof
(
G
emmBiasTransKernelArg
),
grouped_gemm
_kernel_arg
s
.
data
(),
grouped_gemm
_kernel_arg
s
.
size
()
*
sizeof
(
G
roupedGemmKernelArgument
<
NumDTensor
>
),
hipMemcpyHostToDevice
,
stream_config
.
stream_id_
));
...
...
@@ -581,9 +778,14 @@ struct DeviceGroupedGemm_Xdl : public DeviceGroupedGemm<ALayout,
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop_
)
{
const
auto
kernel
=
kernel_grouped_gemm_xdl
<
GridwiseGemm
,
G
emmBiasTransKernelArg
,
G
roupedGemmKernelArgument
<
NumDTensor
>
,
GemmSpec
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
Block2ETileMap
,
GroupedGemmBlock2ETileMap
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp
View file @
fee2002c
...
...
@@ -425,6 +425,8 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
Number
<
NumDTensor
>
{});
}
__device__
__host__
static
constexpr
auto
GetMPerBlock
()
{
return
MPerBlock
;
}
template
<
bool
HasMainKBlockLoop
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
...
...
@@ -868,10 +870,10 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
typename EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock,
#endif
typename
Block2ETileMap
>
__device__
static
void
Run
(
const
ABDataType
*
__restrict__
p_a_grid
,
const
ABDataType
*
__restrict__
p_b_grid
,
__device__
static
void
Run
(
const
void
*
__restrict__
p_a_grid
_
,
const
void
*
__restrict__
p_b_grid
_
,
DsGridPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_e_grid
_
,
void
*
__restrict__
p_shared
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
...
@@ -881,7 +883,7 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
const
index_t
K
,
const
index_t
StrideA
,
const
index_t
StrideB
,
const
index_t
StrideDs
[]
,
const
std
::
array
<
index_t
,
NumDTensor
>
StrideDs
,
const
index_t
StrideE
,
#if 0
const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
...
...
@@ -893,6 +895,9 @@ struct GridwiseGemmMultipleD_xdl_cshuffle
#endif
const
Block2ETileMap
&
block_2_etile_map
)
{
const
auto
p_a_grid
=
reinterpret_cast
<
const
ABDataType
*>
(
p_a_grid_
);
const
auto
p_b_grid
=
reinterpret_cast
<
const
ABDataType
*>
(
p_b_grid_
);
const
auto
p_e_grid
=
reinterpret_cast
<
EDataType
*>
(
p_e_grid_
);
// tensor descriptors for problem definiton
const
auto
a_grid_desc_m_k
=
MakeAGridDescriptor_M_K
<
ALayout
,
GemmSpec
>
(
M
,
K
,
StrideA
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment