Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
760b0c75
Commit
760b0c75
authored
Feb 23, 2024
by
Jing Zhang
Browse files
add wmma
parent
8831b0d8
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
100 additions
and
11 deletions
+100
-11
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+1
-1
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+15
-6
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
+5
-1
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+53
-2
include/ck/utility/amd_wmma.hpp
include/ck/utility/amd_wmma.hpp
+25
-0
No files found.
example/01_gemm/CMakeLists.txt
View file @
760b0c75
...
...
@@ -27,7 +27,7 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_wavelet_fp16)
add_example_executable
(
example_gemm_xdl_skip_b_lds_fp16 gemm_xdl_skip_b_lds_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
OR GPU_TARGETS MATCHES
"gfx1101"
OR GPU_TARGETS MATCHES
"gfx1102"
OR GPU_TARGETS MATCHES
"gfx1200"
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_example_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
760b0c75
...
...
@@ -35,7 +35,8 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
// static constexpr auto WmmaK = Number<16>{};
static
constexpr
auto
WmmaK
=
Number
<
8
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
...
...
@@ -141,6 +142,9 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
static_assert
(
WmmaK
%
A_K1
==
0
,
""
);
static_assert
(
WmmaK
%
B_K1
==
0
,
""
);
}
// Thread level, register decriptor. Vector-write
...
...
@@ -154,6 +158,10 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
constexpr
auto
NThreadPerSubGroup
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I1
];
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
static_assert
(
MSubGroup
==
1
,
""
);
static_assert
(
NThreadPerSubGroup
==
1
,
""
);
static_assert
(
MAccVgprs
==
8
,
""
);
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
...
...
@@ -224,6 +232,7 @@ struct BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1m2n0n1n2m3_CShuffle
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_assert
(
0
,
""
);
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
760b0c75
...
...
@@ -411,7 +411,7 @@ struct DeviceGemmWmma_CShuffle : public DeviceGemm<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_navi3_supported
())
if
(
ck
::
is_navi3_supported
()
||
ck
::
is_navi4_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma.hpp
View file @
760b0c75
...
...
@@ -49,7 +49,7 @@ __global__ void
const
CElementwiseOperation
c_element_op
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
...
...
@@ -480,6 +480,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmma
constexpr
auto
NThreadPerSubGroup
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I5
);
constexpr
auto
MAccVgprs
=
c_block_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs_tmp
.
GetLength
(
I6
);
static_assert
(
MSubGroup
==
2
,
""
);
static_assert
(
NThreadPerSubGroup
==
16
,
""
);
static_assert
(
MAccVgprs
==
8
,
""
);
// LDS descriptor, shuffle and write out in MRepeat x NRepeat times
constexpr
auto
c_shuffle_block_desc_mshrepeat_mpershrepeat_nshrepeat_npershrepeat
=
GetCShuffleBlockDescriptor_MShRepeat_MPerShRepeat_NShRepeat_NPerShRepeat
();
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
760b0c75
...
...
@@ -11,12 +11,15 @@ namespace ck {
enum
struct
WmmaInstr
{
// gfx11
wmma_f32_16x16x16_f16
=
0
,
wmma_f32_16x16x16_bf16
,
wmma_f16_16x16x16_f16
,
wmma_bf16_16x16x16_bf16
,
wmma_i32_16x16x16_iu8
,
wmma_i32_16x16x16_iu4
wmma_i32_16x16x16_iu4
,
// gfx12
wmma_f32_16x16x16_f16_gfx12
,
};
/*
...
...
@@ -117,6 +120,47 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16,
}
};
// A-swizzled
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
,
WaveSize
,
typename
std
::
enable_if_t
<
WaveSize
==
32
||
WaveSize
==
64
>>
{
// Absolute fixing property
// * Data Pixel
static
constexpr
index_t
m_per_wmma
=
16
;
static
constexpr
index_t
n_per_wmma
=
16
;
static
constexpr
index_t
k_per_wmma
=
16
;
static
constexpr
index_t
src_a_data_size
=
2
;
static
constexpr
index_t
src_b_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
// * Thread mapping inside wave, num_thread_per_subgroups always alone N direction
static
constexpr
index_t
num_thread_per_subgroups
=
n_per_wmma
;
// Wave mode dependent propety
static
constexpr
index_t
wave_size
=
Number
<
WaveSize
>
{};
// * Fixed in Navi3x, Will be wave mode dependent on Navi4x
static
constexpr
index_t
num_src_a_vgprs_per_wave
=
m_per_wmma
*
src_a_data_size
/
4
;
static
constexpr
index_t
num_src_b_vgprs_per_wave
=
n_per_wmma
*
src_b_data_size
/
4
;
// * num_acc_vgprs_per_wave alone M direction
// * num_subgroups alone M direction
static
constexpr
index_t
num_acc_vgprs_per_wave
=
m_per_wmma
*
n_per_wmma
*
acc_data_size
/
wave_size
/
4
;
static
constexpr
index_t
num_subgroups
=
wave_size
/
num_thread_per_subgroups
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
{
if
constexpr
(
wave_size
==
32
)
{
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
MPerWmma
,
NPerWmma
>::
Run
(
a
,
b
,
reg_c
);
}
else
if
constexpr
(
wave_size
==
64
)
{
}
}
};
template
<
index_t
WaveSize
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_bf16
,
WaveSize
,
...
...
@@ -300,7 +344,11 @@ struct WmmaSelector
template
<
>
static
constexpr
auto
GetWmma
<
half_t
,
half_t
,
float
,
16
,
16
>
()
{
#if 1
return
WmmaInstr
::
wmma_f32_16x16x16_f16_gfx12
;
#else
return
WmmaInstr
::
wmma_f32_16x16x16_f16
;
#endif
}
template
<
>
...
...
@@ -397,6 +445,8 @@ struct WmmaGemm
const
auto
NWave
=
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
.
GetLength
(
I4
);
static_assert
(
wmma_instr
.
num_acc_vgprs_per_wave
==
8
,
""
);
return
transform_tensor_descriptor
(
c_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
,
make_tuple
(
...
...
@@ -477,7 +527,8 @@ struct WmmaGemm
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
return
GetSwizzledLaneIdLow
();
// return GetSwizzledLaneIdLow();
return
GetLaneIdUnderSubGroup
();
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
...
...
include/ck/utility/amd_wmma.hpp
View file @
760b0c75
...
...
@@ -39,6 +39,31 @@ struct intrin_wmma_f32_16x16x16_f16_w32<16, 16>
}
};
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
;
template
<
>
struct
intrin_wmma_f32_16x16x16_f16_w32_gfx12
<
16
,
16
>
{
template
<
class
FloatC
>
__device__
static
void
Run
(
const
half8_t
&
reg_a
,
const
half8_t
&
reg_b
,
FloatC
&
reg_c
)
{
// * Inline assembly need to elimate the duplicated data load, compiler won't help you
// delete them.
// amd_assembly_wmma_f32_16x16x16_f16_w32(
// reg_a, reg_b, reg_c.template AsType<float8_t>()(Number<0>{}));
#if defined(__gfx12__)
reg_c
.
template
AsType
<
float8_t
>()(
Number
<
0
>
{})
=
__builtin_amdgcn_wmma_f32_16x16x16_f16_w32_gfx12
(
reg_a
,
reg_b
,
reg_c
.
template
AsType
<
float8_t
>()[
Number
<
0
>
{}]);
#else
ignore
=
reg_a
;
ignore
=
reg_b
;
ignore
=
reg_c
;
#endif
}
};
// src: bf16, dst: fp32
template
<
index_t
MPerWave
,
index_t
NPerWave
>
struct
intrin_wmma_f32_16x16x16_bf16_w32
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment