Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0b914465
Commit
0b914465
authored
Feb 29, 2024
by
Jing Zhang
Browse files
fixed wmma
parent
5db68230
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
41 additions
and
42 deletions
+41
-42
cmake/EnableCompilerWarnings.cmake
cmake/EnableCompilerWarnings.cmake
+1
-1
example/01_gemm/gemm_wmma_fp16.cpp
example/01_gemm/gemm_wmma_fp16.cpp
+13
-13
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+11
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+1
-5
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+13
-13
library/include/ck/library/utility/check_err.hpp
library/include/ck/library/utility/check_err.hpp
+1
-1
script/cmake-ck-dev.sh
script/cmake-ck-dev.sh
+1
-1
No files found.
cmake/EnableCompilerWarnings.cmake
View file @
0b914465
...
...
@@ -66,7 +66,7 @@ else()
-Wunreachable-code
-Wunused
-Wno-reserved-identifier
-Werror
#
-Werror
-Wno-option-ignored
-Wsign-compare
-Wno-extra-semi-stmt
...
...
example/01_gemm/gemm_wmma_fp16.cpp
View file @
0b914465
...
...
@@ -34,24 +34,24 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
BElementOp
,
CElementOp
,
GemmDefault
,
1
,
// Prefetch stage
3
2
,
// BlockSize
1
6
,
// MPerBlock
1
6
,
// NPerBlock
2
,
// Prefetch stage
2
56
,
// BlockSize
1
28
,
// MPerBlock
25
6
,
// NPerBlock
64
,
// KPerBlock
8
,
// K1
16
,
// MPerWmma
16
,
// NPerWmma
1
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
1
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
8
,
1
>
,
4
,
// M-Repeat // M-PerWmma / M-Repeat = M-Wave
4
,
// N-Repeat // N-PerWmma / N-Repeat = N-Wave
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
8
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
...
...
@@ -60,7 +60,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmWmma_CShuffle
true
,
1
,
// C shuffle (M Repeat) Per store
1
,
// C shuffle (N Repeat) Per store
S
<
1
,
16
,
1
,
2
>
,
S
<
1
,
32
,
1
,
8
>
,
8
>
;
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
0b914465
...
...
@@ -108,7 +108,7 @@ struct BlockwiseGemmWMMA
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
0
,
WMMA_a_idx
,
0
);
return
make_tuple
(
0
,
0
,
waveId_m
,
wmma_gemm
.
GetSubGroupId
()
,
WMMA_a_idx
,
0
);
}
else
{
...
...
@@ -125,7 +125,7 @@ struct BlockwiseGemmWMMA
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
0
,
WMMA_b_idx
,
0
);
return
make_tuple
(
0
,
0
,
waveId_n
,
wmma_gemm
.
GetSubGroupId
()
,
WMMA_b_idx
,
0
);
}
else
{
...
...
@@ -300,6 +300,9 @@ struct BlockwiseGemmWMMA
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_assert
(
KPack
%
(
A_K1
*
A_KRow
)
==
0
,
""
);
static_assert
(
KPack
%
(
B_K1
*
B_KRow
)
==
0
,
""
);
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
...
...
@@ -309,7 +312,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -319,7 +322,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -365,7 +368,7 @@ struct BlockwiseGemmWMMA
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
B_K1
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -373,7 +376,7 @@ struct BlockwiseGemmWMMA
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
make_tuple
(
Number
<
k
*
KPack
/
A_K1
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
...
...
@@ -416,7 +419,7 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
Number
<
A_KRow
*
A_K1
>
{},
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
...
...
@@ -425,7 +428,7 @@ struct BlockwiseGemmWMMA
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
B_KRow
*
B_K1
>
{},
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
0b914465
...
...
@@ -135,7 +135,7 @@ struct GridwiseGemm_Wmma
static
constexpr
auto
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWmma
);
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
(
K1
==
16
)
?
32
:
16
;
static
constexpr
auto
WmmaK
=
16
;
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -841,10 +841,6 @@ struct GridwiseGemm_Wmma
constexpr
auto
NThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I6
);
static_assert
(
MSubGroup
==
2
,
""
);
static_assert
(
NThreadPerSubGroup
==
16
,
""
);
static_assert
(
MAccVgprs
==
8
,
""
);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
0b914465
...
...
@@ -132,9 +132,9 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
//
static constexpr index_t src_a_data_size = 2;
//
static constexpr index_t src_b_data_size = 2;
//
static constexpr index_t acc_data_size = 4;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
...
...
@@ -145,8 +145,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_gfx12,
// static constexpr index_t num_src_b_vgprs_per_wave = k_per_wmma / 2 * src_b_data_size / 4;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
/
wave_size
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
...
...
@@ -390,10 +389,12 @@ struct WmmaSelector
static_assert
(
selected_wmma
.
k_per_wmma
==
16
,
"WRONG! WMMA_M must equal to 16"
);
#if 0
static_assert(selected_wmma.wave_size * selected_wmma.num_acc_vgprs_per_wave *
selected_wmma.acc_data_size ==
selected_wmma.m_per_wmma * selected_wmma.n_per_wmma * 4,
"WRONG! Invalid Number of Accumulator Register");
#endif
}
};
...
...
@@ -443,8 +444,6 @@ struct WmmaGemm
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
static_assert
(
wmma_instr
.
num_acc_vgprs_per_wave
==
8
,
""
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
...
...
@@ -553,6 +552,9 @@ struct WmmaGemm
__device__
static
auto
GetSubGroupId
()
{
static_assert
(
wmma_instr
.
num_thread_per_subgroups
*
wmma_instr
.
num_subgroups
==
wmma_instr
.
wave_size
,
""
);
return
(
GetLaneId
()
/
wmma_instr
.
num_thread_per_subgroups
)
%
wmma_instr
.
num_subgroups
;
}
...
...
@@ -567,13 +569,11 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
// return TransposeC ? GetLaneIdUnderSubGroup() : GetSwizzledLaneIdLow();
return
GetLaneIdUnderSubGroup
();
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
// return TransposeC ? GetSwizzledLaneIdLow() : GetLaneIdUnderSubGroup();
return
GetLaneIdUnderSubGroup
();
}
...
...
library/include/ck/library/utility/check_err.hpp
View file @
0b914465
...
...
@@ -156,7 +156,7 @@ check_err(const Range& out,
{
max_err
=
err
>
max_err
?
err
:
max_err
;
err_count
++
;
if
(
err_count
<
5
)
//
if(err_count < 5)
{
std
::
cerr
<<
msg
<<
std
::
setw
(
12
)
<<
std
::
setprecision
(
7
)
<<
" out["
<<
i
<<
"] != ref["
<<
i
<<
"]: "
<<
o
<<
" != "
<<
r
<<
std
::
endl
;
...
...
script/cmake-ck-dev.sh
View file @
0b914465
...
...
@@ -10,7 +10,7 @@ cmake
-D
CMAKE_CXX_COMPILER
=
/opt/rocm/bin/hipcc
\
-D
CMAKE_CXX_FLAGS
=
"-std=c++17 -O3 -ftemplate-backtrace-limit=0 -fPIE -Wno-gnu-line-marker"
\
-D
CMAKE_BUILD_TYPE
=
Release
\
-D
BUILD_DEV
=
O
N
\
-D
BUILD_DEV
=
O
FF
\
-D
GPU_TARGETS
=
"gfx1200"
\
-D
CMAKE_VERBOSE_MAKEFILE:BOOL
=
ON
\
-D
USE_BITINT_EXTENSION_INT4
=
OFF
\
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment