Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
579f84c6
Commit
579f84c6
authored
Mar 06, 2023
by
aska-0096
Browse files
tempsave
parent
7e003d31
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
395 additions
and
119 deletions
+395
-119
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+4
-4
example/01_gemm/run_gemm_example.inc
example/01_gemm/run_gemm_example.inc
+1
-1
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
...emm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
+1
-1
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+2
-2
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+2
-3
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+3
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+374
-107
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
...k/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
+5
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+3
-1
No files found.
example/01_gemm/gemm_wmma_fp16.cpp
View file @
579f84c6
...
@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -37,13 +37,13 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
GemmDefault
,
GemmDefault
,
256
,
// BlockSize
256
,
// BlockSize
128
,
// MPerBlock
128
,
// MPerBlock
1
6
,
// NPerBlock
1
28
,
// NPerBlock
32
,
// KPerBlock
32
,
// KPerBlock
8
,
// K1
8
,
// K1
16
,
// MPerWmma
16
,
// MPerWmma
16
,
// NPerWmma
16
,
// NPerWmma
1
,
// M Repeat
2
,
// M Repeat
1
,
// N-Repeat
4
,
// N-Repeat
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
...
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true
,
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
128
,
1
,
2
>
,
S
<
1
,
64
,
1
,
4
>
,
8
>
;
8
>
;
// clang-format on
// clang-format on
...
...
example/01_gemm/run_gemm_example.inc
View file @
579f84c6
...
@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
...
@@ -44,7 +44,7 @@ bool run_gemm(const ProblemSize& problem_size, const ExecutionConfig& config)
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
-
5.
f
,
5.
f
}(
b_k_n
);
break
;
break
;
case
4
:
case
4
:
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
-
5
.
f
,
5
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
ADataType
>
{
1
.
f
,
1
.
f
}(
a_m_k
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
ck
::
utils
::
FillUniformDistributionIntegerValue
<
BDataType
>
{
1.
f
,
1.
f
}(
b_k_n
);
break
;
break
;
default
:
default
:
...
...
example/32_batched_gemm_scale_softmax_gemm/batched_gemm_scale_softmax_gemm_permute_wmma_fp16.cpp
View file @
579f84c6
...
@@ -129,7 +129,7 @@ using DeviceGemmInstance =
...
@@ -129,7 +129,7 @@ using DeviceGemmInstance =
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
1
,
8
,
8
,
1
,
1
,
// be eight?
false
,
false
,
1
,
// CShuffleMWmmaPerWavePerShuffle
1
,
// CShuffleMWmmaPerWavePerShuffle
2
,
// CShuffleNWmmaPerWavePerShuffle
2
,
// CShuffleNWmmaPerWavePerShuffle
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
579f84c6
...
@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -33,9 +33,9 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
printf
(
"Warm up 1 time
\n
"
);
printf
(
"Warm up 1 time
\n
"
);
#endif
#endif
// warm up
// warm up
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
//
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
const
int
nrepeat
=
1
00
;
const
int
nrepeat
=
1
;
#if DEBUG_LOG
#if DEBUG_LOG
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
#endif
#endif
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
579f84c6
...
@@ -27,6 +27,8 @@ template <index_t BlockSize,
...
@@ -27,6 +27,8 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
NRepeat
,
index_t
KPack
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* Source
...
@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA
...
@@ -83,9 +85,6 @@ struct BlockwiseGemmWMMA
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
static
constexpr
bool
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
bool
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
// Read from Lds, duplicate Twice, Read from VGPR, no duplication.
static
constexpr
index_t
A_Data_Duplicated_Rate
=
AEnableLds
?
2
:
1
;
static
constexpr
index_t
A_Data_Duplicated_Rate
=
AEnableLds
?
2
:
1
;
static
constexpr
index_t
B_Data_Duplicated_Rate
=
BEnableLds
?
2
:
1
;
static
constexpr
index_t
B_Data_Duplicated_Rate
=
BEnableLds
?
2
:
1
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
579f84c6
...
@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
...
@@ -89,6 +89,9 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds
=
MWaves
==
1
?
false
:
true
;
// static constexpr auto AEnableLds = true;
// static constexpr auto BEnableLds = true;
static
constexpr
auto
matrix_padder
=
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
// Describe how data read from Global memory
// Describe how data read from Global memory
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
579f84c6
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_v1.hpp
View file @
579f84c6
...
@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true>
...
@@ -56,6 +56,8 @@ struct GridwiseGemmPipeline_v1<1, true, true>
CThreadBuffer
&
c_thread_buf
,
CThreadBuffer
&
c_thread_buf
,
index_t
num_loop
)
index_t
num_loop
)
{
{
if
(
get_thread_local_1d_id
()
<
32
);
printf
(
"Mat-A Lds Enabled, Mat-B Lds Enabled
\n
"
);
// preload data into LDS
// preload data into LDS
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
a_blockwise_copy
.
RunRead
(
a_grid_desc
,
a_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
...
@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
...
@@ -304,6 +306,9 @@ struct GridwiseGemmPipeline_v1<1, false, true>
},
},
Number<a_block_desc.GetLengths().GetSize()>{});
Number<a_block_desc.GetLengths().GetSize()>{});
#endif
#endif
if
(
get_thread_local_1d_id
()
<
32
);
printf
(
"Mat-A Lds Disabled, Mat-B Lds Enabled
\n
"
);
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
constexpr
auto
a_block_origin_idx
=
make_tuple
(
I0
,
I0
,
I0
,
I0
,
I0
,
I0
);
auto
a_block_buf_switch
=
a_block_buf
;
auto
a_block_buf_switch
=
a_block_buf
;
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
579f84c6
...
@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
...
@@ -694,7 +694,9 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
NPerWmma
,
NPerWmma
,
MRepeat
,
MRepeat
,
NRepeat
,
NRepeat
,
KPack
>
{};
KPack
,
AEnableLds
,
BEnableLds
>
{};
// Prepare Register for C matrix
// Prepare Register for C matrix
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
auto
c_thread_buf
=
blockwise_gemm
.
GetCThreadBuffer
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment