Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7910f486
Commit
7910f486
authored
May 01, 2022
by
Jianfeng yan
Browse files
DeviceGemmXdlSplit and DeviceGemmXdlSplitKCShuffle both work for arbitrary K
parent
b5a9f642
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
541 additions
and
302 deletions
+541
-302
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+40
-85
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+500
-217
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+1
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
7910f486
...
@@ -19,6 +19,11 @@ namespace ck {
...
@@ -19,6 +19,11 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
...
@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
// static constexpr index_t Getk
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
{
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
}
}
...
@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
...
@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
static
auto
MakeAGridDescriptor_K0_M_K1_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
{
const
index_t
KPad
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
KPad
ded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPad
/
K1
;
const
index_t
K0
=
KPad
ded
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
...
@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
ded
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
...
@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
static
auto
MakeBGridDescriptor_K0_N_K1_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
{
const
index_t
KPad
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
KPad
ded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPad
/
K1
;
const
index_t
K0
=
KPad
ded
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
...
@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
make_right_pad_transform
(
K
,
KPad
ded
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
...
@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
const
bool
tail_has_main_k0_block_loop
=
const
bool
tail_has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
{
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
return
launch_and_time_kernel
(
kernel
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
...
@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
,
true
>
;
ave_time
=
Run
(
kernel
);
}
}
else
if
(
has_main_k0_block_loop
&&
!
tail_has_main_k0_block_loop
)
else
if
(
has_main_k0_block_loop
&&
!
tail_has_main_k0_block_loop
)
{
{
...
@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
...
@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
true
,
true
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
Run
(
kernel
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
else
if
(
!
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
else
if
(
!
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
{
{
...
@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
...
@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
false
,
false
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
Run
(
kernel
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
...
@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
...
@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
false
,
false
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
Run
(
kernel
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
}
}
else
else
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
7910f486
...
@@ -19,6 +19,108 @@ namespace ck {
...
@@ -19,6 +19,108 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
AGridDesc_AK0_M_AK1_Tail
,
typename
BGridDesc_BK0_N_BK1_Tail
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
ComputePtrOffsetOfBatch
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
,
bool
TailHasMainKBlockLoop
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_batched_gemm_xdl_cshuffle_v1
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
AGridDesc_AK0_M_AK1_Tail
a_grid_desc_ak0_m_ak1_tail
,
const
BGridDesc_BK0_N_BK1_Tail
b_grid_desc_bk0_n_bk1_tail
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetCPtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
if
(
g_idx
<
batch_count
-
1
)
{
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
}
else
{
GridwiseGemm
::
template
Run
<
TailHasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
c_element_op
,
a_grid_desc_ak0_m_ak1_tail
,
b_grid_desc_bk0_n_bk1_tail
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
block_2_ctile_map
);
}
#else
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
batch_count
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
a_grid_desc_ak0_m_ak1_tail
;
ignore
=
b_grid_desc_bk0_n_bk1_tail
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
template
<
typename
ALayout
,
template
<
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
...
@@ -69,14 +171,127 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -69,14 +171,127 @@ struct DeviceGemmXdlSplitKCShuffle
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
auto
GetKPad
(
index_t
K1
,
index_t
K
,
index_t
KBatch
)
template
<
index_t
K1
>
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0PerBlock
=
KPerBlock
/
K1
;
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
}
template
<
bool
IsTail
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
);
template
<
bool
IsTail
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
);
/*
* No padding in K
*/
template
<
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
index_t
MRaw
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
KPerBlock
==
0
);
assert
(
K
%
AK1
==
0
);
const
auto
a_grid_desc_mraw_k
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
K
),
make_tuple
(
StrideA
,
I1
));
}
else
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
K
),
make_tuple
(
I1
,
StrideA
));
}
}();
const
auto
M
=
math
::
integer_divide_ceil
(
MRaw
,
MPerBlock
)
*
MPerBlock
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
AK0
=
K
/
AK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad M, but not K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
{
// not pad M or K
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
template
<
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
index_t
K
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
KPerBlock
*
KBatch
)
*
KPerBlock
;
assert
(
K
%
KPerBlock
==
0
);
const
index_t
KPad
=
KBatch
*
K0
*
K1
;
assert
(
K
%
BK1
==
0
);
return
KPad
;
const
auto
b_grid_desc_nraw_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
K
),
make_tuple
(
I1
,
StrideB
));
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
K
),
make_tuple
(
StrideB
,
I1
));
}
}();
const
auto
N
=
math
::
integer_divide_ceil
(
NRaw
,
NPerBlock
)
*
NPerBlock
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
BK0
=
K
/
BK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad N, but not K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
{
// not pad N or K
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
}
static
auto
MakeAGridDescriptor_AK0_M_AK1
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
template
<
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
<
true
>
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
const
auto
a_grid_desc_mraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
if
constexpr
(
is_same_v
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>
)
...
@@ -96,14 +311,14 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -96,14 +311,14 @@ struct DeviceGemmXdlSplitKCShuffle
const
auto
MPad
=
M
-
MRaw
;
const
auto
MPad
=
M
-
MRaw
;
const
auto
KPad
=
K
-
KRaw
;
const
auto
KPad
=
K
-
KRaw
;
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MKPadding
||
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
{
// pad both M and K
// pad both M and K
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
...
@@ -121,31 +336,9 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -121,31 +336,9 @@ struct DeviceGemmXdlSplitKCShuffle
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MPadding
||
else
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad M, but not K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_right_pad_transform
(
MRaw
,
MPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
)
{
{
// pad K, but not M
// pad K, but not M
assert
(
K
%
AK1
==
0
);
const
auto
AK0
=
K
/
AK1
;
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
const
auto
a_grid_desc_m_k
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
a_grid_desc_mraw_kraw
,
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
make_pass_through_transform
(
MRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
...
@@ -161,25 +354,10 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -161,25 +354,10 @@ struct DeviceGemmXdlSplitKCShuffle
return
a_grid_desc_ak0_m_ak1
;
return
a_grid_desc_ak0_m_ak1
;
}
}
else
{
// not pad M or K
assert
(
KRaw
%
AK1
==
0
);
const
auto
AK0
=
KRaw
/
AK1
;
const
auto
a_grid_desc_ak0_m_ak1
=
transform_tensor_descriptor
(
a_grid_desc_mraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
AK0
,
AK1
)),
make_pass_through_transform
(
MRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
a_grid_desc_ak0_m_ak1
;
}
}
}
static
auto
MakeBGridDescriptor_BK0_N_BK1
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
template
<
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
<
true
>
(
index_t
KRaw
,
index_t
NRaw
,
index_t
StrideB
)
{
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
const
auto
b_grid_desc_nraw_kraw
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
@@ -200,14 +378,13 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -200,14 +378,13 @@ struct DeviceGemmXdlSplitKCShuffle
const
auto
NPad
=
N
-
NRaw
;
const
auto
NPad
=
N
-
NRaw
;
const
auto
KPad
=
K
-
KRaw
;
const
auto
KPad
=
K
-
KRaw
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NKPadding
||
assert
(
K
%
BK1
==
0
);
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
const
auto
BK0
=
K
/
BK1
;
{
// pad both N and K
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
GemmSpec
==
GemmSpecialization
::
NKPadding
||
GemmSpec
==
GemmSpecialization
::
MNKPadding
)
{
// pad both N and K
const
auto
b_grid_desc_n_k
=
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
make_tuple
(
make_right_pad_transform
(
NRaw
,
NPad
),
...
@@ -224,31 +401,8 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -224,31 +401,8 @@ struct DeviceGemmXdlSplitKCShuffle
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
NPadding
||
else
// pad K, but not N
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
// pad N, but not K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_right_pad_transform
(
NRaw
,
NPad
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
else
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
KPadding
||
GemmSpec
==
GemmSpecialization
::
MKPadding
)
{
{
// pad K, but not N
assert
(
K
%
BK1
==
0
);
const
auto
BK0
=
K
/
BK1
;
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_n_k
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
b_grid_desc_nraw_kraw
,
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
make_tuple
(
make_pass_through_transform
(
NRaw
),
make_right_pad_transform
(
KRaw
,
KPad
)),
...
@@ -264,22 +418,6 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -264,22 +418,6 @@ struct DeviceGemmXdlSplitKCShuffle
return
b_grid_desc_bk0_n_bk1
;
return
b_grid_desc_bk0_n_bk1
;
}
}
else
{
// not pad N or K
assert
(
KRaw
%
BK1
==
0
);
const
auto
BK0
=
KRaw
/
BK1
;
const
auto
b_grid_desc_bk0_n_bk1
=
transform_tensor_descriptor
(
b_grid_desc_nraw_kraw
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
BK1
)),
make_pass_through_transform
(
NRaw
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
return
b_grid_desc_bk0_n_bk1
;
}
}
}
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
static
auto
MakeCGridDescriptor_M_N
(
index_t
MRaw
,
index_t
NRaw
,
index_t
StrideC
)
...
@@ -340,10 +478,11 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -340,10 +478,11 @@ struct DeviceGemmXdlSplitKCShuffle
}
}
}
}
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1_Tail
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
<
true
>
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1_Tail
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
<
true
>
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
struct
ComputePtrOffsetOfStridedBatch
struct
ComputePtrOffsetOfStridedBatch
{
{
...
@@ -418,7 +557,8 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -418,7 +557,8 @@ struct DeviceGemmXdlSplitKCShuffle
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
using
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
CGridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
1
,
1
,
1
));
using
Block2CTileMap
=
decltype
(
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
1
,
1
,
1
));
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
{
{
...
@@ -445,21 +585,40 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -445,21 +585,40 @@ struct DeviceGemmXdlSplitKCShuffle
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
}
c_element_op_
{
c_element_op
}
{
{
const
auto
AKPad
=
GetKPad
(
AK1
,
KRaw
,
k_batch
);
const
auto
actual_batch_and_ksplitted_A
=
assert
(
AKPad
%
k_batch
==
0
);
GetActualBatchAndKSplitted
<
AK1
>
(
KRaw
,
k_batch
);
const
auto
BKPad
=
GetKPad
(
BK1
,
KRaw
,
k_batch
);
const
auto
actual_batch_and_ksplitted_B
=
assert
(
BKPad
%
k_batch
==
0
);
GetActualBatchAndKSplitted
<
BK1
>
(
KRaw
,
k_batch
);
const
auto
AKSplitted
=
AKPad
/
k_batch
;
assert
(
actual_batch_and_ksplitted_A
.
first
==
actual_batch_and_ksplitted_B
.
first
);
const
auto
BKSplitted
=
BKPad
/
k_batch
;
BatchCount_
=
actual_batch_and_ksplitted_A
.
first
;
const
auto
AKSplitted
=
actual_batch_and_ksplitted_A
.
second
;
const
auto
BKSplitted
=
actual_batch_and_ksplitted_B
.
second
;
a_grid_desc_ak0_m_ak1_
=
a_grid_desc_ak0_m_ak1_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
AKSplitted
,
StrideA
);
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
MRaw
,
AKSplitted
,
StrideA
);
b_grid_desc_bk0_n_bk1_
=
b_grid_desc_bk0_n_bk1_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
BKSplitted
,
NRaw
,
StrideB
);
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
BKSplitted
,
NRaw
,
StrideB
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
c_grid_desc_m_n_
=
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
);
if
(
GridwiseGemm
::
CheckValidity
(
is_valid_
=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
))
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
c_grid_desc_m_n_
);
if
(
KRaw
!=
AKSplitted
*
BatchCount_
||
KRaw
!=
BKSplitted
*
BatchCount_
)
{
has_tail_
=
true
;
const
auto
AKTail
=
KRaw
-
AKSplitted
*
(
BatchCount_
-
1
);
const
auto
BKTail
=
KRaw
-
BKSplitted
*
(
BatchCount_
-
1
);
a_grid_desc_ak0_m_ak1_tail_
=
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
<
true
>
(
MRaw
,
AKTail
,
StrideA
);
b_grid_desc_bk0_n_bk1_tail_
=
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
<
true
>
(
BKTail
,
NRaw
,
StrideB
);
is_valid_
&=
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_tail_
,
b_grid_desc_bk0_n_bk1_tail_
,
c_grid_desc_m_n_
);
}
if
(
is_valid_
)
{
{
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
c_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
...
@@ -492,7 +651,8 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -492,7 +651,8 @@ struct DeviceGemmXdlSplitKCShuffle
compute_ptr_offset_of_batch_
=
compute_ptr_offset_of_batch_
=
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
ComputePtrOffsetOfStridedBatch
{
a_batch_stride
,
b_batch_stride
};
block_2_ctile_map_
=
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
BatchCount_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
));
block_2_ctile_map_
=
BatchedGemmUtil
::
MakeBlock2CTileMap
<
MPerBlock
,
NPerBlock
>
(
BatchCount_
,
c_grid_desc_m_n_
.
GetLength
(
I0
),
c_grid_desc_m_n_
.
GetLength
(
I1
));
}
}
}
}
...
@@ -501,8 +661,12 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -501,8 +661,12 @@ struct DeviceGemmXdlSplitKCShuffle
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
index_t
BatchCount_
;
index_t
BatchCount_
;
bool
has_tail_
;
bool
is_valid_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
AGridDesc_AK0_M_AK1_Tail
a_grid_desc_ak0_m_ak1_tail_
;
BGridDesc_BK0_N_BK1_Tail
b_grid_desc_bk0_n_bk1_tail_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
...
@@ -534,10 +698,23 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -534,10 +698,23 @@ struct DeviceGemmXdlSplitKCShuffle
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
if
(
arg
.
has_tail_
)
{
std
::
cout
<<
"arg.a_grid_desc_ak0_m_ak1_tail_{"
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.b_grid_desc_bk0_n_bk1_tail_{"
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I1
)
<<
", "
<<
arg
.
b_grid_desc_bk0_n_bk1_tail_
.
GetLength
(
I2
)
<<
"}"
<<
std
::
endl
;
}
}
}
if
(
!
GridwiseGemm
::
CheckValidity
(
if
(
!
arg
.
is_valid_
)
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_m_n_
))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
"wrong! GridwiseBatchedGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"
);
...
@@ -546,127 +723,233 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -546,127 +723,233 @@ struct DeviceGemmXdlSplitKCShuffle
const
index_t
grid_size
=
const
index_t
grid_size
=
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
GridwiseGemm
::
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
BatchCount_
;
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
auto
K0
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
const
bool
has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0
);
float
ave_time
=
0
;
float
ave_time
=
0
;
if
(
has_
m
ai
n_k0_block_loop
)
if
(
arg
.
has_
t
ai
l_
)
{
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v1
<
const
auto
K0_tail
=
arg
.
a_grid_desc_ak0_m_ak1_tail_
.
GetLength
(
I0
);
GridwiseGemm
,
const
bool
tail_has_main_k0_block_loop
=
ADataType
,
// TODO: distiguish A/B datatype
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
CDataType
,
AElementwiseOperation
,
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
BElementwiseOperation
,
if
(
nrepeat
==
0
)
CElementwiseOperation
,
{
DeviceOp
::
AGridDesc_AK0_M_AK1
,
launch_kernel
(
kernel
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
dim3
(
grid_size
),
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
dim3
(
BlockSize
),
ComputePtrOffsetOfStridedBatch
,
0
,
Block2CTileMap
,
arg
.
p_a_grid_
,
true
>
;
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
if
(
nrepeat
==
0
)
arg
.
BatchCount_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
a_grid_desc_ak0_m_ak1_tail_
,
arg
.
b_grid_desc_bk0_n_bk1_tail_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
return
0.0
f
;
}
else
{
return
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
a_grid_desc_ak0_m_ak1_tail_
,
arg
.
b_grid_desc_bk0_n_bk1_tail_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
};
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
true
,
true
>
;
ave_time
=
Run
(
kernel
);
}
else
if
(
has_main_k0_block_loop
&&
!
tail_has_main_k0_block_loop
)
{
{
launch_kernel
(
kernel
,
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v1
<
dim3
(
grid_size
),
GridwiseGemm
,
dim3
(
BlockSize
),
ADataType
,
// TODO: distiguish A/B datatype
0
,
CDataType
,
arg
.
p_a_grid_
,
AElementwiseOperation
,
arg
.
p_b_grid_
,
BElementwiseOperation
,
arg
.
p_c_grid_
,
CElementwiseOperation
,
arg
.
BatchCount_
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
arg
.
a_element_op_
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
arg
.
b_element_op_
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
arg
.
c_element_op_
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
arg
.
a_grid_desc_ak0_m_ak1_
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
arg
.
b_grid_desc_bk0_n_bk1_
,
ComputePtrOffsetOfStridedBatch
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
Block2CTileMap
,
arg
.
compute_ptr_offset_of_batch_
,
true
,
arg
.
block_2_ctile_map_
);
false
>
;
ave_time
=
Run
(
kernel
);
}
else
if
(
!
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v1
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
ComputePtrOffsetOfStridedBatch
,
Block2CTileMap
,
false
,
true
>
;
ave_time
=
Run
(
kernel
);
}
}
else
else
{
{
ave_time
=
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v1
<
launch_and_time_kernel
(
kernel
,
GridwiseGemm
,
nrepeat
,
ADataType
,
// TODO: distiguish A/B datatype
dim3
(
grid_size
)
,
CDataType
,
dim3
(
BlockSize
)
,
AElementwiseOperation
,
0
,
BElementwiseOperation
,
arg
.
p_a_grid_
,
CElementwiseOperation
,
arg
.
p_b_grid_
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
arg
.
p_c_grid_
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
arg
.
BatchCount_
,
DeviceOp
::
AGridDesc_AK0_M_AK1_Tail
,
arg
.
a_element_op_
,
DeviceOp
::
BGridDesc_BK0_N_BK1_Tail
,
arg
.
b_element_op_
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
arg
.
c_element_op_
,
ComputePtrOffsetOfStridedBatch
,
arg
.
a_grid_desc_ak0_m_ak1_
,
Block2CTileMap
,
arg
.
b_grid_desc_bk0_n_bk1_
,
false
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
false
>
;
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
ave_time
=
Run
(
kernel
);
}
}
}
}
else
else
{
{
const
auto
kernel
=
kernel_batched_gemm_xdl_cshuffle_v1
<
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
GridwiseGemm
,
if
(
nrepeat
==
0
)
ADataType
,
// TODO: distiguish A/B datatype
{
CDataType
,
launch_kernel
(
kernel
,
AElementwiseOperation
,
dim3
(
grid_size
),
BElementwiseOperation
,
dim3
(
BlockSize
),
CElementwiseOperation
,
0
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
arg
.
p_a_grid_
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
arg
.
p_b_grid_
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
arg
.
p_c_grid_
,
ComputePtrOffsetOfStridedBatch
,
arg
.
BatchCount_
,
Block2CTileMap
,
arg
.
a_element_op_
,
false
>
;
arg
.
b_element_op_
,
arg
.
c_element_op_
,
if
(
nrepeat
==
0
)
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
return
0.0
f
;
}
else
{
return
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
};
if
(
has_main_k0_block_loop
)
{
{
launch_kernel
(
kernel
,
const
auto
kernel
=
ck
::
kernel_batched_gemm_xdl_cshuffle_v1
<
dim3
(
grid_size
),
GridwiseGemm
,
dim3
(
BlockSize
),
ADataType
,
// TODO: distiguish A/B datatype
0
,
CDataType
,
arg
.
p_a_grid_
,
AElementwiseOperation
,
arg
.
p_b_grid_
,
BElementwiseOperation
,
arg
.
p_c_grid_
,
CElementwiseOperation
,
arg
.
BatchCount_
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
arg
.
a_element_op_
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
arg
.
b_element_op_
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
arg
.
c_element_op_
,
ComputePtrOffsetOfStridedBatch
,
arg
.
a_grid_desc_ak0_m_ak1_
,
Block2CTileMap
,
arg
.
b_grid_desc_bk0_n_bk1_
,
true
>
;
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
ave_time
=
Run
(
kernel
);
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
ave_time
=
const
auto
kernel
=
ck
::
kernel_batched_gemm_xdl_cshuffle_v1
<
launch_and_time_kernel
(
kernel
,
GridwiseGemm
,
nrepeat
,
ADataType
,
// TODO: distiguish A/B datatype
dim3
(
grid_size
),
CDataType
,
dim3
(
BlockSize
),
AElementwiseOperation
,
0
,
BElementwiseOperation
,
arg
.
p_a_grid_
,
CElementwiseOperation
,
arg
.
p_b_grid_
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
arg
.
p_c_grid_
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
arg
.
BatchCount_
,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
arg
.
a_element_op_
,
ComputePtrOffsetOfStridedBatch
,
arg
.
b_element_op_
,
Block2CTileMap
,
arg
.
c_element_op_
,
false
>
;
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
ave_time
=
Run
(
kernel
);
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
}
}
return
ave_time
;
return
ave_time
;
}
}
...
@@ -781,7 +1064,7 @@ struct DeviceGemmXdlSplitKCShuffle
...
@@ -781,7 +1064,7 @@ struct DeviceGemmXdlSplitKCShuffle
return
str
.
str
();
return
str
.
str
();
}
}
};
};
// namespace device
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
7910f486
...
@@ -142,6 +142,7 @@ __global__ void
...
@@ -142,6 +142,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment