Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
8929bde2
Commit
8929bde2
authored
Jul 20, 2022
by
wangshaojie6
Browse files
add second blockwisegemm with A is in VGPR. WIP: second gemm pipeline
parent
580e9484
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
529 additions
and
65 deletions
+529
-65
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_a_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_a_lds.hpp
+316
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+1
-1
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
...k/tensor_operation/gpu/element/element_wise_operation.hpp
+9
-0
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
...or_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
+81
-60
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_skip_lds.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_skip_lds.hpp
+117
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+5
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_a_lds.hpp
0 → 100644
View file @
8929bde2
#pragma once
#include "common_header.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
#include "xdlops_gemm.hpp"
#include "tensor_adaptor.hpp"
namespace
ck
{
template
<
index_t
BlockSize
,
typename
FloatAB
,
typename
FloatAcc
,
typename
BK0NK1BlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
K0PerBlock
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_skip_a_lds
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
KPerBlock
=
K0PerBlock
*
KPack
;
static
constexpr
index_t
B_K0
=
BK0NK1BlockDesc
{}.
GetLength
(
I0
);
static
constexpr
index_t
B_K1
=
BK0NK1BlockDesc
{}.
GetLength
(
I2
);
static
constexpr
auto
xdlops_gemm
=
XdlopsGemm
<
FloatAB
,
MPerXDL
,
NPerXDL
,
KPack
>
{};
static
constexpr
index_t
KPerThread
=
KPerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
K0PerThread
=
K0PerBlock
/
xdlops_gemm
.
K0PerXdlops
;
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerXDL
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerXDL
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
xdlops_gemm
.
GetRegSizePerXdlops
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
get_thread_local_1d_id
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
xdlops_a_idx
=
xdlops_gemm
.
CalculateAThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_m
,
xdlops_a_idx
[
I1
],
KPerThread
*
xdlops_a_idx
[
I0
]);
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
xdlops_b_idx
=
xdlops_gemm
.
CalculateBThreadOriginDataIndex
();
return
make_tuple
(
0
,
waveId_n
,
xdlops_b_idx
[
I1
],
KPerThread
*
xdlops_b_idx
[
I0
]);
}
template
<
index_t
m0
,
index_t
n0
,
index_t
xdlops_i
,
index_t
blk_i
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
,
Number
<
xdlops_i
>
,
Number
<
blk_i
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
xdlops_gemm
.
GetBeginOfThreadBlk
(
xdlops_i
,
blk_i
);
constexpr
auto
mrepeat_mwave_mperxdl_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperxdl_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperxdl_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperxdl_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
__host__
__device__
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
()
{
static_assert
(
AK0MK1BlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
MWaves
*
NWaves
*
WaveSize
,
"BlockSize != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerXDL
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerXDL
*
NRepeat
)
==
0
,
"wrong!"
);
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_m0_m1_m2_n_tblk_lens
=
xdlops_gemm
.
GetCM0M1M2NThreadBlkLengths
();
constexpr
auto
M0
=
c_m0_m1_m2_n_tblk_lens
[
I0
];
constexpr
auto
M1
=
c_m0_m1_m2_n_tblk_lens
[
I1
];
constexpr
auto
M2
=
c_m0_m1_m2_n_tblk_lens
[
I2
];
constexpr
auto
N
=
c_m0_m1_m2_n_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
M0
,
M1
,
M2
,
N
));
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
()
{
constexpr
auto
c_block_desc_g_m0_n0_m1_n1_m2_n2
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
NWaves
>
{},
Number
<
MPerXDL
>
{},
Number
<
NPerXDL
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_block_desc_g_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m0_n0_m1_n1_m2_n2
);
}
template
<
typename
CGridDesc_G_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_G_M_N
&
c_grid_desc_g_m_n
)
{
const
auto
G
=
c_grid_desc_g_m_n
.
GetLength
(
I0
);
const
auto
M
=
c_grid_desc_g_m_n
.
GetLength
(
I1
);
const
auto
N
=
c_grid_desc_g_m_n
.
GetLength
(
I2
);
const
auto
c_grid_desc_g_m0_n0_m1_n1_m2_n2
=
transform_tensor_descriptor
(
c_grid_desc_g_m_n
,
make_tuple
(
make_pass_through_transform
(
G
),
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerXDL
),
MWaves
,
MPerXDL
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerXDL
),
NWaves
,
NPerXDL
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
3
,
5
>
{},
Sequence
<
2
,
4
,
6
>
{}));
return
xdlops_gemm
.
MakeCDescriptor_G_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_g_m0_n0_m1_n1_m2_n2
);
}
__host__
__device__
static
constexpr
auto
MakeBBlockDescriptor_N0_N1_N2_K
()
{
return
transform_tensor_descriptor
(
BK0NK1BlockDesc
{},
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
Number
<
B_K0
>
{},
Number
<
B_K1
>
{})),
make_unmerge_transform
(
make_tuple
(
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerXDL
>
{}))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
3
>
{},
Sequence
<
0
,
1
,
2
>
{}));
}
__device__
void
MoveABlockSliceWindow
()
{
a_thread_copy_
.
MoveSrcSliceWindow
(
a_block_desc_m0_m1_m2_k
,
make_multi_index
(
0
,
0
,
0
,
K0PerBlock
*
KPack
));
}
__device__
void
ResetABlockStartWindow
()
{
a_thread_copy_
.
SetSrcCoord
(
CalculateAThreadOriginDataIndex
());
}
static
constexpr
auto
b_block_desc_n0_n1_n2_k
=
MakeBBlockDescriptor_N0_N1_N2_K
();
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_thread_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
b_thread_copy_
.
Run
(
b_block_desc_n0_n1_n2_k
,
make_tuple
(
n0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
static_for
<
0
,
KPerThread
,
KPack
>
{}([
&
](
auto
k
)
{
vector_type
<
FloatAB
,
KPack
>
a_thread_vec
;
vector_type
<
FloatAB
,
KPack
>
b_thread_vec
;
constexpr
index_t
k0
=
k
/
KPack
;
static_for
<
0
,
KPack
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
k0
,
m0
,
i
))
>
{}];
b_thread_vec
.
template
AsType
<
FloatAB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
0
,
0
,
0
,
k
+
i
))
>
{}];
});
using
mfma_input_type
=
typename
vector_type
<
FloatAB
,
xdlops_gemm
.
K1PerXdlops
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
private:
// B[M0, M1, M2, KPerThread]
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
I1
,
I1
,
I1
,
Number
<
KPerThread
>
{}));
// A[K0PerThread, M0, KPack]
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
K0PerThread
>
{},
// KPerThread
Number
<
MRepeat
>
{},
// repeat
Number
<
KPack
>
{}));
// C[M, N, NumRegXdlops]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
xdlops_gemm
.
GetRegSizePerXdlops
()));
using
BThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
decltype
(
b_block_desc_n0_n1_n2_k
),
decltype
(
b_thread_desc_
),
Sequence
<
1
,
1
,
1
,
KPerThread
>
,
Sequence
<
0
,
1
,
2
,
3
>
,
3
,
B_K1
,
B_K1
>
;
BThreadCopy
b_thread_copy_
{
CalculateBThreadOriginDataIndex
()};
};
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
View file @
8929bde2
...
...
@@ -20,7 +20,7 @@ template <index_t BlockSize,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
>
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_
v1r1
struct
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_
skip_b_lds
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
include/ck/tensor_operation/gpu/element/element_wise_operation.hpp
View file @
8929bde2
...
...
@@ -201,6 +201,15 @@ struct UnaryTypeConvert<ck::bhalf_t, float>
}
};
template
<
>
struct
UnaryTypeConvert
<
ck
::
half_t
,
float
>
{
__host__
__device__
void
operator
()(
ck
::
half_t
&
y
,
float
&
x
)
const
{
y
=
ck
::
type_convert
<
ck
::
half_t
,
float
>
(
x
);
}
};
}
// namespace element_wise
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/grid/gridwise_gemm_gemm_xdl_skip_lds.hpp
View file @
8929bde2
...
...
@@ -85,6 +85,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
// gemm1 K1
static
constexpr
auto
AccK1
=
I4
;
static
constexpr
auto
Gemm1K0PerBlock
=
Number
<
KPerBlock
/
AccK1
>
{};
static
constexpr
index_t
WaveSize
=
64
;
static
constexpr
index_t
M0Waves
=
M0PerBlock
/
(
M0XdlPerWave
*
M0PerXDL
);
...
...
@@ -101,7 +102,8 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
using
GridwiseGemmPipe
=
GridwiseGemmPipelineSkipLds
;
using
GridwiseGemmPipe0
=
GridwiseGemmPipelineSkipBLds
;
using
GridwiseGemmPipe1
=
GridwiseGemmPipelineAInVgpr
;
__host__
__device__
static
constexpr
auto
GetABlockDescriptor_K0PerBlock_MPerBlock_K1
()
{
...
...
@@ -358,23 +360,22 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
DefaultBlock2CTileMap
=
decltype
(
MakeDefaultBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
using
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
=
decltype
(
MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
BGridDesc_K0_N_K1
{}));
using
B
0
GridDesc_K0_K1_K2_N0_N1_N2_N3_K3
=
decltype
(
MakeB0GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
B
0
GridDesc_K0_N_K1
{}));
__host__
__device__
static
constexpr
auto
MakeB1GridDescriptor_K0_K1_K2_N0_N1_N2_N3_K3
(
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
)
{
}
using
TypeConvertFp32ToFp16Functor
=
ck
::
tensor_operation
::
element_wise
::
UnaryTypeConvert
<
ck
::
half_t
,
float
>
;
template
<
bool
HasMainK0BlockLoop
,
typename
Block2CTileMap
=
DefaultBlock2CTileMap
>
__device__
static
void
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b0_grid
,
const
FloatAB
*
__restrict__
p_b1_grid
,
FloatC
*
__restrict__
p_c_grid
,
void
*
__restrict__
p_shared
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
const
B0GridDesc_K0_K1_K2_N0_N1_N2_N3_K3
b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
const
B1GridDesc_K0_N_K1
&
b1_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
...
...
@@ -383,12 +384,14 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
{
const
auto
a_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_grid
,
a_grid_desc_k0_m_k1
.
GetElementSpaceSize
());
const
auto
b_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_grid
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
());
const
auto
b0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b0_grid
,
b0_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
.
GetElementSpaceSize
());
const
auto
b1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b1_grid
,
b1_grid_desc_k0_n_k1
.
GetElementSpaceSize
());
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
const
auto
Gemm0
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// divide block work by [M, N]
const
auto
block_work_idx
=
...
...
@@ -474,13 +477,13 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
wave_k_n_id[I0],
wave_k_n_id[I1]);
printf("mfma thread k per xdlops: %d K0PerThread: %d HasMainK0BlockLoop: %d K0: %d \t",
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
xdlops_gemm.K0PerXdlops, K0PerThread, HasMainK0BlockLoop, b
0
_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3.GetLength(I0));
#endif
auto
b_threadwise_copy
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
decltype
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
decltype
(
b
0
_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
decltype
(
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
),
Sequence
<
I1
,
I1
,
...
...
@@ -495,16 +498,14 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
BBlockTransferSrcScalarPerVector
,
BThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b
0
_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
make_multi_index
(
0
,
wave_k_n_id
[
I0
],
0
,
block_work_idx
[
I1
],
0
,
wave_id
[
I1
],
wave_k_n_id
[
I1
],
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[K0PerBlock, MPerBlock] is in LDS
// b_mtx[K0PerBlock, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// GEMM0 definition
// c_mtx += b_mtx * a_mtx
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
<
BlockSize
,
...
...
@@ -530,27 +531,27 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
K0PerBlock
*
MultiK0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
1
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
// gridwise GEMM pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
>
);
const
auto
gridwise_gemm_pipeline
=
GridwiseGemmPipe
{};
// gridwise GEMM
0
pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe
0
>
);
const
auto
gridwise_gemm_pipeline
0
=
GridwiseGemmPipe
0
{};
const
index_t
num_k_block_main_loop
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
a_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline
.
template
Run
<
HasMainKBlockLoop
,
>
gridwise_gemm_pipeline
0
.
template
Run
<
HasMainKBlockLoop
,
MultiK0
>
(
a_grid_desc_k0_m_k1
,
a_block_desc_k0_m_k1
,
a_blockwise_copy
,
a_grid_buf
,
a_block_buf
,
a_block_slice_copy_step
,
b_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b
0
_grid_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_thread_desc_k0_k1_k2_n0_n1_n2_n3_k3
,
b_threadwise_copy
,
b_grid_buf
,
b_thread_buf
,
b_thread_slice_copy_step
,
blockwise_gemm
,
c_thread_buf
,
...
...
@@ -589,35 +590,7 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
make_tuple
(
Sequence
<
0
,
2
,
4
,
5
>
{},
Sequence
<
1
,
3
,
7
>
{},
Sequence
<
6
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// A1 matrix blockwise copy
// actually a threadwise copy. this variant needs to support RunRead() and RunWrite()
// TODO ANT: real blockwise copy from c_block_desc to c_thread_desc
auto
a1_blockwise_copy
=
ThreadwiseTensorSliceTransfer_v3r1
<
Sequence
<
m0
*
m1
*
m2
*
m3
,
n0
*
n1
*
n2
,
m4
>
{},
// ThreadSliceLengths
tensor_operation
::
element_wise
::
PassThrough
,
// SrcElementwiseOperation
tensor_operation
::
element_wise
::
PassThrough
,
// DstElementwiseOperation
InMemoryDataOperationEnum
::
Set
,
// DstInMemOp
FloatGemmAcc
,
// SrcData
FloatAB
,
// DstData
a1_thread_desc_k0_m_k1
,
// SrcDesc
a1_thread_desc_k0_m_k1
,
// DstDesc
Sequence
<
1
,
0
,
2
>
,
// SrcDimAccessOrder
Sequence
<
1
,
0
,
2
>
,
// DstDimAccessOrder
2
,
// SrcVectorDim
2
,
// DstVectorDim
m4
,
// SrcScalarPerVector
m4
,
// DstScalarPerVector
1
,
// SrcScalarStrideInVector
1
,
// DstScalarStrideInVector
false
,
// ThreadTransferSrcResetCoordinateAfterRun
true
,
// ThreadTransferDstResetCoordinateAfterRun
NumGemmKPrefetchStage
>
(
a1_thread_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{},
a1_thread_desc_k0_m_k1
,
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
ThreadGroupTensorSliceTransfer_v4r1
<
ThisThreadBlock
,
...
...
@@ -649,14 +622,62 @@ struct GridwiseGemmGemmXdlopsSkipLdsV1
make_multi_index
(
0
,
0
,
0
),
tensor_operation
::
element_wise
::
PassThrough
{});
auto
a1_
block
_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
acc_thread_desc_m0_n0_m1_n1_m2_
n2_n3_n4
.
GetElementSpaceSize
());
auto
a1_
thread
_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatAB
>
(
acc_thread_desc_m0_n0_m1_n1_m2_
m3_m4_n2
.
GetElementSpaceSize
());
// reuse LDS space for gemm0's a_block_buf
auto
b1_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
static_cast
<
FloatAB
*>
(
p_shared
),
b1_block_desc_bk0_n_bk1
.
GetElementSpaceSize
());
constexpr
auto
a1_thread_slice_copy_step
=
make_multi_index
(
Gemm1K0PerBlock
,
0
,
0
,
0
,
0
,
0
,
0
,
0
);
constexpr
auto
b1_block_slice_copy_step
=
make_multi_index
(
Gemm1K0PerBlock
,
0
,
0
);
// GEMM1 definition
// c_mtx += a_mtx * b_mtx
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
// sanity check
auto
blockwise_gemm1
=
BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_skip_a_lds
<
BlockSize
,
FloatAB
,
FloatAcc
,
decltype
(
b1_block_desc_bk0_n_bk1
),
MPerBlock
,
NPerBlock
,
Gemm1K0PerBlock
,
MPerXDL
,
NPerXDL
,
MXdlPerWave
,
NXdlPerWave
,
K1
>
{};
auto
c1_thread_buf
=
blockwise_gemm1
.
GetCThreadBuffer
();
// gridwise GEMM 0 pipeline
static_assert
(
std
::
is_default_constructible_v
<
GridwiseGemmPipe1
>
);
const
auto
gridwise_gemm_pipeline1
=
GridwiseGemmPipe1
{};
const
index_t
num_k_block_main_loop_1
=
__builtin_amdgcn_readfirstlane
(
(
a_grid_desc_k0_m_k1
.
GetLength
(
I0
)
*
a_grid_desc_k0_m_k1
.
GetLength
(
I2
))
/
KPerBlock
);
gridwise_gemm_pipeline1
.
template
Run
<
TypeConvertFp32ToFp16Functor
,
MultiK0
>
(
a1_thread_desc_k0_m_k1
,
c_thread_buf
,
a1_thread_buf
,
a1_thread_slice_copy_step
,
b1_grid_desc_bk0_n_bk1
,
b1_block_desc_bk0_n_bk1
,
b1_blockwise_copy
,
b1_grid_buf
,
b1_block_buf
,
b1_block_slice_copy_step
,
blockwise_gemm
,
c1_thread_buf
,
num_k_block_main_loop
);
// output: register to global memory
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_skip_lds.hpp
View file @
8929bde2
...
...
@@ -18,7 +18,7 @@ __device__ void s_nop()
#endif
}
struct
GridwiseGemmPipelineSkipLds
struct
GridwiseGemmPipelineSkip
B
Lds
{
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
...
...
@@ -32,6 +32,7 @@ struct GridwiseGemmPipelineSkipLds
}
template
<
bool
HasMainLoop
,
index_t
MultK0
,
typename
AGridDesc
,
typename
ABlockDesc
,
typename
ABlockTransfer
,
...
...
@@ -45,8 +46,7 @@ struct GridwiseGemmPipelineSkipLds
typename
BThreadBuffer
,
typename
BThreadTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
,
index_t
MultK0
>
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AGridDesc
&
a_grid_desc
,
const
ABlockDesc
&
a_block_desc
,
ABlockTransfer
&
a_blockwise_copy
,
...
...
@@ -57,7 +57,7 @@ struct GridwiseGemmPipelineSkipLds
const
BThreadDesc
&
b_thread_desc
,
BThreadTransfer
&
b_threadwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BThreadBuffer
&
b_thread_buf
[
MultK0
]
,
BThreadBuffer
*
b_thread_buf
,
const
BThreadTransferStep
&
b_thread_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
,
...
...
@@ -142,4 +142,117 @@ struct GridwiseGemmPipelineSkipLds
}
};
struct
GridwiseGemmPipelineAInVgpr
{
static
constexpr
I0
=
Number
<
0
>
{};
__host__
__device__
static
constexpr
bool
IsSupported
(
index_t
num_loop
)
{
// TODO: improve applicability
return
num_loop
>=
2
;
}
__host__
__device__
static
constexpr
bool
CalculateHasMainLoop
(
index_t
num_loop
)
{
return
num_loop
>
2
;
}
template
<
typename
TypeConvertFp32ToFp16Functor
,
typename
AThreadDesc
,
index_t
i_k0
>
__device__
static
void
ConvertCopy
(
const
AThreadDesc
&
a_thread_desc
,
const
AccThreadBuffer
&
acc_thread_buf
,
AThreadBuffer
&
a_thread_buf
)
{
constexpr
auto
i_k0_num
=
Number
<
i_k0
>
{};
static_for
<
0
,
n0
,
1
>
{}([
&
](
auto
n
)
{
static_for
<
0
,
m4
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
acc_offset
=
a_thread_desc
.
CalculateOffset
(
make_tuple
(
i_k0_num
,
n
,
I0
,
I0
,
I0
,
I0
,
m
,
I0
));
constexpr
auto
a_offset
=
a_thread_desc
.
CalculateOffset
(
make_tuple
(
I0
,
n
,
I0
,
I0
,
I0
,
I0
,
m
,
I0
));
TypeConvertFp32ToFp16Functor
(
a1_thread_buf
(
Number
<
a_offset
>
{}),
acc_thread_buf
(
Number
<
acc_offset
>
{}));
});
});
}
template
<
typename
TypeConvertFp32ToFp16Functor
,
index_t
MultK0
,
typename
AThreadDesc
,
typename
AccThreadBuffer
,
typename
AThreadBuffer
,
typename
AThreadTransferStep
,
typename
BGridDesc
,
typename
BBlockDesc
,
typename
BBlockTransfer
,
typename
BGridBuffer
,
typename
BBlockBuffer
,
typename
BBlockTransferStep
,
typename
BlockwiseGemm
,
typename
CThreadBuffer
>
__device__
static
void
Run
(
const
AThreadDesc
&
a_thread_desc
,
const
AccThreadBuffer
&
acc_thread_buf
,
AThreadBuffer
&
a_thread_buf
,
const
AThreadTransferStep
&
a_thread_transfer_step
,
const
BGridDesc
&
b_grid_desc
,
const
BBlockDesc
&
b_block_desc
,
BBlockTransfer
&
b_blockwise_copy
,
const
BGridBuffer
&
b_grid_buf
,
BBlockBuffer
&
b_block_buf
,
const
BBlockTransferStep
&
b_block_copy_step
,
const
BlockwiseGemm
&
blockwise_gemm
,
CThreadBuffer
&
c_thread_buf
)
{
static_for
<
0
,
MultiK0
,
1
>
{}([
&
](
auto
i_k0
){
});
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// Initialize C
c_thread_buf
.
Clear
();
// a data write to lds
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
// main body
if
constexpr
(
HasMainK0BlockLoop
)
{
index_t
i
=
0
;
do
{
block_sync_lds
();
// GEMM i
blockwise_gemm
.
Run
(
a_thread_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
b_blockwise_copy
.
MoveSrcSliceWindow
(
b_grid_desc
,
b_block_copy_step
);
// LDS write i + 1
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
// global read i + 2
b_blockwise_copy
.
RunRead
(
b_grid_desc
,
b_grid_buf
);
i
+=
1
;
}
while
(
i
<
(
num_loop
-
2
));
// tail
{
block_sync_lds
();
// GEMM num_loop - 2
blockwise_gemm
.
Run
(
a_thread_buf
,
b_block_buf
,
c_thread_buf
);
block_sync_lds
();
// LDS write num_loop - 1
b_blockwise_copy
.
RunWrite
(
b_block_desc
,
b_block_buf
);
block_sync_lds
();
// GEMM num_loop - 1
blockwise_gemm
.
Run
(
a_thread_buf
,
b_block_buf
,
c_thread_buf
);
}
}
}
};
}
// namespace ck
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
8929bde2
...
...
@@ -1180,6 +1180,11 @@ struct ThreadwiseTensorSliceTransfer_v4
move_tensor_coordinate
(
SrcDesc
{},
src_ref_coord_
,
src_slice_move_step_iter
);
}
__device__
void
SetSrcCoord
(
const
Index
&
src_ref_idx
)
{
src_ref_coord_
=
make_tensor_coordinate
(
SrcDesc
{},
src_ref_idx
);
}
private:
SrcCoord
src_ref_coord_
;
};
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment