Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
9e3825a2
Commit
9e3825a2
authored
Nov 01, 2023
by
Paul
Browse files
Format
parent
028bb4b6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
15 deletions
+26
-15
host/test/gemm_multiple_d.cpp
host/test/gemm_multiple_d.cpp
+8
-5
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
...n/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
+9
-4
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
+9
-6
No files found.
host/test/gemm_multiple_d.cpp
View file @
9e3825a2
...
@@ -32,11 +32,14 @@ extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_
...
@@ -32,11 +32,14 @@ extern "C" __global__ void f(const ck::half_t* a, const ck::half_t* b, ck::half_
static_assert(desc.IsValid(), "Invalid ck gemm.");
static_assert(desc.IsValid(), "Invalid ck gemm.");
${template}::Run(desc,
if constexpr(desc.IsValid())
a,
{
b,
${template}::Run(desc,
ck::make_tuple(),
a,
c);
b,
ck::make_tuple(),
c);
}
}
}
)__ck__"
;
)__ck__"
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle.hpp
View file @
9e3825a2
...
@@ -503,7 +503,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -503,7 +503,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
// check vector load/store
// check vector load/store
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
// check vector load of A
// check vector load of A
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
ABlockTransferSrcVectorDim
==
2
)
{
{
...
@@ -524,7 +523,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -524,7 +523,6 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
{
{
return
false
;
return
false
;
}
}
// check vector laod of B
// check vector laod of B
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
if
constexpr
(
is_same_v
<
BLayout
,
Col
>
&&
BBlockTransferSrcVectorDim
==
2
)
{
{
...
@@ -723,6 +721,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -723,6 +721,13 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
[
&
](
auto
d
)
constexpr
{
return
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
d
);
},
DsDesc
{});
DsDesc
{});
}
}
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadBDescriptor_N_K
(
BDesc
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
ds_tuple
())
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
DeviceOp
::
matrix_padder
.
PadCDescriptor_M_N
(
EDesc
{}))
>
;
using
AGridDesc_AK0_M_AK1
=
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
DeviceOp
::
matrix_padder
.
PadADescriptor_M_K
(
ADesc
{})))
>
;
...
@@ -806,7 +811,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -806,7 +811,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
constexpr
bool
IsValid
()
const
constexpr
bool
IsValid
()
const
{
{
return
GridwiseGemm
::
CheckValidity
(
(
a_grid_desc_m_k
)
,
return
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k
,
b_grid_desc_n_k
,
b_grid_desc_n_k
,
ds_grid_desc_m_n
,
ds_grid_desc_m_n
,
e_grid_desc_m_n
,
e_grid_desc_m_n
,
...
@@ -844,7 +849,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -844,7 +849,7 @@ struct DeviceGemmMultipleD_Xdl_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType
*
__restrict__
p_e_grid
)
EDataType
*
__restrict__
p_e_grid
)
{
{
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared_block
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
assert
(
desc
.
is_v
alid
);
assert
(
desc
.
IsV
alid
()
);
if
(
desc
.
has_main_k_block_loop
)
if
(
desc
.
has_main_k_block_loop
)
{
{
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
GridwiseGemm
::
template
Run
<
true
>(
p_a_grid
,
...
...
include/ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp
View file @
9e3825a2
...
@@ -161,7 +161,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
...
@@ -161,7 +161,7 @@ struct BlockToCTileMap_M00_N0_M01Adapt<MPerBlock, NPerBlock, void>
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
@@ -317,7 +317,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
...
@@ -317,7 +317,10 @@ struct BlockToCTileMap_KSplit_M00_N0_M01Adapt
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
return
true
;
// always valid provided that user gets grid size from CalculateGridSize()
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
return
true
;
}
private:
private:
index_t
M01_
;
index_t
M01_
;
...
@@ -375,7 +378,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
...
@@ -375,7 +378,7 @@ struct BlockToCTileMap_M00_N00_M01_N01
return
true
;
return
true
;
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
return
true
;
// validity check moved to kernel
...
@@ -487,7 +490,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
...
@@ -487,7 +490,7 @@ struct BlockToCTileMap_KSplit_M00_N00_M01_N01
return
true
;
return
true
;
}
}
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
if
constexpr
(
DeviceCTileIndexCheck
)
if
constexpr
(
DeviceCTileIndexCheck
)
return
true
;
// validity check moved to kernel
return
true
;
// validity check moved to kernel
...
@@ -611,7 +614,7 @@ struct OffsettedBlockToCTileMap
...
@@ -611,7 +614,7 @@ struct OffsettedBlockToCTileMap
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
const
{
{
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
return
block_to_ctile_map_
.
CheckValidity
(
c_grid_desc_m_n
);
}
}
...
@@ -668,7 +671,7 @@ struct BlockToCTileMap_3DGrid_KSplit
...
@@ -668,7 +671,7 @@ struct BlockToCTileMap_3DGrid_KSplit
}
}
template
<
typename
CGridDesc_M_N
>
template
<
typename
CGridDesc_M_N
>
__host__
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
__host__
constexpr
bool
CheckValidity
(
const
CGridDesc_M_N
&
/* c_grid_desc_m_n */
)
const
{
{
return
true
;
return
true
;
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment