Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
06f57782
Commit
06f57782
authored
May 05, 2022
by
Jianfeng yan
Browse files
regress to using 1 grid_desc
parent
308146e7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
24 additions
and
241 deletions
+24
-241
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+11
-104
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+13
-137
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
06f57782
...
@@ -29,8 +29,6 @@ template <typename GridwiseGemm,
...
@@ -29,8 +29,6 @@ template <typename GridwiseGemm,
typename
FloatC
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
BGridDesc_K0_N_K1
,
typename
AGridDesc_K0_M_K1_Tail
,
typename
BGridDesc_K0_N_K1_Tail
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
...
@@ -50,8 +48,8 @@ __global__ void
...
@@ -50,8 +48,8 @@ __global__ void
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
AGridDesc_K0_M_K1
_Tail
a_grid_desc_k0_m_k1_tail
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_tail
,
const
BGridDesc_K0_N_K1
_Tail
b_grid_desc_k0_n_k1_tail
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_tail
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
...
@@ -173,7 +171,7 @@ struct DeviceGemmXdlSplitK
...
@@ -173,7 +171,7 @@ struct DeviceGemmXdlSplitK
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
}
}
static
auto
MakeAGridDescriptor_K0_M_K1
_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
{
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPadded
/
K1
;
const
index_t
K0
=
KPadded
/
K1
;
...
@@ -217,7 +215,7 @@ struct DeviceGemmXdlSplitK
...
@@ -217,7 +215,7 @@ struct DeviceGemmXdlSplitK
}
}
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
{
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
...
@@ -262,87 +260,6 @@ struct DeviceGemmXdlSplitK
...
@@ -262,87 +260,6 @@ struct DeviceGemmXdlSplitK
}
}
}
}
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
// return MakeAGridDescriptor_K0_M_K1_Tail(M, K, StrideA);
assert
(
K
%
(
K1
*
K0PerBlock
)
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
// return MakeBGridDescriptor_K0_N_K1_Tail(K, N, StrideB);
assert
(
K
%
(
K1
*
K0PerBlock
)
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
M
,
index_t
N
,
index_t
StrideC
)
{
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
const
auto
c_grid_desc_m_n
=
[
&
]()
{
...
@@ -378,11 +295,9 @@ struct DeviceGemmXdlSplitK
...
@@ -378,11 +295,9 @@ struct DeviceGemmXdlSplitK
}
}
}
}
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
AGridDesc_K0_M_K1_Tail
=
decltype
(
MakeAGridDescriptor_K0_M_K1_Tail
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1_Tail
=
decltype
(
MakeBGridDescriptor_K0_N_K1_Tail
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
static
constexpr
auto
MakeBlock2CTileMap
(
index_t
batch_count
,
static
constexpr
auto
MakeBlock2CTileMap
(
index_t
batch_count
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
const
CGridDesc_M_N
&
c_grid_desc_m_n
,
...
@@ -543,9 +458,9 @@ struct DeviceGemmXdlSplitK
...
@@ -543,9 +458,9 @@ struct DeviceGemmXdlSplitK
has_tail_
=
true
;
has_tail_
=
true
;
const
auto
KTail
=
K
-
KSplitted
*
(
BatchCount_
-
1
);
const
auto
KTail
=
K
-
KSplitted
*
(
BatchCount_
-
1
);
a_grid_desc_k0_m_k1_tail_
=
a_grid_desc_k0_m_k1_tail_
=
DeviceGemmXdlSplitK
::
MakeAGridDescriptor_K0_M_K1
_Tail
(
M
,
KTail
,
StrideA
);
DeviceGemmXdlSplitK
::
MakeAGridDescriptor_K0_M_K1
(
M
,
KTail
,
StrideA
);
b_grid_desc_k0_n_k1_tail_
=
b_grid_desc_k0_n_k1_tail_
=
DeviceGemmXdlSplitK
::
MakeBGridDescriptor_K0_N_K1
_Tail
(
KTail
,
N
,
StrideB
);
DeviceGemmXdlSplitK
::
MakeBGridDescriptor_K0_N_K1
(
KTail
,
N
,
StrideB
);
is_valid
&=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_tail_
,
is_valid
&=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_tail_
,
b_grid_desc_k0_n_k1_tail_
,
b_grid_desc_k0_n_k1_tail_
,
...
@@ -597,8 +512,8 @@ struct DeviceGemmXdlSplitK
...
@@ -597,8 +512,8 @@ struct DeviceGemmXdlSplitK
bool
has_tail_
;
bool
has_tail_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
AGridDesc_K0_M_K1
_Tail
a_grid_desc_k0_m_k1_tail_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_tail_
;
BGridDesc_K0_N_K1
_Tail
b_grid_desc_k0_n_k1_tail_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_tail_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
ComputePtrOffsetOfStridedBatch
compute_ptr_offset_of_batch_
;
...
@@ -707,8 +622,6 @@ struct DeviceGemmXdlSplitK
...
@@ -707,8 +622,6 @@ struct DeviceGemmXdlSplitK
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -728,8 +641,6 @@ struct DeviceGemmXdlSplitK
...
@@ -728,8 +641,6 @@ struct DeviceGemmXdlSplitK
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -749,8 +660,6 @@ struct DeviceGemmXdlSplitK
...
@@ -749,8 +660,6 @@ struct DeviceGemmXdlSplitK
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
@@ -770,8 +679,6 @@ struct DeviceGemmXdlSplitK
...
@@ -770,8 +679,6 @@ struct DeviceGemmXdlSplitK
CDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
06f57782
...
@@ -33,8 +33,6 @@ template <typename GridwiseGemm,
...
@@ -33,8 +33,6 @@ template <typename GridwiseGemm,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
AGridDesc_AK0_M_AK1_Tail
,
typename
BGridDesc_BK0_N_BK1_Tail
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
...
@@ -54,8 +52,8 @@ __global__ void
...
@@ -54,8 +52,8 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
AGridDesc_AK0_M_AK1
_Tail
a_grid_desc_ak0_m_ak1_tail
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_tail
,
const
BGridDesc_BK0_N_BK1
_Tail
b_grid_desc_bk0_n_bk1_tail
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_tail
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
...
@@ -183,118 +181,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -183,118 +181,7 @@ struct DeviceGemmXdlSplitKCShuffle
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
}
}
template
<
bool
IsTail
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
);
template
<
bool
IsTail
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
);
/*
* No padding in K
*/
template
<
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
index_t
MRaw
,
index_t
K
,
index_t
StrideA
)
{
// return MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, K, StrideA);
assert
(
K
%
KPerBlock
==
0
);
assert
(
K
%
AK1
==
0
);
const
auto
a_grid_desc_mraw_k
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
AK0
=
K
/
AK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
template
<
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
index_t
K
,
index_t
NRaw
,
index_t
StrideB
)
{
// return MakeBGridDescriptor_BK0_N_BK1<true>(K, NRaw, StrideB);
assert
(
K
%
KPerBlock
==
0
);
assert
(
K
%
BK1
==
0
);
const
auto
b_grid_desc_nraw_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
BK0
=
K
/
BK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
template
<
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
<
true
>
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
@@ -359,8 +246,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -359,8 +246,7 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
}
}
template
<
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_BK0_N_BK1
<
true
>
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
@@ -481,11 +367,9 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -481,11 +367,9 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1_Tail
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
true
>
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1_Tail
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
true
>
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
...
@@ -598,9 +482,9 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -598,9 +482,9 @@ struct DeviceGemmXdlSplitKCShuffle
const
auto
BKSplitted
=
actual_batch_and_ksplitted_B
.
second
;
const
auto
BKSplitted
=
actual_batch_and_ksplitted_B
.
second
;
a_grid_desc_ak0_m_ak1_
=
a_grid_desc_ak0_m_ak1_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
MRaw
,
AKSplitted
,
StrideA
);
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
AKSplitted
,
StrideA
);
b_grid_desc_bk0_n_bk1_
=
b_grid_desc_bk0_n_bk1_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
BKSplitted
,
NRaw
,
StrideB
);
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
BKSplitted
,
NRaw
,
StrideB
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
is_valid_
=
GridwiseGemm
::
CheckValidity
(
is_valid_
=
GridwiseGemm
::
CheckValidity
(
...
@@ -613,9 +497,9 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -613,9 +497,9 @@ struct DeviceGemmXdlSplitKCShuffle
const
auto
BKTail
=
KRaw
-
BKSplitted
*
(
BatchCount_
-
1
);
const
auto
BKTail
=
KRaw
-
BKSplitted
*
(
BatchCount_
-
1
);
a_grid_desc_ak0_m_ak1_tail_
=
a_grid_desc_ak0_m_ak1_tail_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
true
>
(
MRaw
,
AKTail
,
StrideA
);
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
AKTail
,
StrideA
);
b_grid_desc_bk0_n_bk1_tail_
=
b_grid_desc_bk0_n_bk1_tail_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
true
>
(
BKTail
,
NRaw
,
StrideB
);
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
BKTail
,
NRaw
,
StrideB
);
is_valid_
&=
GridwiseGemm
::
CheckValidity
(
is_valid_
&=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_tail_
,
b_grid_desc_bk0_n_bk1_tail_
,
c_grid_desc_m_n_
);
a_grid_desc_ak0_m_ak1_tail_
,
b_grid_desc_bk0_n_bk1_tail_
,
c_grid_desc_m_n_
);
...
@@ -668,8 +552,8 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -668,8 +552,8 @@ struct DeviceGemmXdlSplitKCShuffle
bool
is_valid_
;
bool
is_valid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
AGridDesc_AK0_M_AK1
_Tail
a_grid_desc_ak0_m_ak1_tail_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_tail_
;
BGridDesc_BK0_N_BK1
_Tail
b_grid_desc_bk0_n_bk1_tail_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_tail_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
@@ -796,8 +680,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -796,8 +680,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
Block2CTileMap
,
...
@@ -817,8 +699,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -817,8 +699,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
Block2CTileMap
,
...
@@ -838,8 +718,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -838,8 +718,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
Block2CTileMap
,
...
@@ -859,8 +737,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -859,8 +737,6 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
Block2CTileMap
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment