Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8e20f747
"vscode:/vscode.git/clone" did not exist on "7482178162b779506a54538f2cf2565c8b88c597"
Commit
8e20f747
authored
Oct 31, 2023
by
Paul
Browse files
Format
parent
6e4a1075
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
253 additions
and
81 deletions
+253
-81
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+233
-63
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+20
-18
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
8e20f747
...
@@ -498,94 +498,99 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -498,94 +498,99 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
};
};
static
bool
IsSupported
Argument
(
const
Argument
&
arg
)
static
constexpr
bool
IsSupported
(
index_t
MRaw_
,
index_t
NRaw_
,
index_t
KRaw_
)
{
{
if
(
!
ck
::
is_xdl_supported
())
{
return
false
;
}
// check vector load/store
// check vector load/store
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
if
(
KRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
return
false
;
if
(
arg
.
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
}
else
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
ABlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
MRaw_
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
else
{
return
false
;
}
// check vector laod of B
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
{
if
(
arg
.
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
if
(
KRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
{
// FIXME: not rigorous
return
false
;
if
(
arg
.
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
}
else
}
else
if
constexpr
(
is_same_v
<
BLayout
,
Row
>
&&
BBlockTransferSrcVectorDim
==
1
)
{
// FIXME: not rigorous
if
(
NRaw_
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
else
{
return
false
;
}
// check vector load of Ds
// check vector load of Ds
// only support RowMajor for now
// only support RowMajor for now
bool
all_valid
=
true
;
bool
all_valid
=
true
;
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
all_valid
=
false
;
}
});
if
(
!
all_valid
)
if
constexpr
(
!
is_same_v
<
DLayout
,
Row
>
)
{
{
return
false
;
all_valid
=
false
;
}
}
});
// check vector store of E
if
(
!
all_valid
)
// only support RowMajor for now
{
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
return
false
;
{
}
if
(
arg
.
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
// check vector store of E
return
false
;
// only support RowMajor for now
}
if
constexpr
(
is_same_v
<
ELayout
,
Row
>
)
}
{
else
if
(
NRaw_
%
CDEBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
{
return
false
;
return
false
;
}
}
}
}
else
{
return
false
;
}
return
true
;
}
return
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
))
{
return
false
;
}
return
IsSupported
(
arg
.
MRaw_
,
arg
.
NRaw_
,
arg
.
KRaw_
)
and
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
...
@@ -708,6 +713,171 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -708,6 +713,171 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
struct
Descriptor
{
static
constexpr
auto
ds_tuple
()
{
return
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
DsDesc
{});
}
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{})))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_tuple
()))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{})))
>
;
// tensor descriptors for problem definiton
AGridDesc_M_K
a_grid_desc_m_k
;
BGridDesc_N_K
b_grid_desc_n_k
;
DsGridDesc_M_N
ds_grid_desc_m_n
;
EGridDesc_M_N
e_grid_desc_m_n
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map
;
// element-wise op
AElementwiseOperation
a_element_op
;
BElementwiseOperation
b_element_op
;
CDEElementwiseOperation
cde_element_op
;
// for checking vector load/store
index_t
MRaw
;
index_t
NRaw
;
index_t
KRaw
;
bool
has_main_k_block_loop
=
true
;
constexpr
Descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op_
,
BElementwiseOperation
b_element_op_
,
CDEElementwiseOperation
cde_element_op_
)
:
a_grid_desc_m_k
{
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
a
)},
b_grid_desc_n_k
{
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
b
)},
ds_grid_desc_m_n
{
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
)},
e_grid_desc_m_n
{
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
e
)},
a_grid_desc_ak0_m_ak1
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k
)},
b_grid_desc_bk0_n_bk1
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
transform_tuples
(
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
ds
))},
e_grid_desc_mblock_mperblock_nblock_nperblock
{
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n
)},
block_2_etile_map
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n
)},
has_main_k_block_loop
{
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
))},
a_element_op
{
a_element_op_
},
b_element_op
{
b_element_op_
},
cde_element_op
{
cde_element_op_
},
MRaw
{
e
.
GetLength
(
I0
)},
NRaw
{
e
.
GetLength
(
I1
)},
KRaw
{
a
.
GetLength
(
I1
)}
{
}
constexpr
bool
IsValid
()
const
{
return
GridwiseGemm
::
CheckValidity
((
a_grid_desc_m_k
),
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
block_2_etile_map
)
and
IsSupported
(
MRaw
,
NRaw
,
KRaw
);
}
constexpr
index_t
GetBlockSize
()
const
{
return
BlockSize
;
}
constexpr
index_t
GetGridSize
()
const
{
return
block_2_etile_map
.
CalculateGridSize
(
e_grid_desc_m_n
);
}
};
template
<
class
ADesc
,
class
BDesc
,
class
DsDesc
,
class
EDesc
>
static
constexpr
auto
make_descriptor
(
ADesc
a
,
BDesc
b
,
DsDesc
ds
,
EDesc
e
,
AElementwiseOperation
a_element_op
=
AElementwiseOperation
{},
BElementwiseOperation
b_element_op
=
BElementwiseOperation
{},
CDEElementwiseOperation
cde_element_op
=
CDEElementwiseOperation
{})
{
return
Descriptor
<
ADesc
,
BDesc
,
DsDesc
,
EDesc
>
(
a
,
b
,
ds
,
e
,
a_element_op
,
b_element_op
,
cde_element_op
);
}
template
<
class
Desc
,
class
DsPointer
>
__device__
static
void
Run
(
const
Desc
&
desc
,
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
)
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
is_valid
);
if
(
desc
.
has_main_k_block_loop
)
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
false
>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
desc
.
a_element_op
,
desc
.
b_element_op
,
desc
.
cde_element_op
,
desc
.
a_grid_desc_ak0_m_ak1
,
desc
.
b_grid_desc_bk0_n_bk1
,
desc
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
desc
.
block_2_etile_map
);
}
}
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
8e20f747
...
@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -24,10 +24,10 @@ struct BlockToCTileMap_M00_N0_M01
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
1
)
index_t
M01
=
1
)
:
M01_
(
M01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
))
:
M01_
(
M01
),
underlying_map_
(
GetBlockToCTileMap
(
c_grid_desc_m_n
,
M01
))
{
{
}
}
...
@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -51,8 +51,8 @@ struct BlockToCTileMap_M00_N0_M01
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
c_tile_idx
,
const
CTileDim
&
c_tile_dim
)
const
const
CTileDim
&
c_tile_dim
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
DefaultValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
return
DefaultValidCTileIndex
(
c_tile_idx
,
c_tile_dim
);
...
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
...
@@ -60,7 +60,7 @@ struct BlockToCTileMap_M00_N0_M01
return
true
;
return
true
;
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
return
true
;
// validity check moved to kernel
...
@@ -120,25 +120,27 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -120,25 +120,27 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
__host__
__device__
const
expr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
default
;
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
operator
=
(
const
BlockToCTileMap_M00_N0_M01Adapt
&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
&
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
&
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
operator
=
(
BlockToCTileMap_M00_N0_M01Adapt
&&
)
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
index_t
M
,
index_t
N
,
index_t
M01
=
8
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
:
M_
(
M
),
N_
(
N
),
M01_
(
M01
)
{
{
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
index_t
M01
=
8
)
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
BlockToCTileMap_M00_N0_M01Adapt
(
:
BlockToCTileMap_M00_N0_M01Adapt
(
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
c_grid_desc_m_n
.
GetLength
(
I0
),
c_grid_desc_m_n
.
GetLength
(
I1
),
M01
)
{
{
...
@@ -232,8 +234,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -232,8 +234,8 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
const
CTileDim
&
/* c_tile_dim */
)
const
{
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment