Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
3ccfb0ae
Commit
3ccfb0ae
authored
May 19, 2023
by
aska-0096
Browse files
(2/5) bilinear gemm pass, perf bug: skip a lds has lower performance than skip b lds
parent
c713d224
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
134 additions
and
182 deletions
+134
-182
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
+9
-9
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+29
-57
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+9
-39
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+86
-76
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-1
No files found.
example/02_gemm_bilinear/gemm_bilinear_wmma_fp16.cpp
View file @
3ccfb0ae
...
@@ -80,33 +80,33 @@ using DeviceOpInstance =
...
@@ -80,33 +80,33 @@ using DeviceOpInstance =
BElementOp
,
BElementOp
,
CDEElementOp
,
CDEElementOp
,
GemmSpec
,
GemmSpec
,
2
,
1
,
128
,
128
,
64
,
64
,
128
,
64
,
64
,
8
,
64
,
4
,
16
,
16
,
16
,
16
,
2
,
1
,
4
,
4
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
4
,
8
,
4
,
true
,
true
,
S
<
4
,
32
,
1
>
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
4
,
8
,
4
,
true
,
true
,
1
,
1
,
1
,
1
,
S
<
1
,
32
,
1
,
4
>
,
S
<
1
,
64
,
1
,
2
>
,
8
>
;
8
>
;
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
3ccfb0ae
...
@@ -87,6 +87,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -87,6 +87,7 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
I6
=
Number
<
6
>
{};
// K1 = Max Vector Access Pixels
// K1 = Max Vector Access Pixels
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
static
constexpr
auto
K1Number
=
Number
<
K1
>
{};
...
@@ -98,8 +99,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -98,8 +99,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
@@ -144,18 +145,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -144,18 +145,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
else
else
{
{
constexpr
auto
A_KRow
=
WmmaK
/
K1
;
constexpr
auto
A_KRow
=
2
;
constexpr
auto
A_K0PerWmma
=
WmmaK
/
A_KRow
/
K1Number
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
A_KWmma
=
K
/
WmmaK
;
const
auto
M0
=
M
/
MPerBlock
;
const
auto
M0
=
M
/
MPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
a_grid_desc_m_k
,
a_grid_desc_m_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_KRow
>
{},
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
A_KWmma
,
Number
<
A_K0PerWmma
>
{},
Number
<
A_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
M0
*
MRepeat
,
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
}
}
...
@@ -195,18 +199,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -195,18 +199,21 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
}
}
else
else
{
{
constexpr
auto
B_KRow
=
WmmaK
/
K1
;
constexpr
auto
B_KRow
=
2
;
constexpr
auto
B_K0PerWmma
=
WmmaK
/
B_KRow
/
K1Number
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
B_KWmma
=
K
/
WmmaK
;
const
auto
N0
=
N
/
NPerBlock
;
const
auto
N0
=
N
/
NPerBlock
;
// 0 1 0 1 2 3 4 5 6
// M - K <-> A_KWmma - MBlock*MRepeat - MWaves - A_K0PerWmma - A_KRow - MPerWmma - A_K1
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
b_grid_desc_n_k
,
b_grid_desc_n_k
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_KRow
>
{},
K1Number
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
B_KWmma
,
Number
<
B_K0PerWmma
>
{},
Number
<
B_KRow
>
{},
K1Number
)),
make_unmerge_transform
(
make_unmerge_transform
(
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
N0
*
NRepeat
,
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{}))),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
1
>
{},
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
3
,
5
>
{},
Sequence
<
1
,
2
,
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
,
4
,
6
>
{},
Sequence
<
1
,
2
,
5
>
{}));
}
}
}
}
...
@@ -438,14 +445,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -438,14 +445,11 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
else
else
{
{
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
return
arg
.
a_grid_desc
.
GetLength
(
I0
)
*
arg
.
a_grid_desc
.
GetLength
(
I3
)
*
arg
.
a_grid_desc
.
GetLength
(
I
5
);
arg
.
a_grid_desc
.
GetLength
(
I
4
)
*
arg
.
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
float
ave_time
=
0
;
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
GridwiseOp
,
GridwiseOp
,
ADataType
,
ADataType
,
...
@@ -462,9 +466,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -462,9 +466,9 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
true
>
;
// Last Option is W/O
has_main_k_block_loop
>
;
// Last Option is W/O
ave_time
=
return
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
...
@@ -482,48 +486,16 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -482,48 +486,16 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseOp
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_mupltipe_d_wmma_cshuffle
<
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
GridwiseOp
,
ADataType
,
BDataType
,
typename
GridwiseOp
::
DsGridPointer
,
EDataType
,
remove_reference_t
<
typename
DeviceOp
::
AGridDesc
>
,
remove_reference_t
<
typename
DeviceOp
::
BGridDesc
>
,
remove_reference_t
<
typename
GridwiseOp
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
remove_reference_t
<
typename
GridwiseOp
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
remove_reference_t
<
typename
GridwiseOp
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_ds_grid_
,
arg
.
p_e_grid_
,
arg
.
a_grid_desc
,
arg
.
b_grid_desc
,
arg
.
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
e_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
cde_element_op_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
3ccfb0ae
...
@@ -382,11 +382,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -382,11 +382,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg
.
a_grid_desc_
.
GetLength
(
I4
)
*
arg
.
a_grid_desc_
.
GetLength
(
I6
);
arg
.
a_grid_desc_
.
GetLength
(
I4
)
*
arg
.
a_grid_desc_
.
GetLength
(
I6
);
}
}
}();
}();
auto
launch_kernel
=
[
&
](
auto
has_main_k_block_loop
)
{
float
ave_time
=
0
;
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
const
auto
kernel
=
kernel_gemm_wmma
<
const
auto
kernel
=
kernel_gemm_wmma
<
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
ADataType
,
...
@@ -400,9 +396,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -400,9 +396,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
true
>
;
// Last Option is W/O
has_main_k_block_loop
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
return
launch_and_time_kernel
(
stream_config
,
kernel
,
kernel
,
dim3
(
grid_size
),
dim3
(
grid_size
),
dim3
(
BlockSize
),
dim3
(
BlockSize
),
...
@@ -417,42 +413,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -417,42 +413,16 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
else
else
{
{
const
auto
kernel
=
kernel_gemm_wmma
<
return
launch_kernel
(
integral_constant
<
bool
,
false
>
{});
GridwiseGemm
,
ADataType
,
BDataType
,
CDataType
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
AGridDesc
>
,
remove_reference_t
<
DeviceGemmWmma_CShuffle
::
BGridDesc
>
,
remove_reference_t
<
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
remove_reference_t
<
typename
GridwiseGemm
::
DefaultBlock2CTileMap
>
,
false
>
;
ave_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
grid_size
),
dim3
(
BlockSize
),
0
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
a_grid_desc_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
block_2_ctile_map_
);
}
}
return
ave_time
;
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
3ccfb0ae
...
@@ -379,10 +379,23 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -379,10 +379,23 @@ struct GridwiseGemmMultipleD_Wmma
else
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->MRepeat->MWave->KRow->MPerWmma->K1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
MRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
Number
<
MRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
Number
<
MRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
K1
,
K1
,
K1
,
I1
));
}
}
}();
}();
...
@@ -413,10 +426,23 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -413,10 +426,23 @@ struct GridwiseGemmMultipleD_Wmma
else
else
{
{
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
// KWmma->NRepeat->NWave->NRow->NPerWmma->K1 Per Thread
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
K1
),
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
NRepeat
>
{}
*
K1
,
K1
,
K1
,
K1
,
K1
,
I1
));
Number
<
NRepeat
>
{},
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
K1
),
make_tuple
(
Number
<
NRepeat
>
{}
*
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
Number
<
K0PerWmma
>
{}
*
K1
,
K1
,
K1
,
K1
,
I1
));
}
}
}();
}();
...
@@ -436,7 +462,7 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -436,7 +462,7 @@ struct GridwiseGemmMultipleD_Wmma
{
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -456,7 +482,7 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -456,7 +482,7 @@ struct GridwiseGemmMultipleD_Wmma
{
{
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
);
return
make_multi_index
(
KWmmaPerBlock
,
0
,
0
,
0
,
0
,
0
,
0
);
}
}
}();
}();
...
@@ -471,45 +497,33 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -471,45 +497,33 @@ struct GridwiseGemmMultipleD_Wmma
constexpr
auto
a_wave_desc
=
[
&
]()
{
constexpr
auto
a_wave_desc
=
[
&
]()
{
if
constexpr
(
AEnableLds
)
if
constexpr
(
AEnableLds
)
{
{
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_
AKRow_
MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
A_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{}
,
A_KRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWmma
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// KWmma_MRepeat_MWave_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
// KWmma_MRepeat_MWave_
K0PerWmma_
KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
return
transform_tensor_descriptor
(
Number
<
MRepeat
>
{},
ABlockDesc_
{},
I1
,
make_tuple
(
make_freeze_transform
(
I0
),
Number
<
A_KRow
>
{},
make_pass_through_transform
(
Number
<
KWmma
>
{}),
I1
,
make_pass_through_transform
(
Number
<
MRepeat
>
{}),
Number
<
A_K1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
A_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -525,42 +539,31 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -525,42 +539,31 @@ struct GridwiseGemmMultipleD_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I1
;
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_
pass_through_transform
(
Number
<
B_K0
>
{}),
make_tuple
(
make_
unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{}
,
B_KRow
)
),
make_unmerge_transform
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWmma
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
make_tuple
(
Sequence
<
0
,
3
>
{},
Sequence
<
1
,
2
,
4
>
{},
Sequence
<
5
>
{}));
}
}
else
else
{
{
// KWmma_
N
Repeat_
N
Wave_KRow_
N
PerWmma_K1 -> K0_
N
Repeat_
N
waves_
N
PerWmma_K1
// KWmma_
M
Repeat_
M
Wave_
K0PerWmma_
KRow_
M
PerWmma_K1 -> K0_
M
Repeat_
M
waves_
M
PerWmma_K1
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I5
);
constexpr
auto
K0PerWmma
=
BBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
BBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
// Workaround, Freeze transform
return
transform_tensor_descriptor
(
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
BBlockDesc_
{},
Number
<
NRepeat
>
{},
make_tuple
(
make_freeze_transform
(
I0
),
I1
,
make_pass_through_transform
(
Number
<
KWmma
>
{}),
Number
<
B_KRow
>
{},
make_pass_through_transform
(
Number
<
NRepeat
>
{}),
I1
,
make_pass_through_transform
(
I1
),
Number
<
B_K1
>
{}));
make_pass_through_transform
(
I1
),
make_pass_through_transform
(
Number
<
B_K1
>
{})),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
4
>
{},
Sequence
<
5
>
{}),
make_tuple
(
Sequence
<>
{},
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
}
}
}();
}();
...
@@ -620,9 +623,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -620,9 +623,9 @@ struct GridwiseGemmMultipleD_Wmma
else
else
{
{
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
a_grid_desc
.
GetLength
(
I1
)
*
a_grid_desc
.
GetLength
(
I2
)
*
a_grid_desc
.
GetLength
(
I
4
),
a_grid_desc
.
GetLength
(
I
5
),
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I
5
));
a_grid_desc
.
GetLength
(
I
4
)
*
a_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -635,9 +638,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -635,9 +638,9 @@ struct GridwiseGemmMultipleD_Wmma
else
else
{
{
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
return
make_tuple
(
b_grid_desc
.
GetLength
(
I1
)
*
b_grid_desc
.
GetLength
(
I2
)
*
b_grid_desc
.
GetLength
(
I
4
),
b_grid_desc
.
GetLength
(
I
5
),
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I0
)
*
b_grid_desc
.
GetLength
(
I3
)
*
b_grid_desc
.
GetLength
(
I
5
));
b_grid_desc
.
GetLength
(
I
4
)
*
b_grid_desc
.
GetLength
(
I6
));
}
}
};
};
...
@@ -837,7 +840,8 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -837,7 +840,8 @@ struct GridwiseGemmMultipleD_Wmma
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I2
);
}
}
else
{
else
{
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I5
);
return
a_grid_desc
.
GetLength
(
I0
)
*
a_grid_desc
.
GetLength
(
I3
)
*
a_grid_desc
.
GetLength
(
I4
)
*
a_grid_desc
.
GetLength
(
I6
);
}
}
}();
}();
...
@@ -888,8 +892,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -888,8 +892,9 @@ struct GridwiseGemmMultipleD_Wmma
else
else
{
{
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves ->
WmmaK/K1
-> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves ->
K0PerWmma -> KRow
-> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
a_block_desc
.
GetElementSpaceSize
());
a_block_desc
.
GetElementSpaceSize
());
...
@@ -902,11 +907,12 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -902,11 +907,12 @@ struct GridwiseGemmMultipleD_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -914,6 +920,7 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -914,6 +920,7 @@ struct GridwiseGemmMultipleD_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
m_block_data_idx_on_grid
/
(
MWaves
*
MPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
@@ -967,7 +974,8 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -967,7 +974,8 @@ struct GridwiseGemmMultipleD_Wmma
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
// KPerBlock/WmmaK -> NRepeat -> NWaves -> WmmaK/K1 -> NPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
b_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
BDataType
>
(
b_block_desc
.
GetElementSpaceSize
());
b_block_desc
.
GetElementSpaceSize
());
// Limitation: NumDim of Src and Dst descriptor should be identical
// Limitation: NumDim of Src and Dst descriptor should be identical
...
@@ -979,11 +987,12 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -979,11 +987,12 @@ struct GridwiseGemmMultipleD_Wmma
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Sequence
<
Number
<
KWmmaPerBlock
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
Number
<
K0PerWmma
>
{},
I1
,
I1
,
I1
,
I1
,
Number
<
K1Value
>
{}
>
,
Number
<
K1Value
>
{}
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
,
6
>
,
5
,
6
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
true
>
(
...
@@ -991,6 +1000,7 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -991,6 +1000,7 @@ struct GridwiseGemmMultipleD_Wmma
make_multi_index
(
0
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
n_block_data_idx_on_grid
/
(
NWaves
*
NPerWmma
),
get_thread_local_1d_id
()
/
32
,
get_thread_local_1d_id
()
/
32
,
0
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
(
get_thread_local_1d_id
()
%
32
)
/
16
,
get_thread_local_1d_id
()
%
16
,
get_thread_local_1d_id
()
%
16
,
0
));
0
));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
3ccfb0ae
...
@@ -655,7 +655,7 @@ struct GridwiseGemm_Wmma
...
@@ -655,7 +655,7 @@ struct GridwiseGemm_Wmma
else
else
{
{
// Thread-wise copy
// Thread-wise copy
// KPerBlock/WmmaK -> MRepeat -> MWaves ->
WmmaK/K1
-> MPerWmma -> K1
// KPerBlock/WmmaK -> MRepeat -> MWaves ->
K0PerWmma -> KRow
-> MPerWmma -> K1
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerBlock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1Value
;
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
auto
a_block_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
ADataType
>
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment