Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
d16063db
Commit
d16063db
authored
Nov 22, 2022
by
aska-0096
Browse files
tempsave
parent
98ccb367
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
551 additions
and
171 deletions
+551
-171
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+5
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+433
-0
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
.../ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
+12
-12
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp
.../ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp
+60
-74
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
+41
-85
No files found.
example/01_gemm/CMakeLists.txt
View file @
d16063db
...
@@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
...
@@ -35,3 +35,8 @@ add_example_executable_no_testing(example_gemm_xdl_fp64 gemm_xdl_fp64.cpp)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_skip_b_lds_fp16
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
add_dependencies
(
example_gemm_xdl example_gemm_xdl_fp64
)
add_custom_target
(
example_gemm_wmma
)
add_example_executable
(
example_gemm_wmma_fp16 gemm_wmma_fp16.cpp
)
add_dependencies
(
example_gemm_wmma example_gemm_wmma_fp16
)
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
d16063db
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp"
#include "ck/tensor_operation/gpu/warp/wmma_gemm.hpp"
#include "ck/tensor_description/tensor_adaptor.hpp"
namespace
ck
{
enum
struct
LoopScheduler
{
Default
,
};
constexpr
LoopScheduler
make_default_loop_scheduler
()
{
return
LoopScheduler
::
Default
;
}
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
static
constexpr
index_t
WaveSize
=
get_warp_size
();
static
constexpr
index_t
MPerBlock
=
AK0MK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
NPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
KPerBlock
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
)
*
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
A_K0
=
AK0MK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
A_K1
=
AK0MK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
wmma_gemm
=
WMMAGemm
<
FloatAB
,
MPerWMMA
,
NPerWMMA
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
wmma_gemm
.
K0PerWMMA
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWMMA
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
WMMA_a_idx
[
I1
],
KPerThread
*
WMMA_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
WMMA_b_idx
[
I1
],
KPerThread
*
WMMA_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
WMMA_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
WMMA_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
(
WMMA_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
WMMA_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex8D
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
WMMA_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk4D
(
WMMA_i
,
blk_i
);
return
make_tuple
(
Number
<
m0
>
{},
Number
<
n0
>
{},
waveId_m
,
waveId_n
,
blk_idx
[
I0
],
blk_idx
[
I1
],
blk_idx
[
I2
],
blk_idx
[
I3
]);
}
__host__
__device__
BlockwiseGemmWMMA_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
()
&&
BK0NK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
wmma_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
wmma_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
wmma_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeABlockDescriptor_K0_M0_M1_M2_K1
()
{
return
transform_tensor_descriptor
(
AK0MK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
make_tuple
(
Number
<
A_K0
>
{},
Number
<
A_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_K0_N0_N1_N2_K1
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_pass_through_transform
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
4
>
{},
Sequence
<
1
,
2
,
3
>
{}));
}
static
constexpr
auto
a_block_desc_k0_m0_m1_m2_k1
=
MakeABlockDescriptor_K0_M0_M1_M2_K1
();
static
constexpr
auto
b_block_desc_k0_n0_n1_n2_k1
=
MakeBBlockDescriptor_K0_N0_N1_N2_K1
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
RepeatDiff
=
MRepeat
-
NRepeat
;
constexpr
auto
WmmaK
=
wmma_gemm
.
k_per_wmma
;
static_for
<
0
,
KPerBlock
/
WmmaK
,
1
>
{}([
&
](
auto
iWmmaK
){
// Cut to Repeat Retangle to Square, assume MRepeat > NRepeat
static_for
<
0
,
RepeatDiff
,
1
>
{}([
&
](
auto
iCut
){
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
0
,
0
,
iK
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iN
,
0
,
0
,
iK
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iCut
,
iN
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
>
{},
iCut
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
});
// Run FIFO fashion loopover in Square
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
WmmaInnerloop
){
static_for
<
WmmaInnerloop
,
NRepeat
,
1
>
{}([
&
](
auto
iN
){
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
0
,
0
,
iK
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
iN
,
0
,
0
,
iK
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
+
RepeatDiff
,
iN
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
+
RepeatDiff
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
WmmaInnerloop
+
RepeatDiff
,
MRepeat
,
1
>
{}([
&
](
auto
iM
){
vector_type
<
FloatAB
,
WmmaK
>
a_thread_vec
;
vector_type
<
FloatAB
,
WmmaK
>
b_thread_vec
;
static_for
<
0
,
WmmaK
,
1
>
{}([
&
](
auto
iK
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
0
,
0
,
iK
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
iK
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
WmmaInnerloop
,
0
,
0
,
iK
))
>
{}];
});
using
wmma_input_type
=
typename
vector_type
<
FloatAB
,
WmmaK
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
iM
,
WmmaInnerloop
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
iWmmaK
>
{},
WmmaInnerloop
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
});
});
}
protected:
// A[M0, M1, M2, K0 = WmmaK]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
WmmaK
>
{}));
// B[N0, N1, N2, K0 = WmmaK]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
WmmaK
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWMMA
()));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
A_K1
,
A_K1
>
;
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
WmmaK
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
AThreadCopy
a_thread_copy_
{
CalculateAThreadOriginDataIndex
()};
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
AK0MK1BlockDesc
,
typename
BK0NK1BlockDesc
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
LoopScheduler
LoopSched
>
constexpr
auto
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2_Selector
()
{
if
constexpr
(
LoopSched
==
LoopScheduler
::
Default
)
{
return
BlockwiseGemmWMMA_k0mk1_k0nk1_m0m1n0n1n2m2
<
BlockSize
,
FloatAB
,
FloatAcc
,
AK0MK1BlockDesc
,
BK0NK1BlockDesc
,
MPerWMMA
,
NPerWMMA
,
MRepeat
,
NRepeat
,
KPack
>
{};
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_gemm_wmma.hpp
View file @
d16063db
...
@@ -36,10 +36,10 @@ template <typename ADataType,
...
@@ -36,10 +36,10 @@ template <typename ADataType,
ck
::
index_t
NPerBlock
,
ck
::
index_t
NPerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K0PerBlock
,
ck
::
index_t
K1
,
ck
::
index_t
K1
,
ck
::
index_t
MPer
XDL
,
ck
::
index_t
MPer
WMMA
,
ck
::
index_t
NPer
XDL
,
ck
::
index_t
NPer
WMMA
,
ck
::
index_t
M
Xdl
PerWave
,
ck
::
index_t
M
Wmma
PerWave
,
ck
::
index_t
N
Xdl
PerWave
,
ck
::
index_t
N
Wmma
PerWave
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterLengths_K0_M_K1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
typename
ABlockTransferSrcAccessOrder
,
...
@@ -217,11 +217,11 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -217,11 +217,11 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
MPerBlock
,
MPerBlock
,
NPerBlock
,
NPerBlock
,
K0PerBlock
,
K0PerBlock
,
MPer
XDL
,
MPer
WMMA
,
NPer
XDL
,
NPer
WMMA
,
K1
,
K1
,
M
Xdl
PerWave
,
M
Wmma
PerWave
,
N
Xdl
PerWave
,
N
Wmma
PerWave
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
...
@@ -543,10 +543,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
...
@@ -543,10 +543,10 @@ struct DeviceGemmWmma : public DeviceGemm<ALayout,
<<
NPerBlock
<<
", "
<<
NPerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K0PerBlock
<<
", "
<<
K1
<<
", "
<<
K1
<<
", "
<<
MPer
XDL
<<
", "
<<
MPer
WMMA
<<
", "
<<
NPer
XDL
<<
", "
<<
NPer
WMMA
<<
", "
<<
M
Xdl
PerWave
<<
", "
<<
M
Wmma
PerWave
<<
", "
<<
N
Xdl
PerWave
<<
N
Wmma
PerWave
<<
">"
<<
">"
<<
" NumPrefetch: "
<<
" NumPrefetch: "
<<
NumPrefetch
<<
", "
<<
NumPrefetch
<<
", "
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_wmma_v1r1.hpp
View file @
d16063db
...
@@ -141,7 +141,7 @@ template <
...
@@ -141,7 +141,7 @@ template <
index_t
CBlockTransferScalarPerVector_NWaveNPerWmma
,
index_t
CBlockTransferScalarPerVector_NWaveNPerWmma
,
index_t
NumGemmKPrefetchStage
=
1
,
index_t
NumGemmKPrefetchStage
=
1
,
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
PipelineVersion
PipelineVer
=
PipelineVersion
::
v1
>
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
ops_v3r3
struct
GridwiseGemm_k0mk1_k0nk1_mn_wmma
_v1
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -160,52 +160,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -160,52 +160,35 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
using
GridwiseGemmPipe
=
remove_cvref_t
<
decltype
(
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
GridwiseGemmPipeline_Selector
<
PipelineVer
,
NumGemmKPrefetchStage
>
())
>
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_
K10_
K1PerInst
()
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_
K10_
MPerBlock_K1PerInst
()
{
{
constexpr
auto
inst_max_size
=
16
/
sizeof
(
FloatAB
);
constexpr
auto
inst_max_size
=
16
/
sizeof
(
FloatAB
);
constexpr
auto
k1perinst
=
(
K1
<
inst_max_size
)
?
K1
:
inst_max_size
;
constexpr
auto
k1perinst
=
(
K1
<
inst_max_size
)
?
K1
:
inst_max_size
;
constexpr
auto
K10
=
K1
/
k1perinst
;
constexpr
auto
K10
=
K1
/
k1perinst
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k10_k1perinst
=
[
&
]()
{
constexpr
auto
a_block_desc_k0_k10_m_k1perinst
=
[
&
]()
{
if
constexpr
(
ABlockLdsExtraM
)
// May have static err
{
return
make_naive_tensor_descriptor_aligned
(
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
K10
,
Number
<
MPerBlock
>
{},
k1perinst
),
k1perinst
);
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K1
),
make_tuple
(
Number
<
MPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
// May have static err
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
MPerBlock
>
{},
K10
,
k1perinst
),
k1perinst
);
}
}();
}();
return
a_block_desc_k0_
m_
k1
;
return
a_block_desc_k0_k1
0_m_k1perinst
;
}
}
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_NPerBlock_
K10_
K1PerInst
()
__host__
__device__
static
constexpr
auto
GetBBlockDescriptor_K0PerBlock_
K10_
NPerBlock_K1PerInst
()
{
{
constexpr
auto
inst_max_size
=
16
/
sizeof
(
FloatAB
);
constexpr
auto
inst_max_size
=
16
/
sizeof
(
FloatAB
);
constexpr
auto
k1perinst
=
(
K1
<
inst_max_size
)
?
K1
:
inst_max_size
;
constexpr
auto
k1perinst
=
(
K1
<
inst_max_size
)
?
K1
:
inst_max_size
;
constexpr
auto
K10
=
K1
/
k1perinst
;
constexpr
auto
K10
=
K1
/
k1perinst
;
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
b_block_desc_k0_n_k1
=
[
&
]()
{
constexpr
auto
b_block_desc_k0_k10_n_k1perinst
=
[
&
]()
{
if
constexpr
(
BBlockLdsExtraN
)
return
make_naive_tensor_descriptor_aligned
(
{
make_tuple
(
Number
<
K0PerBlock
>
{},
K10
,
Number
<
NPerBlock
>
{},
k1perinst
),
k1perinst
);
return
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K1
),
make_tuple
(
Number
<
NPerBlock
+
1
>
{}
*
K1
,
K1
,
I1
));
}
else
{
return
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
K0PerBlock
>
{},
Number
<
NPerBlock
>
{},
K10
,
k1perinst
),
k1perinst
);
}
}();
}();
return
b_block_desc_k0_
n_
k1
;
return
b_block_desc_k0_k1
0_n_k1perinst
;
}
}
__host__
__device__
static
constexpr
auto
__host__
__device__
static
constexpr
auto
...
@@ -230,18 +213,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -230,18 +213,20 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
__host__
__device__
static
constexpr
index_t
GetSharedMemoryNumberOfByte
()
{
{
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_desc_k0_
m_
k1
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
();
constexpr
auto
a_block_desc_k0_k1
0_m_k1perinst
=
GetABlockDescriptor_K0PerBlock_
K10_
MPerBlock_K1
PerInst
();
constexpr
auto
b_block_desc_k0_
n_
k1
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K1
();
constexpr
auto
b_block_desc_k0_k1
0_n_k1perinst
=
GetBBlockDescriptor_K0PerBlock_
K10_
NPerBlock_K1
PerInst
();
constexpr
auto
max_lds_align
=
K1
;
constexpr
auto
max_lds_align
=
a_block_desc_k0_k10_m_k1perinst
.
GetLength
(
I3
)
;
constexpr
auto
a_block_space_size_aligned
=
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_
m_
k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_k1
0_m_k1perinst
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
b_block_space_size_aligned
=
constexpr
auto
b_block_space_size_aligned
=
math
::
integer_least_multiple
(
b_block_desc_k0_n_k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
b_block_desc_k0_k10_n_k1perinst
.
GetElementSpaceSize
(),
max_lds_align
);
constexpr
auto
c_block_size
=
0
;
#ifndef DISABLE_C_SHUFFLE
// LDS allocation for C shuffle in LDS
// LDS allocation for C shuffle in LDS
constexpr
auto
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
=
constexpr
auto
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
=
GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma
();
GetCBlockDescriptor_MBlock_NWmmaPerWave_MWaveMPerWmma_NBlock_NWmmaPerWave_NWaveNPerWmma
();
...
@@ -249,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -249,7 +234,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
constexpr
auto
c_block_size
=
constexpr
auto
c_block_size
=
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
c_block_desc_mblock_mwmmaperwave_mwavemperwmma_nblock_nwmmaperwave_nwavenperwmma
.
GetElementSpaceSize
();
.
GetElementSpaceSize
();
#endif
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
return
math
::
max
((
a_block_space_size_aligned
+
b_block_space_size_aligned
)
*
sizeof
(
FloatAB
),
sizeof
(
FloatAB
),
c_block_size
*
sizeof
(
FloatC
));
c_block_size
*
sizeof
(
FloatC
));
...
@@ -423,42 +408,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -423,42 +408,42 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
const
index_t
n_block_data_idx_on_grid
=
const
index_t
n_block_data_idx_on_grid
=
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
__builtin_amdgcn_readfirstlane
(
block_work_idx
[
I1
]
*
NPerBlock
);
// lds max alignment
constexpr
auto
max_lds_align
=
K1
;
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_
m_
k10_
k11
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst
();
constexpr
auto
a_block_desc_k0_k10_
m_k1perinst
=
GetABlockDescriptor_K0PerBlock_MPerBlock_K10_K1PerInst
();
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
constexpr
auto
a_block_desc_k0_m_k10_k11
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst
();
constexpr
auto
b_block_desc_k0_k10_n_k1perinst
=
GetBBlockDescriptor_K0PerBlock_NPerBlock_K10_K1PerInst
();
// lds max alignment
constexpr
auto
max_lds_align
=
a_block_desc_k0_m_k10_k11
.
GetLength
(
I3
);
// A matrix blockwise copy
// A matrix blockwise copy
auto
a_blockwise_copy
=
auto
a_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
AElementwiseOperation
,
/* typename SrcElementwiseOperation,
*/
AElementwiseOperation
,
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
/* typename DstElementwiseOperation,
*/
ck
::
tensor_operation
::
element_wise
::
PassThrough
,
InMemoryDataOperationEnum
::
Set
,
/* InMemoryDataOperationEnum DstInMemOp,
*/
InMemoryDataOperationEnum
::
Set
,
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
/* typename BlockSliceLengths,
*/
Sequence
<
K0PerBlock
,
MPerBlock
,
K1
>
,
ABlockTransferThreadClusterLengths_K0_M_K1
,
/* typename ThreadClusterLengths,
*/
ABlockTransferThreadClusterLengths_K0_M_K1
,
ABlockTransferThreadClusterArrangeOrder
,
/* typename ThreadClusterArrangeOrder,
*/
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
/* typename SrcData,
*/
FloatAB
,
FloatAB
,
/* typename DstData,
*/
FloatAB
,
decltype
(
a_grid_desc_k0_m_k1
),
/* typename SrcDesc,
*/
decltype
(
a_grid_desc_k0_m_k1
),
decltype
(
a_block_desc_k0_
m_
k1
),
/* typename DstDesc,
*/
decltype
(
a_block_desc_k0_k1
0_m_k1perinst
),
ABlockTransferSrcAccessOrder
,
/* typename SrcDimAccessOrder,
*/
ABlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
/* typename DstDimAccessOrder,
*/
Sequence
<
1
,
0
,
2
>
,
ABlockTransferSrcVectorDim
,
/* index_t SrcVectorDim,
*/
ABlockTransferSrcVectorDim
,
2
,
/* index_t DstVectorDim,
*/
2
,
ABlockTransferSrcScalarPerVector
,
/* index_t SrcScalarPerVector,
*/
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K1
,
/* index_t DstScalarPerVector,
*/
ABlockTransferDstScalarPerVector_K1
,
1
,
/* index_t SrcScalarStrideInVector,
*/
1
,
1
,
/* index_t DstScalarStrideInVector,
*/
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
/* bool ThreadTransferSrcResetCoordinateAfterRun, */
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
/* bool ThreadTransferDstResetCoordinateAfterRun, */
true
>
(
a_grid_desc_k0_m_k1
,
a_grid_desc_k0_m_k1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_element_op
,
a_element_op
,
a_block_desc_k0_
m_
k1
,
a_block_desc_k0_k1
0_m_k1perinst
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -474,7 +459,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -474,7 +459,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_grid_desc_k0_n_k1
),
decltype
(
b_block_desc_k0_
n_
k1
),
decltype
(
b_block_desc_k0_k1
0_n_k1perinst
),
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
Sequence
<
1
,
0
,
2
>
,
Sequence
<
1
,
0
,
2
>
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
...
@@ -488,7 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -488,7 +473,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_element_op
,
b_element_op
,
b_block_desc_k0_
n_
k1
,
b_block_desc_k0_k1
0_n_k1perinst
,
make_multi_index
(
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
),
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
ck
::
tensor_operation
::
element_wise
::
PassThrough
{});
...
@@ -504,8 +489,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -504,8 +489,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
BlockwiseGemmWmmaops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_block_desc_k0_
m_
k1
),
decltype
(
a_block_desc_k0_k1
0_m_k1perinst
),
decltype
(
b_block_desc_k0_
n_
k1
),
decltype
(
b_block_desc_k0_k1
0_n_k1perinst
),
MPerWmma
,
MPerWmma
,
NPerWmma
,
NPerWmma
,
MWmmaPerWave
,
MWmmaPerWave
,
...
@@ -516,14 +501,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -516,14 +501,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size_aligned
=
constexpr
auto
a_block_space_size_aligned
=
math
::
integer_least_multiple
(
a_block_desc_k0_
m_
k1
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_block_desc_k0_k1
0_m_k1perinst
.
GetElementSpaceSize
(),
max_lds_align
);
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_
m_
k1
.
GetElementSpaceSize
());
static_cast
<
FloatAB
*>
(
p_shared
),
a_block_desc_k0_k1
0_m_k1perinst
.
GetElementSpaceSize
());
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
auto
b_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
static_cast
<
FloatAB
*>
(
p_shared
)
+
a_block_space_size_aligned
,
b_block_desc_k0_
n_
k1
.
GetElementSpaceSize
());
b_block_desc_k0_k1
0_n_k1perinst
.
GetElementSpaceSize
());
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
,
0
,
0
);
...
@@ -532,13 +517,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -532,13 +517,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
const
index_t
K0BlockMainLoop
=
__builtin_amdgcn_readfirstlane
(
K0
/
K0PerBlock
);
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
GridwiseGemmPipe
::
template
Run
<
HasMainKBlockLoop
>(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_
m_
k1
,
a_block_desc_k0_k1
0_m_k1perinst
,
a_blockwise_copy
,
a_blockwise_copy
,
a_grid_buf
,
a_grid_buf
,
a_block_buf
,
a_block_buf
,
a_block_slice_copy_step
,
a_block_slice_copy_step
,
b_grid_desc_k0_n_k1
,
b_grid_desc_k0_n_k1
,
b_block_desc_k0_
n_
k1
,
b_block_desc_k0_k1
0_n_k1perinst
,
b_blockwise_copy
,
b_blockwise_copy
,
b_grid_buf
,
b_grid_buf
,
b_block_buf
,
b_block_buf
,
...
@@ -546,7 +531,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -546,7 +531,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
blockwise_gemm
,
blockwise_gemm
,
c_thread_buf
,
c_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
#ifndef DISABLE_C_SHUFFLE
// shuffle C and write out
// shuffle C and write out
{
{
static_assert
(
MWmmaPerWave
%
CShuffleMWmmaPerWavePerShuffle
==
0
&&
static_assert
(
MWmmaPerWave
%
CShuffleMWmmaPerWavePerShuffle
==
0
&&
...
@@ -809,6 +794,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
...
@@ -809,6 +794,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_wmmaops_v3r3
}
}
});
});
}
}
#endif
}
}
};
};
...
...
include/ck/tensor_operation/gpu/warp/wmma_gemm.hpp
View file @
d16063db
...
@@ -25,15 +25,15 @@ struct wmma_type;
...
@@ -25,15 +25,15 @@ struct wmma_type;
template
<
>
template
<
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_w32
>
struct
wmma_type
<
WmmaInstr
::
wmma_f32_16x16x16_f16_w32
>
{
{
static
constexpr
index_t
m_per_w
ave
=
16
;
static
constexpr
index_t
m_per_w
mma
=
16
;
static
constexpr
index_t
n_per_w
ave
=
16
;
static
constexpr
index_t
n_per_w
mma
=
16
;
static
constexpr
index_t
k_per_w
ave
=
16
;
static
constexpr
index_t
k_per_w
mma
=
16
;
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
wave_size
=
32
;
static
constexpr
index_t
lane_size
=
16
;
static
constexpr
index_t
lane_size
=
16
;
static
constexpr
index_t
src_data_size
=
2
;
static
constexpr
index_t
src_data_size
=
2
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
acc_data_size
=
4
;
static
constexpr
index_t
num_srcregs_per_w
ave
=
8
;
static
constexpr
index_t
num_srcregs_per_w
mma
=
8
;
static
constexpr
index_t
num_accregs_per_w
ave
=
8
;
static
constexpr
index_t
num_accregs_per_w
mma
=
8
;
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
template
<
index_t
MPerWmma
,
index_t
NPerWmma
,
class
FloatA
,
class
FloatB
,
class
FloatC
>
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
__device__
void
run
(
const
FloatA
&
a
,
const
FloatB
&
b
,
FloatC
&
reg_c
)
const
...
@@ -45,7 +45,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_w32>
...
@@ -45,7 +45,7 @@ struct wmma_type<WmmaInstr::wmma_f32_16x16x16_f16_w32>
template
<
typename
src_type
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
template
<
typename
src_type
,
typename
dst_type
,
index_t
MPerWmma
,
index_t
NPerWmma
>
struct
WmmaSelector
struct
WmmaSelector
{
{
template
<
typename
src_type
,
typename
dst_type
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
template
<
typename
src_type
_
,
typename
dst_type
_
,
index_t
MPerWmma_
,
index_t
NPerWmma_
>
static
constexpr
auto
GetWmma
();
static
constexpr
auto
GetWmma
();
template
<
>
template
<
>
...
@@ -89,21 +89,21 @@ struct WmmaSelector
...
@@ -89,21 +89,21 @@ struct WmmaSelector
__host__
__device__
constexpr
WmmaSelector
()
__host__
__device__
constexpr
WmmaSelector
()
{
{
static_assert
(
selected_wmma
.
m_per_w
ave
==
selected_wmma
.
n_per_w
ave
,
static_assert
(
selected_wmma
.
m_per_w
mma
==
selected_wmma
.
n_per_w
mma
,
"WRONG! WMMA_M must equal to WMMA_N"
);
"WRONG! WMMA_M must equal to WMMA_N"
);
static_assert
(
selected_wmma
.
m_per_w
ave
==
selected_wmma
.
k_per_w
ave
,
static_assert
(
selected_wmma
.
m_per_w
mma
==
selected_wmma
.
k_per_w
mma
,
"WRONG! WMMA_M must equal to WMMA_K"
);
"WRONG! WMMA_M must equal to WMMA_K"
);
static_assert
(
selected_wmma
.
k_per_w
ave
==
16
,
static_assert
(
selected_wmma
.
k_per_w
mma
==
16
,
"WRONG! WMMA_M must equal to WMMA_N"
);
"WRONG! WMMA_M must equal to WMMA_N"
);
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_accregs_per_w
ave
*
selected_wmma
.
acc_data_size
==
static_assert
(
selected_wmma
.
wave_size
*
selected_wmma
.
num_accregs_per_w
mma
*
selected_wmma
.
acc_data_size
==
selected_wmma
.
m_per_w
ave
*
selected_wmma
.
n_per_w
ave
*
4
,
selected_wmma
.
m_per_w
mma
*
selected_wmma
.
n_per_w
mma
*
4
,
"WRONG! Number of Accumulator Register"
);
"WRONG! Number of Accumulator Register"
);
static_assert
(
selected_wmma
.
lane_size
*
selected_wmma
.
num_srcregs_per_w
ave
*
selected_wmma
.
src_data_size
==
static_assert
(
selected_wmma
.
lane_size
*
selected_wmma
.
num_srcregs_per_w
mma
*
selected_wmma
.
src_data_size
==
selected_wmma
.
m_per_w
ave
*
selected_wmma
.
k_per_w
ave
*
4
,
selected_wmma
.
m_per_w
mma
*
selected_wmma
.
k_per_w
mma
*
4
,
"WRONG! Number of Source Register"
);
"WRONG! Number of Source Register"
);
}
}
};
};
...
@@ -126,20 +126,12 @@ struct WmmaGemm
...
@@ -126,20 +126,12 @@ struct WmmaGemm
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex
=
MultiIndex
<
2
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
using
CIndex4D
=
MultiIndex
<
4
>
;
__device__
static
constexpr
index_t
GetNumBlks
()
{
return
wmma_instr
.
num_output_blks
;
}
__device__
static
constexpr
index_t
GetNumXdlops
()
{
return
MPerWmma
*
NPerWmma
/
(
wmma_instr
.
m_per_blk
*
wmma_instr
.
n_per_blk
*
wmma_instr
.
num_output_blks
);
}
__host__
__device__
constexpr
WmmaGemm
()
__host__
__device__
constexpr
WmmaGemm
()
{
{
static_assert
(
NPerWmma
==
16
&&
MPerWmma
==
16
,
static_assert
(
NPerWmma
==
16
&&
MPerWmma
==
16
,
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"
);
"Only support GemmNPerWmma == 16 and GemmMPerWmma == 16 for wmma"
);
static_assert
(
KPack
%
wmma_instr
.
k_per_w
ave
==
0
,
"KPack
cannot be divide
d b
y
k_per_w
ave
"
);
static_assert
(
KPack
==
wmma_instr
.
k_per_w
mma
,
"KPack
shoul
d b
e
k_per_w
mma
"
);
}
}
// XDL output supporting C = A * B
// XDL output supporting C = A * B
...
@@ -267,79 +259,43 @@ struct WmmaGemm
...
@@ -267,79 +259,43 @@ struct WmmaGemm
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
||
(
is_same
<
src_type
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
||
(
is_same
<
src_type
,
int4_t
>::
value
&&
is_same
<
dst_type
,
int32_t
>::
value
)
#endif
#endif
,
,
"base type couple must be (half, float), (bhalf, float), (half, half), (bhalf, bhalf), (int8, int32) or (int4, int32)!"
);
"base type couple must be (half, float), (bhalf, float), (half, half),
if
constexpr
(
!
TransposeC
)
(bhalf, bhalf), (int8, int32) or (int4, int32)!"
);
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
static_for
<
0
,
KPack
/
wmma_instr
.
k_per_wave
,
1
>
{}([
&
](
auto
k
)
{
p_a_wave
[
0
],
p_b_wave
[
0
],
p_c_thread
);
if
constexpr
(
!
TransposeC
)
}
{
else
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
{
p_a_wave
[
k
],
p_b_wave
[
k
],
p_c_thread
);
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
}
p_b_wave
[
0
],
p_a_wave
[
0
],
p_c_thread
);
else
}
{
wmma_instr
.
template
run
<
MPerWmma
,
NPerWmma
>(
p_b_wave
[
k
],
p_a_wave
[
k
],
p_c_thread
);
}
});
}
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
GetLaneId
()
{
return
get_thread_local_1d_id
()
%
wmma_instr
.
wave_size
;
}
__device__
static
auto
Get
BlkIdx
()
__device__
static
auto
Get
LaneIdHigh
()
{
{
const
auto
laneId
=
GetLaneId
();
return
GetLaneId
()
/
16
;
}
constexpr
auto
threadidx_to_blk_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
1
,
wmma_instr
.
num_input_blks
,
wmma_instr
.
num_threads_per_blk
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
const
auto
blk_idx
=
threadidx_to_blk_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
laneId
));
const
auto
blk_id
=
blk_idx
[
I1
];
const
auto
blk_td
=
blk_idx
[
I2
];
return
make_tuple
(
blk_id
,
blk_td
);
__device__
static
auto
GetLaneIdLow
()
{
return
GetLaneId
()
%
16
;
}
__device__
static
auto
GetSwizzledLaneIdLow
()
{
return
((
GetLaneIdLow
()
&
1
)
<<
3
)
|
(
GetLaneIdLow
()
>>
1
);
}
}
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
{
const
auto
laneId
=
GetLaneId
();
return
make_tuple
(
0
,
GetSwizzledLaneIdLow
());
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
wmma_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
}
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
__host__
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
{
const
auto
laneId
=
GetLaneId
();
return
make_tuple
(
0
,
GetLaneIdLow
());
const
auto
blk_idx
=
GetBlkIdx
();
const
auto
blk_id
=
blk_idx
[
I0
];
const
auto
blk_td
=
blk_idx
[
I1
];
if
constexpr
(
wmma_instr
.
is_k_reduction
)
{
return
make_tuple
(
blk_id
,
blk_td
);
}
else
{
return
make_tuple
(
0
,
laneId
);
}
}
}
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
__device__
static
CIndex
GetBeginOfThreadBlk
(
index_t
xdlops_i
,
index_t
blk_i
)
...
@@ -365,12 +321,12 @@ struct WmmaGemm
...
@@ -365,12 +321,12 @@ struct WmmaGemm
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
return
TransposeC
?
CIndex4D
{
blk_td
,
I0
,
blk_id
,
I0
}
:
CIndex4D
{
I0
,
blk_id
,
I0
,
blk_td
};
}
}
static
constexpr
auto
m
f
ma
=
Mf
maSelector
<
base
_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
w
mma
=
Wm
maSelector
<
src_type
,
dst
_type
,
MPerWmma
,
NPerWmma
>
{};
static
constexpr
auto
wmma_instr
=
m
f
ma
.
selected_m
f
ma
;
static
constexpr
auto
wmma_instr
=
w
mma
.
selected_
w
mma
;
static
constexpr
auto
KPerXdlops
=
m
f
ma
.
GetKPerXdlops
();
static
constexpr
auto
KPerXdlops
=
w
mma
.
GetKPerXdlops
();
static
constexpr
auto
K1PerXdlops
=
m
f
ma
.
GetK1PerXdlops
();
static
constexpr
auto
K1PerXdlops
=
w
mma
.
GetK1PerXdlops
();
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
static
constexpr
auto
K0PerXdlops
=
KPerXdlops
/
K1PerXdlops
;
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
__host__
__device__
static
constexpr
auto
GetCM0M1M2NThreadBlkLengths
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment