Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
6a97c046
Commit
6a97c046
authored
Feb 14, 2023
by
Alan Turner
Browse files
Add gpu-invoker to device_gemm_multiple_d_xdl_cshuffle
parent
1b62bfaa
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
172 additions
and
5 deletions
+172
-5
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
...ration/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
+167
-0
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+5
-5
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
6a97c046
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
//#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
...
@@ -89,6 +90,70 @@ namespace ck {
...
@@ -89,6 +90,70 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
MPerBlock
,
index_t
NPerBlock
,
typename
CGridDesc_M_N
>
struct
BlockToCTileMap_M00_N0_M01Adapt2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt2
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
}
__host__
__device__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
index_t
grid_size
=
M0
*
N0
;
return
grid_size
;
}
template
<
typename
TopIdx
>
__host__
__device__
constexpr
auto
CalculateBottomIndex
(
const
TopIdx
&
idx_top
)
const
{
auto
block_1d_id
=
idx_top
[
I0
];
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n_
.
GetLength
(
I1
),
NPerBlock
);
block_1d_id
=
block_1d_id
%
(
M0
*
N0
);
// swallow batch index
index_t
idx_N0
=
block_1d_id
%
N0
;
index_t
idx_M0
=
block_1d_id
/
N0
;
const
auto
M01_adapt
=
(
idx_M0
<
M0
-
M0
%
M01_
)
?
M01_
:
M0
%
M01_
;
index_t
idx_M00
=
idx_M0
/
M01_
;
index_t
idx_M01
=
idx_M0
%
M01_
;
index_t
idx_N0_M01_local
=
idx_N0
+
idx_M01
*
N0
;
return
make_tuple
(
idx_N0_M01_local
%
M01_adapt
+
idx_M00
*
M01_
,
idx_N0_M01_local
/
M01_adapt
);
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
index_t
M01_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
};
// GEMM:
// GEMM:
// input : A[M, K]
// input : A[M, K]
// input : B[N, K]
// input : B[N, K]
...
@@ -679,6 +744,108 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -679,6 +744,108 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
return
str
.
str
();
return
str
.
str
();
}
}
template
<
class
ADesc
,
class
BDesc
,
class
EDesc
,
class
...
DsDesc
>
struct
Descriptor
{
using
AGridDesc_M_K
=
decltype
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{}));
using
BGridDesc_N_K
=
decltype
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{}));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
make_tuple
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
DsDesc
{})...))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
__device__
constexpr
Descriptor
(
DsDesc
...
dsdesc
)
{
static_assert
(
GridwiseGemm
::
CheckValidity
(
AGridDesc_M_K
{},
BGridDesc_N_K
{},
DsGridDesc_M_N
{},
EGridDesc_M_N
{},
get_block_2_etile_map
()));
}
constexpr
auto
get_a_grid_desc_ak0_m_ak1
()
const
{
return
a_grid_desc_ak0_m_ak1
;
}
constexpr
auto
get_b_grid_desc_bk0_n_bk1
()
const
{
return
b_grid_desc_bk0_n_bk1
;
}
constexpr
auto
get_ds_grid_desc_mblock_mperblock_nblock_nperblock
()
const
{
return
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
}
constexpr
auto
get_e_grid_desc_mblock_mperblock_nblock_nperblock
()
const
{
return
e_grid_desc_mblock_mperblock_nblock_nperblock
;
}
static
constexpr
auto
get_block_2_etile_map
()
{
return
BlockToCTileMap_M00_N0_M01Adapt2
<
MPerBlock
,
NPerBlock
,
EGridDesc_M_N
>
(
EGridDesc_M_N
{});
}
constexpr
bool
has_main_k_block_loop
()
const
{
return
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
0
>
{})
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
ck
::
Number
<
2
>
{}));
}
private:
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock
;
};
template
<
class
ADesc
,
class
BDesc
,
class
EDesc
,
class
...
DsDesc
>
struct
GPU_Invoker
{
using
Descriptor
=
DeviceOp
::
Descriptor
<
ADesc
,
BDesc
,
EDesc
,
DsDesc
...
>
;
__device__
constexpr
GPU_Invoker
(
DsDesc
...
dsdesc
)
{}
template
<
class
DsPointer
>
__device__
static
void
run
(
const
ADataType
*
__restrict__
p_a_grid
,
const
BDataType
*
__restrict__
p_b_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
)
{
constexpr
Descriptor
desc
;
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
desc
.
has_main_k_block_loop
()>(
p_a_grid
,
p_b_grid
,
p_ds_grid
,
p_e_grid
,
p_shared_block
,
AElementwiseOperation
{},
BElementwiseOperation
{},
CDEElementwiseOperation
{},
desc
.
get_a_grid_desc_ak0_m_ak1
(),
desc
.
get_b_grid_desc_bk0_n_bk1
(),
desc
.
get_ds_grid_desc_mblock_mperblock_nblock_nperblock
(),
desc
.
get_e_grid_desc_mblock_mperblock_nblock_nperblock
(),
desc
.
get_block_2_etile_map
());
}
};
};
};
}
// namespace device
}
// namespace device
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
6a97c046
...
@@ -117,15 +117,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt
...
@@ -117,15 +117,15 @@ struct BlockToCTileMap_M00_N0_M01Adapt
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
()
=
default
;
__host__
__device__
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
__host__
__device__
constexpr
BlockToCTileMap_M00_N0_M01Adapt
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
index_t
M01
=
8
)
index_t
M01
=
8
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
:
M01_
(
M01
),
c_grid_desc_m_n_
(
c_grid_desc_m_n
)
{
{
}
}
__host__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
__device__
constexpr
index_t
CalculateGridSize
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
M0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I0
),
MPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
const
auto
N0
=
math
::
integer_divide_ceil
(
c_grid_desc_m_n
.
GetLength
(
I1
),
NPerBlock
);
...
@@ -159,13 +159,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt
...
@@ -159,13 +159,13 @@ struct BlockToCTileMap_M00_N0_M01Adapt
}
}
template
<
typename
CTileIdx
,
typename
CTileDim
>
template
<
typename
CTileIdx
,
typename
CTileDim
>
__host__
__device__
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
__host__
__device__
constexpr
bool
ValidCTileIndex
(
const
CTileIdx
&
/* c_tile_idx */
,
const
CTileDim
&
/* c_tile_dim */
)
const
const
CTileDim
&
/* c_tile_dim */
)
const
{
{
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
__device__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
private:
index_t
M01_
;
index_t
M01_
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment