Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc8309db
Commit
dc8309db
authored
Mar 23, 2023
by
aska-0096
Browse files
Skip A_Lds sanity pass, Skip B_Lds scratch occured
parent
a4694341
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
174 additions
and
106 deletions
+174
-106
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+5
-5
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+115
-50
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+53
-50
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
dc8309db
...
@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -42,8 +42,8 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
8
,
// M
Repeat
8
,
// M
-Repeat // M-PerWmma / M-Repeat = M-Wave
1
,
// N-Repeat
1
,
// N-Repeat
// N-PerWmma / N-Repeat = N-Wave
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -51,16 +51,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -51,16 +51,16 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
8
,
8
,
8
,
8
,
true
,
true
,
S
<
4
,
1
6
,
1
>
,
S
<
4
,
6
4
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
2
,
8
,
8
,
8
,
8
,
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
4
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
16
,
1
,
16
>
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
8
>
;
// clang-format on
// clang-format on
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
dc8309db
...
@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -35,7 +35,7 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
// warm up
// warm up
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
// kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const
int
nrepeat
=
1
;
const
int
nrepeat
=
1
00
;
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
#endif
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
dc8309db
...
@@ -298,13 +298,20 @@ struct BlockwiseGemmWMMA
...
@@ -298,13 +298,20 @@ struct BlockwiseGemmWMMA
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
// read A
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
m0
,
I0
,
I0
,
I0
),
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_block_buf
,
a_thread_desc_
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
...
@@ -348,8 +355,66 @@ struct BlockwiseGemmWMMA
...
@@ -348,8 +355,66 @@ struct BlockwiseGemmWMMA
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
});
});
}
else
{
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
B_K1
*
B_Data_Duplicated_Rate
/
2
>
{},
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
WmmaK
/
A_K1
*
A_Data_Duplicated_Rate
/
2
>
{},
m0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
i
%
B_K1
))
>
{}];
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>()(
Number
<
0
>{}),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>()(
Number
<
0
>
{}),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
}
}
}
protected:
protected:
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
dc8309db
...
@@ -89,8 +89,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -89,8 +89,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Unconditional enable double side LDS if uncommented following
// Force enable LDS if uncommented following
// AEnableLds = true;
// AEnableLds = true;
// BEnableLds = true;
// BEnableLds = true;
...
@@ -223,8 +222,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -223,8 +222,8 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// GridwiseGemm
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_Wmma
<
using
GridwiseGemm
=
BlockSize
,
GridwiseGemm_Wmma
<
BlockSize
,
ADataType
,
ADataType
,
BDataType
,
BDataType
,
AccDataType
,
AccDataType
,
...
@@ -572,7 +571,11 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -572,7 +571,11 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
<<
MRepeat
<<
", "
<<
MRepeat
<<
", "
<<
NRepeat
<<
NRepeat
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" AEnableLds: "
<<
AEnableLds
<<
", "
<<
"BEnableLds: "
<<
BEnableLds
<<
", "
<<
"NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
<<
"LoopScheduler: "
<<
"LoopScheduler: "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
<<
LoopSchedToString
[
LoopSched
]
<<
", "
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment