Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8e3c41a5
Commit
8e3c41a5
authored
May 03, 2022
by
Jianfeng yan
Browse files
minor changes
parent
7910f486
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
61 additions
and
58 deletions
+61
-58
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
...ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
+55
-55
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
...operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
+6
-3
No files found.
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk.hpp
View file @
8e3c41a5
...
...
@@ -173,11 +173,10 @@ struct DeviceGemmXdlSplitK
return
std
::
make_pair
(
actual_batch
,
KSplitted
);
}
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1
_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
assert
(
K
%
(
K1
*
K0PerBlock
)
==
0
);
const
index_t
K0
=
K
/
K1
;
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPadded
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
...
@@ -190,12 +189,18 @@ struct DeviceGemmXdlSplitK
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPadded
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
...
...
@@ -204,7 +209,7 @@ struct DeviceGemmXdlSplitK
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
pad
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
...
...
@@ -212,11 +217,11 @@ struct DeviceGemmXdlSplitK
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1
_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
assert
(
K
%
(
K1
*
K0PerBlock
)
==
0
)
;
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
K
/
K1
;
const
index_t
K0
=
K
Padded
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -229,12 +234,18 @@ struct DeviceGemmXdlSplitK
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPadded
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_k
pad
_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
...
@@ -243,7 +254,7 @@ struct DeviceGemmXdlSplitK
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
b_grid_desc_k
pad
_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
...
@@ -251,10 +262,13 @@ struct DeviceGemmXdlSplitK
}
}
static
auto
MakeAGridDescriptor_K0_M_K1
_Tail
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
static
auto
MakeAGridDescriptor_K0_M_K1
(
index_t
M
,
index_t
K
,
index_t
StrideA
)
{
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
const
index_t
K0
=
KPadded
/
K1
;
// return MakeAGridDescriptor_K0_M_K1_Tail(M, K, StrideA);
assert
(
K
%
(
K1
*
K0PerBlock
)
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
a_grid_desc_m_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
ALayout
>::
value
)
...
...
@@ -267,18 +281,12 @@ struct DeviceGemmXdlSplitK
}
}();
const
auto
a_grid_desc_m_kpad
=
transform_tensor_descriptor
(
a_grid_desc_m_k
,
make_tuple
(
make_pass_through_transform
(
M
),
make_right_pad_transform
(
K
,
KPadded
-
K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadM
=
(
MPerBlock
-
M
%
MPerBlock
)
%
MPerBlock
;
return
transform_tensor_descriptor
(
a_grid_desc_m_k
pad
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
M
,
PadM
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
...
...
@@ -287,7 +295,7 @@ struct DeviceGemmXdlSplitK
else
{
return
transform_tensor_descriptor
(
a_grid_desc_m_k
pad
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
M
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
...
...
@@ -295,11 +303,12 @@ struct DeviceGemmXdlSplitK
}
}
static
auto
MakeBGridDescriptor_K0_N_K1
_Tail
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
{
const
index_t
KPadded
=
math
::
integer_divide_ceil
(
K
,
K1
*
K0PerBlock
)
*
K1
*
K0PerBlock
;
// return MakeBGridDescriptor_K0_N_K1_Tail(K, N, StrideB);
assert
(
K
%
(
K1
*
K0PerBlock
)
==
0
);
const
index_t
K0
=
K
Padded
/
K1
;
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
...
...
@@ -312,18 +321,12 @@ struct DeviceGemmXdlSplitK
}
}();
const
auto
b_grid_desc_kpad_n
=
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_right_pad_transform
(
K
,
KPadded
-
K
),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k
pad
_n
,
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_right_pad_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
...
@@ -332,7 +335,7 @@ struct DeviceGemmXdlSplitK
else
{
return
transform_tensor_descriptor
(
b_grid_desc_k
pad
_n
,
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
...
...
@@ -674,28 +677,26 @@ struct DeviceGemmXdlSplitK
const
bool
tail_has_main_k0_block_loop
=
GridwiseGemm
::
CalculateHasMainK0BlockLoop
(
K0_tail
);
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
const
auto
Run
=
[
&
](
const
auto
&
kernel
)
{
return
launch_and_time_kernel
(
kernel
,
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
nrepeat
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
BatchCount_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
a_grid_desc_k0_m_k1_tail_
,
arg
.
b_grid_desc_k0_n_k1_tail_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_ptr_offset_of_batch_
,
arg
.
block_2_ctile_map_
);
};
if
(
has_main_k0_block_loop
&&
tail_has_main_k0_block_loop
)
...
...
@@ -718,7 +719,6 @@ struct DeviceGemmXdlSplitK
true
>
;
ave_time
=
Run
(
kernel
);
}
else
if
(
has_main_k0_block_loop
&&
!
tail_has_main_k0_block_loop
)
{
...
...
include/ck/tensor_operation/gpu/device/device_gemm_xdl_splitk_c_shuffle.hpp
View file @
8e3c41a5
...
...
@@ -20,9 +20,10 @@ namespace tensor_operation {
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize
BatchedGEMM
.
* \brief Wrapper function of GridwiseGemm::Run to realize
a customized BatchedGemm for splitK
.
*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
* The main difference from \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3
* is that there are 2 different tensor descriptors for matrix A and B.
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
...
...
@@ -174,7 +175,7 @@ struct DeviceGemmXdlSplitKCShuffle
template
<
index_t
K1
>
static
auto
GetActualBatchAndKSplitted
(
index_t
K
,
index_t
KBatch
)
{
const
index_t
K0PerBlock
=
KPerBlock
/
K1
;
const
index_t
K0PerBlock
=
KPerBlock
/
K1
;
const
index_t
K0
=
math
::
integer_divide_ceil
(
K
,
KPerBlock
*
KBatch
)
*
K0PerBlock
;
const
index_t
KSplitted
=
K0
*
K1
;
const
index_t
actual_batch
=
math
::
integer_divide_ceil
(
K
,
KSplitted
);
...
...
@@ -193,6 +194,7 @@ struct DeviceGemmXdlSplitKCShuffle
template
<
>
static
auto
MakeAGridDescriptor_AK0_M_AK1
<
false
>
(
index_t
MRaw
,
index_t
K
,
index_t
StrideA
)
{
// return MakeAGridDescriptor_AK0_M_AK1<true>(MRaw, K, StrideA);
assert
(
K
%
KPerBlock
==
0
);
assert
(
K
%
AK1
==
0
);
...
...
@@ -243,6 +245,7 @@ struct DeviceGemmXdlSplitKCShuffle
template
<
>
static
auto
MakeBGridDescriptor_BK0_N_BK1
<
false
>
(
index_t
K
,
index_t
NRaw
,
index_t
StrideB
)
{
// return MakeBGridDescriptor_BK0_N_BK1<true>(K, NRaw, StrideB);
assert
(
K
%
KPerBlock
==
0
);
assert
(
K
%
BK1
==
0
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment