Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0e70dfe9
Commit
0e70dfe9
authored
Apr 27, 2024
by
Jing Zhang
Browse files
debugging
parent
268c497c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
27 additions
and
6 deletions
+27
-6
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
.../gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+12
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+13
-2
No files found.
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_wmma_cshuffle.hpp
View file @
0e70dfe9
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
...
@@ -101,8 +101,8 @@ struct DeviceGemmMultipleD_Wmma_CShuffle : public DeviceGemmMultipleD<ALayout,
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
(
MWaves
==
1
&&
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
BLayout
>::
value
)
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
0e70dfe9
...
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -375,8 +375,9 @@ struct GridwiseGemmMultipleD_Wmma
}
}
else
else
{
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -422,8 +423,9 @@ struct GridwiseGemmMultipleD_Wmma
}
}
else
else
{
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
@@ -497,7 +499,11 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -497,7 +499,11 @@ struct GridwiseGemmMultipleD_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
...
@@ -536,7 +542,11 @@ struct GridwiseGemmMultipleD_Wmma
...
@@ -536,7 +542,11 @@ struct GridwiseGemmMultipleD_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
0e70dfe9
...
@@ -295,7 +295,12 @@ struct GridwiseGemm_Wmma
...
@@ -295,7 +295,12 @@ struct GridwiseGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
A_KRow
=
I2
;
#else
constexpr
auto
A_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
ABlockDesc_
{},
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
...
@@ -310,6 +315,7 @@ struct GridwiseGemm_Wmma
...
@@ -310,6 +315,7 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
// Err: merge transform cause non-constexpr issue
// Err: merge transform cause non-constexpr issue
...
@@ -334,7 +340,7 @@ struct GridwiseGemm_Wmma
...
@@ -334,7 +340,7 @@ struct GridwiseGemm_Wmma
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
MRepeat
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_KRow
>
{}
,
I1
,
I1
,
Number
<
A_K1
>
{}));
Number
<
A_K1
>
{}));
}
}
...
@@ -352,7 +358,11 @@ struct GridwiseGemm_Wmma
...
@@ -352,7 +358,11 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
#ifdef __gfx12__
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
B_KRow
=
I2
;
#else
constexpr
auto
B_KRow
=
I1
;
#endif
return
transform_tensor_descriptor
(
return
transform_tensor_descriptor
(
BBlockDesc_
{},
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
...
@@ -367,13 +377,14 @@ struct GridwiseGemm_Wmma
...
@@ -367,13 +377,14 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
BBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
K0PerWmma
=
BBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
BBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I6
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
NRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_KRow
>
{}
,
I1
,
I1
,
Number
<
B_K1
>
{}));
Number
<
B_K1
>
{}));
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment