Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
060c4f3a
Commit
060c4f3a
authored
Mar 06, 2023
by
aska-0096
Browse files
Skip B Lds Gemm + MulD
parent
04c6a978
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
191 additions
and
139 deletions
+191
-139
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+60
-51
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+2
-19
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+127
-67
No files found.
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
060c4f3a
...
...
@@ -87,8 +87,8 @@ using DeviceOpInstance =
8
,
16
,
16
,
1
,
8
,
1
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
...
@@ -105,7 +105,7 @@ using DeviceOpInstance =
true
,
1
,
1
,
S
<
1
,
1
28
,
1
,
2
>
,
S
<
1
,
1
6
,
1
,
16
>
,
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
060c4f3a
...
...
@@ -105,6 +105,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Describe how data read from Global memory
// Describe how data read from Global memory
static
auto
MakeAGridDescriptor
(
index_t
MRaw
,
index_t
KRaw
,
index_t
StrideA
)
{
const
auto
a_grid_desc_m_k
=
[
&
]()
{
...
...
@@ -115,12 +116,13 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
#ifdef ENABLE_COLMAJOR
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
ALayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
M
,
K
),
make_tuple
(
I1
,
StrideA
));
const
auto
a_grid_desc_mraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
MRaw
,
KRaw
),
make_tuple
(
I1
,
StrideA
));
return
matrix_padder
.
PadADescriptor_M_K
(
a_grid_desc_mraw_kraw
);
}
#endif
}();
const
auto
M
=
a_grid_desc_m_k
.
GetLength
(
I0
);
...
...
@@ -155,42 +157,56 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
static
auto
MakeBGridDescriptor
_K0_N_K1
(
index_t
K
,
index_t
N
,
index_t
StrideB
)
static
auto
MakeBGridDescriptor
(
index_t
K
Raw
,
index_t
N
Raw
,
index_t
StrideB
)
{
assert
(
K
%
K1
==
0
);
const
index_t
K0
=
K
/
K1
;
const
auto
b_grid_desc_k_n
=
[
&
]()
{
const
auto
b_grid_desc_n_k
=
[
&
]()
{
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
StrideB
,
I1
));
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
I1
,
StrideB
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
else
if
constexpr
(
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
K
,
N
),
make_tuple
(
I1
,
StrideB
));
const
auto
b_grid_desc_nraw_kraw
=
make_naive_tensor_descriptor
(
make_tuple
(
NRaw
,
KRaw
),
make_tuple
(
StrideB
,
I1
));
return
matrix_padder
.
PadBDescriptor_N_K
(
b_grid_desc_nraw_kraw
);
}
}();
if
constexpr
(
GemmSpec
==
GemmSpecialization
::
MNPadding
)
const
auto
N
=
b_grid_desc_n_k
.
GetLength
(
I0
);
const
auto
K
=
b_grid_desc_n_k
.
GetLength
(
I1
);
assert
(
K
%
K1
==
0
);
if
constexpr
(
BEnableLds
)
{
const
auto
PadN
=
(
NPerBlock
-
N
%
NPerBlock
)
%
NPerBlock
;
const
index_t
K0
=
K
/
K1
;
return
transform_tensor_descriptor
(
b_grid_desc_
k_n
,
b_grid_desc_
n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_
right_pad
_transform
(
N
,
PadN
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_
pass_through
_transform
(
N
)),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
}
else
{
constexpr
auto
B_KRow
=
WmmaK
/
K1
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
return
transform_tensor_descriptor
(
b_grid_desc_k_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
K0
,
K1Number
)),
make_pass_through_transform
(
N
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}));
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
}
}
...
...
@@ -245,7 +261,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Gridwise descriptor, mapping to whole given provblem.
using
AGridDesc
=
decltype
(
MakeAGridDescriptor
(
1
,
1
,
1
));
using
BGridDesc
_K0_N_K1
=
decltype
(
MakeBGridDescriptor
_K0_N_K1
(
1
,
1
,
1
));
using
BGridDesc
=
decltype
(
MakeBGridDescriptor
(
1
,
1
,
1
));
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{},
{}))
>
;
using
EGridDesc_M_N
=
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
(
1
,
1
,
1
));
...
...
@@ -260,7 +276,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
EDataType
,
// InMemory Data Descriptor
AGridDesc
,
BGridDesc
_K0_N_K1
,
BGridDesc
,
DsGridDesc_M_N
,
EGridDesc_M_N
,
// ElementwiseOp Family
...
...
@@ -329,7 +345,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e_grid
)},
a_grid_desc
{},
b_grid_desc
_k0_n_k1_
{},
b_grid_desc
{},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{},
ds_grid_desc_mblock_mperblock_nblock_nperblock
{},
...
...
@@ -342,7 +358,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
cde_element_op_
{
cde_element_op
}
{
a_grid_desc
=
DeviceOp
::
MakeAGridDescriptor
(
M
,
K
,
StrideA
);
b_grid_desc
_k0_n_k1_
=
DeviceOp
::
MakeBGridDescriptor
_K0_N_K1
(
K
,
N
,
StrideB
);
b_grid_desc
=
DeviceOp
::
MakeBGridDescriptor
(
K
,
N
,
StrideB
);
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
...
...
@@ -359,7 +375,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
block_2_ctile_map_
=
GridwiseOp
::
MakeDefaultBlock2CTileMap
(
e_grid_desc_m_n_
,
M01
,
N01
);
if
(
GridwiseOp
::
CheckValidity
(
a_grid_desc
,
b_grid_desc
_k0_n_k1_
,
b_grid_desc
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_ctile_map_
))
...
...
@@ -382,7 +398,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
// Tensor Descriptors
AGridDesc
a_grid_desc
;
BGridDesc
_K0_N_K1
b_grid_desc
_k0_n_k1_
;
BGridDesc
b_grid_desc
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -410,24 +426,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if 0
{
std::cout << "arg.a_grid_desc{" << arg.a_grid_desc.GetLength(I0)
<< ", " << arg.a_grid_desc.GetLength(I1) << ", "
<< arg.a_grid_desc.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if
(
!
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc
_k0_n_k1_
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
...
...
@@ -439,8 +439,17 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
e_grid_desc_m_n_
);
const
auto
K
=
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
const
auto
K
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I2
);
}
else
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I5
);
}
}();
float
ave_time
=
0
;
...
...
@@ -453,7 +462,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
_K0_N_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
...
...
@@ -475,7 +484,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
_k0_n_k1_
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
...
...
@@ -492,7 +501,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
_K0_N_K1
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
...
...
@@ -514,7 +523,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
_k0_n_k1_
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
...
...
@@ -555,7 +564,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
return
GridwiseOp
::
CheckValidity
(
arg
.
a_grid_desc
,
arg
.
b_grid_desc
_k0_n_k1_
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
060c4f3a
...
...
@@ -344,22 +344,6 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
#if 0
{
std::cout << "arg.a_grid_desc_{" << arg.a_grid_desc_.GetLength(I0)
<< ", " << arg.a_grid_desc_.GetLength(I1) << ", "
<< arg.a_grid_desc_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0)
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0)
<< ", " << arg.c_grid_desc_m_n_.GetLength(I1) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I2) << "}" << std::endl;
}
#endif
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m_n_
,
...
...
@@ -372,7 +356,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
);
const
auto
Get
K
=
[
&
]()
{
const
auto
K
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
{
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I2
);
...
...
@@ -382,8 +366,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
return
arg
.
a_grid_desc_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_
.
GetLength
(
I3
)
*
arg
.
a_grid_desc_
.
GetLength
(
I5
);
}
};
const
auto
K
=
GetK
();
}();
float
ave_time
=
0
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
060c4f3a
...
...
@@ -46,7 +46,7 @@ __global__ void
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc
,
const
BGridDesc_BK0_N_BK1
b_grid_desc
_k0_n_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -85,7 +85,7 @@ __global__ void
p_e_grid
+
e_batch_offset
,
p_shared
,
a_grid_desc
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
a_element_op
,
...
...
@@ -99,7 +99,7 @@ __global__ void
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc
_k0_n_k1
;
ignore
=
b_grid_desc
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
a_element_op
;
...
...
@@ -116,7 +116,7 @@ template <typename GridwiseOp,
typename
DsPointer
,
typename
EDataType
,
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
...
...
@@ -136,7 +136,7 @@ __global__ void
EDataType
*
__restrict__
p_e_grid
,
const
index_t
batch_count
,
const
AGridDesc
a_grid_desc
,
const
BGridDesc
_K0_N_K1
b_grid_desc
_k0_n_k1
,
const
BGridDesc
b_grid_desc
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -177,7 +177,7 @@ __global__ void
p_e_grid
+
e_batch_offset
,
p_shared
,
a_grid_desc
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
...
...
@@ -194,7 +194,7 @@ __global__ void
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc
_k0_n_k1
;
ignore
=
b_grid_desc
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
block_2_etile_map
;
...
...
@@ -208,7 +208,7 @@ template <typename GridwiseOp,
typename
DsPointer
,
typename
EDataType
,
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
AElementwiseOperation
,
...
...
@@ -226,7 +226,7 @@ __global__ void
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AGridDesc
a_grid_desc
,
const
BGridDesc
_K0_N_K1
b_grid_desc
_k0_n_k1
,
const
BGridDesc
b_grid_desc
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
...
...
@@ -245,7 +245,7 @@ __global__ void
p_e_grid
,
p_shared
,
a_grid_desc
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
a_element_op
,
...
...
@@ -258,7 +258,7 @@ __global__ void
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
a_grid_desc
;
ignore
=
b_grid_desc
_k0_n_k1
;
ignore
=
b_grid_desc
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
a_element_op
;
...
...
@@ -277,7 +277,7 @@ template < // DataType Family
typename
EDataType
,
// InMemory Data Descriptor
typename
AGridDesc
,
typename
BGridDesc
_K0_N_K1
,
typename
BGridDesc
,
typename
DsGridDesc_M_N
,
typename
EGridDesc_M_N
,
// ElementwiseOp Family
...
...
@@ -385,6 +385,40 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
a_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor
()
{
constexpr
auto
b_block_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// K0->N->K1 Per Block
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
constexpr
auto
max_lds_align
=
K1
;
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
}
}
else
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
}
}();
return
b_block_desc
;
}
__host__
__device__
static
constexpr
auto
MakeABlockSliceCopyStep
()
{
constexpr
auto
a_block_copy_step
=
[
&
]()
{
...
...
@@ -478,44 +512,56 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
return
a_wave_desc
;
}
template
<
typename
BBlockDesc_
BK0_N_BK1
>
template
<
typename
BBlockDesc_
>
__host__
__device__
static
constexpr
auto
MakeB
Block
Descriptor
_K0_N0_N1_N2_K1
(
const
BBlockDesc_
BK0_N_BK1
&
)
MakeB
Wave
Descriptor
(
const
BBlockDesc_
&
)
{
constexpr
auto
B_K0
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_BK0_N_BK1
{}.
GetLength
(
I2
);
constexpr
auto
b_wave_desc
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
return
transform_tensor_descriptor
(
BBlockDesc_
BK0_N_BK1
{},
BBlockDesc_
{},
make_tuple
(
make_pass_through_transform
(
Number
<
B_K0
>
{}),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
()
{
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0perblock_nperblock_k1
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
max_lds_align
);
// KWmma_NRepeat_NWave_KRow_NPerWmma_K1 -> K0_NRepeat_Nwaves_NPerWmma_K1
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I5
);
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_freeze_transform
(
I0
),
make_pass_through_transform
(
Number
<
KWmma
>
{}),
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}();
return
b_
block_desc_k0perblock_nperblock_k1
;
return
b_
wave_desc
;
}
__host__
__device__
static
constexpr
auto
...
...
@@ -551,7 +597,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
template
<
typename
Block2CTileMap
>
__host__
__device__
static
constexpr
bool
CheckValidity
(
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
_K0_N_K1
&
b_grid_desc
_k0_n_k1
,
const
BGridDesc
&
b_grid_desc
,
const
DsGridDesc_M_N
&
ds_grid_desc_m_n
,
const
EGridDesc_M_N
&
e_grid_desc_m_n
,
const
Block2CTileMap
&
block_2_ctile_map
)
...
...
@@ -581,17 +627,17 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
GetBProblemsizeNK
=
[
&
]()
{
if
constexpr
(
BEnableLds
)
{
return
make_tuple
(
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
),
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
));
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
),
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I2
));
}
else
{
return
make_tuple
(
b_grid_desc
_k0_n_k1
.
GetLength
(
I1
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I2
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I4
),
b_grid_desc
_k0_n_k1
.
GetLength
(
I0
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I3
)
*
b_grid_desc
_k0_n_k1
.
GetLength
(
I5
));
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
b_grid_desc
.
GetLength
(
I4
),
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I5
));
}
};
...
...
@@ -702,7 +748,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
:
0
;
static
constexpr
auto
b_block_space_size_aligned
=
BEnableLds
?
math
::
integer_least_multiple
(
Get
BBlockDescriptor
_K0PerBlock_NPerBlock_K1
().
GetElementSpaceSize
(),
Make
BBlockDescriptor
().
GetElementSpaceSize
(),
max_lds_align
)
:
0
;
...
...
@@ -737,7 +783,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
EDataType
*
__restrict__
p_e_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc
&
a_grid_desc
,
const
BGridDesc
_K0_N_K1
&
b_grid_desc
_k0_n_k1
,
const
BGridDesc
&
b_grid_desc
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
&
...
...
@@ -753,7 +799,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc
_k0_n_k1
.
GetElementSpaceSize
());
p_b_grid
,
b_grid_desc
.
GetElementSpaceSize
());
const
auto
ds_grid_buf
=
generate_tuple
(
[
&
](
auto
i
)
{
return
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
...
...
@@ -789,7 +835,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}();
constexpr
auto
a_block_desc
=
MakeABlockDescriptor
();
constexpr
auto
b_block_desc
=
Get
BBlockDescriptor
_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc
=
Make
BBlockDescriptor
();
auto
a_block_trait
=
[
&
](){
// A matrix blockwise copy
...
...
@@ -886,7 +932,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
BBlockTransferThreadClusterArrangeOrder
,
BDataType
,
BDataType
,
decltype
(
b_grid_desc
_k0_n_k1
),
decltype
(
b_grid_desc
),
decltype
(
b_block_desc
),
BBlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
,
2
>
,
...
...
@@ -898,7 +944,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
1
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_block_desc
,
...
...
@@ -909,22 +955,36 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
}
else
{
constexpr
auto
K0PerBlock
=
KPerBlock
/
K1
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
b_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
auto
b_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v
4
<
BDataType
,
ThreadwiseTensorSliceTransfer_v
2
<
BDataType
,
BDataType
,
decltype
(
b_grid_desc
_k0_n_k1
),
decltype
(
b_grid_desc
),
decltype
(
b_block_desc
),
Sequence
<
Number
<
K
0
PerBlock
>
{},
Sequence
<
Number
<
K
Wmma
PerBlock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
BBlockTransferSrcScalarPerVector
,
1
>
(
make_multi_index
(
0
,
get_thread_local_1d_id
()
/
32
*
16
+
get_thread_local_1d_id
()
%
16
,
0
));
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
0
));
return
make_tuple
(
b_block_buf
,
b_blockwise_copy
);
}
...
...
@@ -945,7 +1005,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
BDataType
,
AccDataType
,
decltype
(
MakeAWaveDescriptor
(
a_block_desc
)),
decltype
(
MakeB
Block
Descriptor
_K0_N0_N1_N2_K1
(
b_block_desc
)),
decltype
(
MakeB
Wave
Descriptor
(
b_block_desc
)),
MPerBlock
,
NPerBlock
,
KPerBlock
,
...
...
@@ -973,7 +1033,7 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc
_k0_n_k1
,
b_grid_desc
,
b_block_desc
,
b_blockwise_copy
,
b_grid_buf
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment