Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
268c497c
Commit
268c497c
authored
Apr 27, 2024
by
Jing Zhang
Browse files
fixed lds_enabled
parent
26e8ba9f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
19 additions
and
23 deletions
+19
-23
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+5
-5
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+4
-7
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+9
-10
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
268c497c
...
...
@@ -40,7 +40,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
64
,
// MPerBlock
128
,
// NPerBlock
64
,
// KPerBlock
4
,
// K1
2
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
2
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
...
...
@@ -49,15 +49,15 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
2
,
2
,
true
,
S
<
4
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
1
,
1
,
2
,
2
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
268c497c
...
...
@@ -70,9 +70,6 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_KRow_
=
AEnableLds
?
1
:
2
;
static
constexpr
index_t
B_KRow_
=
BEnableLds
?
1
:
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
...
...
@@ -316,7 +313,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
_
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -327,7 +324,7 @@ struct BlockwiseGemmWMMA
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
_
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -373,7 +370,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
_
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -381,7 +378,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
_
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
268c497c
...
...
@@ -97,7 +97,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds_manu
=
false
;
static
constexpr
auto
BEnableLds_manu
=
false
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
false
;
//
AEnableLds_auto || AEnableLds_manu || (NumPrefetch > 1);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
matrix_padder
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
268c497c
...
...
@@ -170,8 +170,9 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr
auto
A_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
A_KRow
/
K1
;
// KWmma->MRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -217,8 +218,10 @@ struct GridwiseGemm_Wmma
}
else
{
constexpr
auto
B_KRow
=
I2
;
constexpr
auto
KWmmaPerblock
=
KPerBlock
/
WmmaK
;
constexpr
auto
K0PerWmma
=
WmmaK
/
2
/
K1
;
constexpr
auto
K0PerWmma
=
WmmaK
/
B_KRow
/
K1
;
// KWmma->NRepeat->MWave->K0PerWmma->KRow->MPerWmma->K1 Per Thread
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KWmmaPerblock
>
{},
...
...
@@ -292,7 +295,7 @@ struct GridwiseGemm_Wmma
// AK0_M_AK1 -> AK0_MRepeat_Mwaves_AKRow_MPerWmma_AK1
constexpr
auto
A_K0
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
A_KRow
=
I
1
;
constexpr
auto
A_KRow
=
I
2
;
return
transform_tensor_descriptor
(
ABlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
A_K0
>
{},
A_KRow
)),
...
...
@@ -307,7 +310,6 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
ABlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
ABlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
A_KRow
=
ABlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
A_K1
=
ABlockDesc_
{}.
GetLength
(
I6
);
// Err: merge transform cause non-constexpr issue
...
...
@@ -332,7 +334,7 @@ struct GridwiseGemm_Wmma
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
MRepeat
>
{},
I1
,
Number
<
A_KRow
>
{}
,
I1
,
I1
,
Number
<
A_K1
>
{}));
}
...
...
@@ -350,7 +352,7 @@ struct GridwiseGemm_Wmma
// BK0_N_BK1 -> BK0_NRepeat_Nwaves_NPerWmma_BK1
constexpr
auto
B_K0
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I2
);
constexpr
auto
B_KRow
=
I
1
;
constexpr
auto
B_KRow
=
I
2
;
return
transform_tensor_descriptor
(
BBlockDesc_
{},
make_tuple
(
make_unmerge_transform
(
make_tuple
(
Number
<
B_K0
>
{},
B_KRow
)),
...
...
@@ -365,14 +367,13 @@ struct GridwiseGemm_Wmma
// KWmma_MRepeat_MWave_K0PerWmma_KRow_MPerWmma_K1 -> K0_MRepeat_Mwaves_MPerWmma_K1
constexpr
auto
KWmma
=
BBlockDesc_
{}.
GetLength
(
I0
);
constexpr
auto
K0PerWmma
=
BBlockDesc_
{}.
GetLength
(
I3
);
constexpr
auto
B_KRow
=
BBlockDesc_
{}.
GetLength
(
I4
);
constexpr
auto
B_K1
=
BBlockDesc_
{}.
GetLength
(
I6
);
// Workaround, Freeze transform
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KWmma
*
K0PerWmma
>
{},
Number
<
NRepeat
>
{},
I1
,
Number
<
B_KRow
>
{}
,
I1
,
I1
,
Number
<
B_K1
>
{}));
}
...
...
@@ -781,8 +782,6 @@ struct GridwiseGemm_Wmma
// GEMM
constexpr
auto
KPack
=
math
::
integer_least_multiple
(
K1
,
WmmaK
);
static_assert
(
KPerBlock
%
KPack
==
0
,
""
);
auto
blockwise_gemm
=
BlockwiseGemmWMMA
<
BlockSize
,
ADataType
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment