Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7910f486
Commit
7910f486
authored
May 01, 2022
by
Jianfeng yan
Browse files
DeviceGemmXdlSplit and DeviceGemmXdlSplitKCShuffle both work for arbitrary K
parent
b5a9f642
Changes
3
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
541 additions
and
302 deletions
+541
-302
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+40
-85
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+500
-217
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
+1
-0
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
7910f486
...
@@ -19,6 +19,11 @@ namespace ck {
...
@@ -19,6 +19,11 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
*/
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
...
@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
...
@@ -159,15 +164,12 @@ struct DeviceGemmXdlSplitK
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
// static constexpr index_t Getk
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
{
{
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
// return std::make_pair<index_t, index_t>(actual_batch, KSplitted);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
}
}
...
@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
...
@@ -251,8 +253,8 @@ struct DeviceGemmXdlSplitK
static
auto
MakeAGridDescriptor_K0_M_K1_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
{
const
index_t
KPad
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
KPad
ded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPad
/
K1
;
const
index_t
K0
=
KPad
ded
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
...
@@ -267,7 +269,7 @@ struct DeviceGemmXdlSplitK
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
-
K
)),
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPad
ded
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
...
@@ -295,9 +297,9 @@ struct DeviceGemmXdlSplitK
static
auto
MakeBGridDescriptor_K0_N_K1_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
{
const
index_t
KPad
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
KPad
ded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPad
/
K1
;
const
index_t
K0
=
KPad
ded
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
...
@@ -312,7 +314,7 @@ struct DeviceGemmXdlSplitK
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPad
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
make_right_pad_transform
(
K
,
KPad
ded
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
...
@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
...
@@ -672,26 +674,9 @@ struct DeviceGemmXdlSplitK
const
bool
tail_has_main_k0_block_loop
=
const
bool
tail_has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
{
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
return
launch_and_time_kernel
(
kernel
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
,
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
nrepeat
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
...
@@ -710,6 +695,30 @@ struct DeviceGemmXdlSplitK
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
{
const
auto
kernel
=
kernel_batched_gemm_xdlops_v2r3
<
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
AGridDesc_K0_M_K1_Tail
>
,
remove_reference_t
<
DeviceGemmXdlSplitK
::
BGridDesc_K0_N_K1_Tail
>
,
remove_reference_t
<
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
ComputePtrOffsetOfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
,
true
>
;
ave_time
=
Run
(
kernel
);
}
}
else
if
(
has_main_k0_block_loop
&&
!
tail_has_main_k0_block_loop
)
else
if
(
has_main_k0_block_loop
&&
!
tail_has_main_k0_block_loop
)
{
{
...
@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
...
@@ -730,25 +739,7 @@ struct DeviceGemmXdlSplitK
true
,
true
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
Run
(
kernel
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
else
if
(
!
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
else
if
(
!
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
{
{
...
@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
...
@@ -769,25 +760,7 @@ struct DeviceGemmXdlSplitK
false
,
false
,
true
>
;
true
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
Run
(
kernel
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
else
else
{
{
...
@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
...
@@ -808,25 +781,7 @@ struct DeviceGemmXdlSplitK
false
,
false
,
false
>
;
false
>
;
ave_time
=
launch_and_time_kernel
(
kernel
,
ave_time
=
Run
(
kernel
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
}
}
}
}
else
else
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
7910f486
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v1.hpp
View file @
7910f486
...
@@ -142,6 +142,7 @@ __global__ void
...
@@ -142,6 +142,7 @@ __global__ void
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
#endif // end of if (defined(__gfx908__) || defined(__gfx90a__))
}
}
template
<
typename
FloatAB
,
template
<
typename
FloatAB
,
typename
FloatGemmAcc
,
typename
FloatGemmAcc
,
typename
FloatCShuffle
,
typename
FloatCShuffle
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment