Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0cf90eaf
Commit
0cf90eaf
authored
May 04, 2023
by
Po-Yen, Chen
Browse files
Remove unnesscary type parameters
parent
670ce6b9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
472 additions
and
118 deletions
+472
-118
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
...or_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
+100
-57
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+372
-61
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_xdl_cshuffle.hpp
View file @
0cf90eaf
...
...
@@ -299,6 +299,9 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
<
ALayout
,
BLayout
,
CLayout
,
ADataType
,
// TODO: distinguish A/B datatype
GemmAccDataType
,
CShuffleDataType
,
...
...
@@ -308,9 +311,6 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
CElementwiseOperation
,
GemmSpec
,
InMemoryDataOperationEnum
::
Set
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
NumGemmKPrefetchStage
,
BlockSize
,
MPerBlock
,
...
...
@@ -363,29 +363,89 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
M
{
M_
},
N
{
N_
},
K
{
K_
},
StrideA
{
StrideA_
},
StrideB
{
StrideB_
},
StrideC
{
StrideC_
},
MPadded
{
GridwiseGemm
::
CalculateMPadded
(
M_
)},
NPadded
{
GridwiseGemm
::
CalculateNPadded
(
N_
)},
KPadded
{
GridwiseGemm
::
CalculateKPadded
(
K_
)},
AK0
{
GridwiseGemm
::
CalculateAK0
(
K_
)},
BK0
{
GridwiseGemm
::
CalculateBK0
(
K_
)},
a_grid_desc_ak0_m_ak1
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
StrideA_
,
GridwiseGemm
::
CalculateAK0
(
K_
))},
GridwiseGemm
::
MakeAGridDescriptor_AK0_M_AK1
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
StrideA_
,
GridwiseGemm
::
CalculateAK0
(
K_
))},
b_grid_desc_bk0_n_bk1
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideB_
,
GridwiseGemm
::
CalculateBK0
(
K_
))},
c_grid_desc_m_n
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideC_
)}
GridwiseGemm
::
MakeBGridDescriptor_BK0_N_BK1
(
K_
,
GridwiseGemm
::
CalculateKPadded
(
K_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideB_
,
GridwiseGemm
::
CalculateBK0
(
K_
))},
c_grid_desc_m_n
{
GridwiseGemm
::
MakeCGridDescriptor_M_N
(
M_
,
GridwiseGemm
::
CalculateMPadded
(
M_
),
N_
,
GridwiseGemm
::
CalculateNPadded
(
N_
),
StrideC_
)}
{
}
__host__
__device__
Argument
(
const
Argument
&
)
=
default
;
__host__
__device__
void
Print
()
const
{
printf
(
"arg {M: %d, N: %d, K: %d, "
"SA: %d, SB: %d, SC: %d, "
"MP: %d, NP: %d, KP: %d, "
"AK0: %d, BK0: %d}
\n
"
,
M
,
N
,
K
,
StrideA
,
StrideB
,
StrideC
,
MPadded
,
NPadded
,
KPadded
,
AK0
,
BK0
);
// std::cout << "arg {"
// << "M:" << M << ", "
// << "N:" << N << ", "
// << "K:" << K << ", "
// << "SA:" << StrideA << ", "
// << "SB:" << StrideB << ", "
// << "SC:" << StrideC << ", "
// << "MP:" << MPadded << ", "
// << "NP:" << NPadded << ", "
// << "KP:" << KPadded << ", "
// << "AK0:" << AK0 << ", "
// << "BK0:" << BK0 << "}" << std::endl;
}
__host__
__device__
Argument
(
const
Argument
&
other
)
:
p_a_grid
{
other
.
p_a_grid
},
p_b_grid
{
other
.
p_b_grid
},
p_c_grid
{
other
.
p_c_grid
},
M
{
other
.
M
},
N
{
other
.
N
},
K
{
other
.
K
},
StrideA
{
other
.
StrideA
},
StrideB
{
other
.
StrideB
},
StrideC
{
other
.
StrideC
},
MPadded
{
other
.
MPadded
},
NPadded
{
other
.
NPadded
},
KPadded
{
other
.
KPadded
},
AK0
{
other
.
AK0
},
BK0
{
other
.
BK0
},
a_grid_desc_ak0_m_ak1
{
other
.
a_grid_desc_ak0_m_ak1
},
b_grid_desc_bk0_n_bk1
{
other
.
b_grid_desc_bk0_n_bk1
},
c_grid_desc_m_n
{
other
.
c_grid_desc_m_n
}
{
}
__host__
__device__
~
Argument
()
override
{}
...
...
@@ -396,6 +456,14 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
index_t
M
;
index_t
N
;
index_t
K
;
index_t
StrideA
;
index_t
StrideB
;
index_t
StrideC
;
index_t
MPadded
;
index_t
NPadded
;
index_t
KPadded
;
index_t
AK0
;
index_t
BK0
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
;
CGridDesc_M_N
c_grid_desc_m_n
;
...
...
@@ -406,28 +474,16 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
{
using
Argument
=
DeviceOp
::
Argument
;
void
Print
(
const
Argument
&
karg
)
{
karg
.
Print
();
}
float
Run
(
const
Argument
&
karg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if DEBUG_LOG
if
(
stream_config
.
log_level_
>
0
)
{
// std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I1) << ", "
// << karg.a_grid_desc_ak0_m_ak1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.b_grid_desc_bk0_n_bk1_{"
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I0) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I1) << ", "
// << karg.b_grid_desc_bk0_n_bk1_.GetLength(I2) << "}" << std::endl;
// std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ",
// "
// << karg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
Print
(
karg
);
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
.
a_grid_desc_ak0_m_ak1
,
karg
.
b_grid_desc_bk0_n_bk1
,
karg
.
c_grid_desc_m_n
))
if
(
!
GridwiseGemm
::
CheckValidity
(
karg
))
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
}
...
...
@@ -441,15 +497,17 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
Argument
,
true
>
;
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
Argument
,
true
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
}
else
{
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
Argument
,
false
>
;
ave_time
=
launch_and_time_kernel
(
const
auto
kernel
=
kernel_gemm_xdl_cshuffle_v1_simplified
<
GridwiseGemm
,
Argument
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
gdx
,
gdy
,
gdz
),
dim3
(
BlockSize
),
0
,
karg
);
}
...
...
@@ -472,22 +530,7 @@ struct DeviceGemm_Xdl_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
karg
)
{
if
(
!
(
ck
::
get_device_name
()
==
"gfx908"
||
ck
::
get_device_name
()
==
"gfx90a"
))
{
return
false
;
}
if
((
karg
.
K
%
AK1
!=
0
||
karg
.
K
%
BK1
!=
0
)
&&
!
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
||
GemmSpec
==
GemmSpecialization
::
KPadding
))
{
return
false
;
}
return
GridwiseGemm
::
CheckValidity
(
karg
.
a_grid_desc_ak0_m_ak1
,
karg
.
b_grid_desc_bk0_n_bk1
,
karg
.
c_grid_desc_m_n
);
return
GridwiseGemm
::
CheckValidity
(
karg
);
}
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
0cf90eaf
...
...
@@ -22,7 +22,7 @@ __global__ void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdl_cshuffle_v1
(
const
Argument
karg
)
kernel_gemm_xdl_cshuffle_v1
_simplified
(
Argument
karg
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -33,7 +33,10 @@ __global__ void
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
FloatAB
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatC
,
...
...
@@ -42,9 +45,6 @@ template <typename FloatAB,
typename
CElementwiseOperation
,
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
,
InMemoryDataOperationEnum
CGlobalMemoryDataOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDesc_M_N
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
...
...
@@ -90,10 +90,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static
constexpr
auto
I7
=
Number
<
7
>
{};
// K1 should be Number<...>
static
constexpr
auto
AK0
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
=
Number
<
BK1Value
>
{};
static
constexpr
auto
AK0
_
=
Number
<
KPerBlock
/
AK1Value
>
{};
static
constexpr
auto
BK0
_
=
Number
<
KPerBlock
/
BK1Value
>
{};
static
constexpr
auto
AK1
_
=
Number
<
AK1Value
>
{};
static
constexpr
auto
BK1
_
=
Number
<
BK1Value
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -102,29 +102,28 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
#endif
#define INTEGER_DIVIDE_CEIL(x, y) (((x) + (y)-1) / (y))
__host__
__device__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
__host__
static
auto
CalculateGridSize
(
index_t
M
,
index_t
N
)
{
// reference the implementation of class 'BlockToCTileMap_M00_N0_M01Adapt'
return
std
::
make_tuple
(
Block2CTileMap
::
CalculateGridSize
(
M
,
N
),
1
,
1
);
}
__host__
__device__
static
auto
CalculateMPadded
(
index_t
M
)
__host__
static
auto
CalculateMPadded
(
index_t
M
)
{
return
INTEGER_DIVIDE_CEIL
(
M
,
MPerBlock
)
*
MPerBlock
;
}
__host__
__device__
static
auto
CalculateNPadded
(
index_t
N
)
__host__
static
auto
CalculateNPadded
(
index_t
N
)
{
return
INTEGER_DIVIDE_CEIL
(
N
,
NPerBlock
)
*
NPerBlock
;
}
__host__
__device__
static
auto
CalculateKPadded
(
index_t
K
)
__host__
static
auto
CalculateKPadded
(
index_t
K
)
{
return
INTEGER_DIVIDE_CEIL
(
K
,
KPerBlock
)
*
KPerBlock
;
}
#undef INTEGER_DIVIDE_CEIL
__host__
__device__
static
auto
CalculateAK0
(
index_t
K
)
__host__
static
auto
CalculateAK0
(
index_t
K
)
{
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
...
...
@@ -133,19 +132,19 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
assert
(
CalculateKPadded
(
K
)
%
AK1
==
0
);
assert
(
CalculateKPadded
(
K
)
%
AK1
Value
==
0
);
return
CalculateKPadded
(
K
)
/
AK1
;
return
CalculateKPadded
(
K
)
/
AK1
Value
;
}
else
{
assert
(
K
%
AK1
==
0
);
assert
(
K
%
AK1
Value
==
0
);
return
K
/
AK1
;
return
K
/
AK1
Value
;
}
}
__host__
__device__
static
auto
CalculateBK0
(
index_t
K
)
__host__
static
auto
CalculateBK0
(
index_t
K
)
{
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
...
...
@@ -154,15 +153,232 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
assert
(
CalculateKPadded
(
K
)
%
BK1
==
0
);
assert
(
CalculateKPadded
(
K
)
%
BK1
Value
==
0
);
return
CalculateKPadded
(
K
)
/
BK1
;
return
CalculateKPadded
(
K
)
/
BK1
Value
;
}
else
{
assert
(
K
%
BK1
==
0
);
assert
(
K
%
BK1
Value
==
0
);
return
K
/
BK1
;
return
K
/
BK1Value
;
}
}
__host__
__device__
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
M
,
index_t
MPad
,
index_t
K
,
index_t
KPad
,
index_t
StrideA
,
index_t
AK0
)
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both M and K
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_right_pad_transform
(
M
,
MPad
-
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad K, but not M
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1Value
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
__host__
__device__
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
K
,
index_t
KPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideB
,
index_t
BK0
)
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
N
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
N
,
NPad
-
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad K, but not N
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
N
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1Value
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
__host__
__device__
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
MPad
,
index_t
N
,
index_t
NPad
,
index_t
StrideC
)
{
const
auto
c_grid_desc_mraw_nraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
StrideC
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
CLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
StrideC
));
}
}();
using
GemmSpecialization
=
tensor_operation
::
device
::
GemmSpecialization
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M and N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
// pad M, but not N
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_right_pad_transform
(
M
,
MPad
-
M
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
// pad N, but not M
return
transform_tensor_descriptor
(
c_grid_desc_mraw_nraw
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
N
,
NPad
-
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}
else
{
// not pad M or N
return
c_grid_desc_mraw_nraw
;
}
}
...
...
@@ -174,16 +390,16 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
// A matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
AK0
,
Number
<
MPerBlock
>
{},
AK1
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
,
AK1
,
I1
));
make_tuple
(
AK0
_
,
Number
<
MPerBlock
>
{},
AK1
_
),
make_tuple
(
Number
<
MPerBlock
+
ABlockLdsExtraM
>
{}
*
AK1
_
,
AK1
_
,
I1
));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
()
{
// B matrix in LDS memory, dst of blockwise copy
return
make_naive_tensor_descriptor
(
make_tuple
(
BK0
,
Number
<
NPerBlock
>
{},
BK1
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
,
BK1
,
I1
));
make_tuple
(
BK0
_
,
Number
<
NPerBlock
>
{},
BK1
_
),
make_tuple
(
Number
<
NPerBlock
+
BBlockLdsExtraN
>
{}
*
BK1
_
,
BK1
_
,
I1
));
}
__host__
__device__
static
constexpr
auto
...
...
@@ -209,7 +425,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
constexpr
auto
b_block_desc_bk0_n_bk1
=
GetBBlockDescriptor_BK0PerBlock_NPerBlock_BK1
();
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
_
,
BK1
_
);
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_ak0_m_ak1
.
GetElementSpaceSize
(),
max_lds_align
);
...
...
@@ -230,27 +446,82 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc_AK0_M_AK1
&
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
&
b_grid_desc_bk0_n_bk1
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
template
<
typename
Argument
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
Argument
&
karg
)
{
static_assert
((
MPerBlock
%
(
MPerXdl
*
MXdlPerWave
)
==
0
)
&&
(
NPerBlock
%
(
NXdlPerWave
*
NPerXdl
))
==
0
,
"Invalid tuning param!"
);
const
auto
M
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I1
);
const
auto
N
=
b_grid_desc_bk0_n_bk1
.
GetLength
(
I1
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
M
%
MPerBlock
==
0
))
{
return
false
;
}
}
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
return
false
;
if
constexpr
(
!
(
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
NKPadding
||
GemmSpec
==
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
))
{
if
(
!
(
karg
.
N
%
NPerBlock
==
0
))
{
return
false
;
}
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
))
return
false
;
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
if
(
karg
.
K
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
M
%
ABlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
if
(
karg
.
N
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
K
%
BBlockTransferSrcScalarPerVector
!=
0
)
{
return
false
;
}
}
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
CLayout
>::
value
)
{
if
(
karg
.
N
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
else
{
if
(
karg
.
M
%
CShuffleBlockTransferScalarPerVector_NPerBlock
!=
0
)
{
return
false
;
}
}
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K
/
KPerBlock
;
const
auto
num_k_loop
=
(
CalculateAK0
(
karg
.
K
)
*
AK1Value
)
/
KPerBlock
;
if
(
!
GridwiseGemmPipe
::
IsSupported
(
num_k_loop
))
{
...
...
@@ -268,8 +539,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return
GridwiseGemmPipe
::
CalculateHasMainLoop
(
num_loop
);
}
template
<
typename
CGridDesc
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
const
CGridDesc
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
...
...
@@ -288,28 +560,66 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__
__device__
static
constexpr
auto
MakeBlock2CTileMap
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
using
Block2CTileMap
=
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
>
;
__host__
__device__
static
void
print_bytes
(
const
uint8_t
*
memory
,
std
::
size_t
size
)
{
return
BlockToCTileMap_M00_N0_M01Adapt
<
MPerBlock
,
NPerBlock
,
CGridDesc_M_N
>
(
c_grid_desc_m_n
);
(
void
)
memory
;
(
void
)
size
;
for
(
std
::
size_t
idx
=
0
;
idx
<
size
;
++
idx
)
{
if
(
idx
%
10
==
0
)
{
printf
(
"
\n
"
);
}
printf
(
"0x%02X "
,
static_cast
<
unsigned
>
(
memory
[
idx
]));
}
printf
(
"
\n
"
);
}
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}))
>
;
template
<
typename
T
>
__host__
__device__
static
void
print_bytes
(
const
T
&
obj
)
{
uint8_t
memory
[
sizeof
(
T
)];
memcpy
(
memory
,
&
obj
,
sizeof
(
T
));
using
Block2CTileMap
=
remove_cvref_t
<
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{}))
>
;
print_bytes
(
memory
,
sizeof
(
T
));
}
template
<
bool
HasMainKBlockLoop
,
typename
Argument
>
__device__
static
void
Run
(
const
Argument
karg
,
void
*
__restrict__
p_shared
)
__device__
static
void
Run
(
const
Argument
&
karg
,
void
*
__restrict__
p_shared
)
{
const
FloatAB
*
p_a_grid
=
karg
.
p_a_grid
;
const
FloatAB
*
p_b_grid
=
karg
.
p_b_grid
;
FloatC
*
p_c_grid
=
karg
.
p_c_grid
;
const
auto
&
a_grid_desc_ak0_m_ak1
=
karg
.
a_grid_desc_ak0_m_ak1
;
const
auto
&
b_grid_desc_bk0_n_bk1
=
karg
.
b_grid_desc_bk0_n_bk1
;
const
auto
&
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
const
auto
a_grid_desc_ak0_m_ak1
=
karg
.
a_grid_desc_ak0_m_ak1
;
const
auto
b_grid_desc_bk0_n_bk1
=
karg
.
b_grid_desc_bk0_n_bk1
;
const
auto
c_grid_desc_m_n
=
karg
.
c_grid_desc_m_n
;
// const auto a_grid_desc_ak0_m_ak1 = MakeAGridDescriptor_AK0_M_AK1(karg.M,
// karg.MPadded,
// karg.K,
// karg.KPadded,
// karg.StrideA,
// karg.AK0);
// const auto b_grid_desc_bk0_n_bk1 = MakeBGridDescriptor_BK0_N_BK1(karg.K,
// karg.KPadded,
// karg.N,
// karg.NPadded,
// karg.StrideB,
// karg.BK0);
// const auto c_grid_desc_m_n = MakeCGridDescriptor_M_N(karg.M,
// karg.MPadded,
// karg.N,
// karg.NPadded,
// karg.StrideC);
// if (blockIdx.x == 0 && threadIdx.x == 0) {
// print_bytes(a_grid_desc_ak0_m_ak1);
// }
const
auto
c_grid_desc_mblock_mperblock_nblock_nperblock
=
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n
);
...
...
@@ -326,7 +636,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
const
CElementwiseOperation
c_element_op
{};
// divide block work by [M, N]
const
auto
block_2_ctile_map
=
Make
Block2CTileMap
(
c_grid_desc_m_n
)
;
const
auto
block_2_ctile_map
=
Block2CTileMap
{
karg
.
M
,
karg
.
N
}
;
const
auto
block_work_idx
=
block_2_ctile_map
.
CalculateBottomIndex
(
make_multi_index
(
get_block_1d_id
()));
...
...
@@ -347,7 +657,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
,
BK1
);
constexpr
auto
max_lds_align
=
math
::
lcm
(
AK1
_
,
BK1
_
);
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_ak0_m_ak1
=
GetABlockDescriptor_AK0PerBlock_MPerBlock_AK1
();
...
...
@@ -361,7 +671,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
AK0
,
MPerBlock
,
AK1
>
,
Sequence
<
AK0
_
,
MPerBlock
,
AK1
_
>
,
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -392,7 +702,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
Sequence
<
BK0
,
NPerBlock
,
BK1
>
,
Sequence
<
BK0
_
,
NPerBlock
,
BK1
_
>
,
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
BBlockTransferThreadClusterArrangeOrder
,
FloatAB
,
...
...
@@ -424,8 +734,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1
,
BK1
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
constexpr
index_t
KPack
=
math
::
max
(
math
::
lcm
(
AK1_
,
BK1_
),
MfmaSelector
<
FloatAB
,
MPerXdl
,
NPerXdl
>::
selected_mfma
.
k_per_blk
);
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_Selector
<
BlockSize
,
...
...
@@ -453,8 +764,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
AK1
_
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
/
BK1
_
,
0
,
0
);
// gridwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment