Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
10fdada7
Commit
10fdada7
authored
Sep 09, 2021
by
Jing Zhang
Browse files
rename e0_e1
parent
95228cd7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
106 additions
and
124 deletions
+106
-124
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
...rnel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
+24
-26
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
...ernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
+82
-90
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
...nel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
+0
-8
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_dlops_v3.hpp
View file @
10fdada7
...
@@ -12,14 +12,16 @@ template <index_t BlockSize,
...
@@ -12,14 +12,16 @@ template <index_t BlockSize,
typename
BlockMatrixA
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
typename
ThreadMatrixC
,
index_t
KPerThread
,
index_t
HPerThread
,
index_t
WPerThread
,
index_t
EPerThreadLoop
,
index_t
EPerThreadLoop
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmADataPerRead_K
,
index_t
ThreadGemmBDataPerRead_W
>
index_t
ThreadGemmBDataPerRead_W
>
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
struct
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
struct
MatrixIndex
struct
MatrixIndex
{
{
index_t
k
;
index_t
k
;
...
@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -27,6 +29,10 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
index_t
w
;
index_t
w
;
};
};
static
constexpr
auto
KPerThread
=
ThreadMatrixC
{}.
GetLength
(
I0
);
static
constexpr
auto
HPerThread
=
ThreadMatrixC
{}.
GetLength
(
I2
);
static
constexpr
auto
WPerThread
=
ThreadMatrixC
{}.
GetLength
(
I3
);
// HACK: fix this @Jing Zhang
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
index_t
KPerThreadSubC
=
4
;
...
@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -39,16 +45,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
static
constexpr
auto
c_thread_mtx_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
__device__
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
k
*
KPerThread
)}
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
k
*
KPerThread
)}
...
@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -58,11 +54,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
ThreadMatrixC
::
IsKnownAtCompileTime
(),
ThreadMatrixC
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
static_assert
(
BlockMatrixA
{}.
GetLength
(
I0
)
==
BlockMatrixB
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent
\n
"
);
"wrong! K dimension not consistent
\n
"
);
...
@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -88,11 +79,11 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
__device__
static
MatrixIndex
GetBeginOfThreadMatrixC
(
index_t
thread_id
)
{
{
constexpr
index_t
H
=
BlockMatrixB
{}.
GetLength
(
Number
<
2
>
{});
constexpr
index_t
H
PerBlock
=
BlockMatrixB
{}.
GetLength
(
Number
<
2
>
{});
constexpr
index_t
W
=
BlockMatrixB
{}.
GetLength
(
Number
<
3
>
{});
constexpr
index_t
W
PerBlock
=
BlockMatrixB
{}.
GetLength
(
Number
<
3
>
{});
constexpr
auto
num_w_threads
=
W
/
WPerThread
;
constexpr
auto
num_w_threads
=
W
PerBlock
/
WPerThread
;
constexpr
auto
num_h_threads
=
H
/
HPerThread
;
constexpr
auto
num_h_threads
=
H
PerBlock
/
HPerThread
;
constexpr
auto
num_hw_threads
=
num_w_threads
*
num_h_threads
;
constexpr
auto
num_hw_threads
=
num_w_threads
*
num_h_threads
;
index_t
k_thread_id
=
thread_id
/
num_hw_threads
;
index_t
k_thread_id
=
thread_id
/
num_hw_threads
;
...
@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -115,8 +106,6 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
is_same
<
remove_cvref_t
<
typename
CThreadBuffer
::
type
>
,
remove_cvref_t
<
FloatC
>>::
value
&&
"wrong! inconsistent type"
);
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
a_block_mtx
=
BlockMatrixA
{};
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
...
@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -166,8 +155,7 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
}
}
template
<
typename
ABlockSliceMoveStepIdx
>
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveASliceWindow
(
const
BlockMatrixA
&
,
__device__
void
MoveABlockSliceWindow
(
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
{
a_thread_copy_
.
MoveSrcSliceWindow
(
BlockMatrixA
{},
a_block_slice_move_step_idx
);
a_thread_copy_
.
MoveSrcSliceWindow
(
BlockMatrixA
{},
a_block_slice_move_step_idx
);
}
}
...
@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
...
@@ -175,6 +163,16 @@ struct BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
private:
private:
MatrixIndex
c_thread_begin_mtx_idx_
;
MatrixIndex
c_thread_begin_mtx_idx_
;
using
AThreadCopy
=
ThreadwiseTensorSliceTransfer_v4
<
FloatAB
,
FloatAB
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
1
>
;
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
};
};
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_dlops_v2.hpp
View file @
10fdada7
...
@@ -15,24 +15,24 @@ namespace ck {
...
@@ -15,24 +15,24 @@ namespace ck {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
A
EK
GridDesc
,
typename
AGridDesc
_E0_E1_K
,
typename
B
ENHoWo
GridDesc
,
typename
BGridDesc
_E_N_Ho_Wo
,
typename
C
KNHoWo
GridDesc
,
typename
CGridDesc
_K_N_Ho_Wo
,
typename
CBlockIdTo
KNHoWo
BlockClusterAdaptor
,
typename
CBlockIdToBlockClusterAdaptor
_K_N_Ho_Wo
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
__global__
void
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
#endif
kernel_gemm_dlops_v2
(
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_
a
_grid
,
const
FloatAB
*
__restrict__
p_
b
_grid
,
const
Float
AB
*
__restrict__
p_
b
_grid
,
Float
C
*
__restrict__
p_
c
_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
AGridDesc_E0_E1_K
a_e0_e1_k_grid_desc
,
const
AEK
GridDesc
a_e_k
_grid_desc
,
const
B
GridDesc
_E_N_Ho_Wo
b_e0_e1_n_ho_wo
_grid_desc
,
const
BENHoWo
GridDesc
b_e
_n_ho_wo_grid_desc
,
const
C
GridDesc
_K_N_Ho_Wo
c_k
_n_ho_wo_grid_desc
,
const
C
KNHoWoGridDesc
c_k_n_ho_wo_grid_desc
,
const
C
BlockIdToBlockClusterAdaptor_K_N_Ho_Wo
const
CBlockIdToKNHoWoBlockClusterAdaptor
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
{
{
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
...
@@ -43,8 +43,8 @@ __global__ void
...
@@ -43,8 +43,8 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_e_k_grid_desc
,
a_e
0_e1
_k_grid_desc
,
b_e_n_ho_wo_grid_desc
,
b_e
0_e1
_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -56,10 +56,10 @@ __global__ void
...
@@ -56,10 +56,10 @@ __global__ void
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
A
EK
GridDesc
,
typename
AGridDesc
_E0_E1_K
,
typename
B
ENHoWo
GridDesc
,
typename
BGridDesc
_E_N_Ho_Wo
,
typename
C
KNHoWo
GridDesc
,
typename
CGridDesc
_K_N_Ho_Wo
,
typename
CBlockIdTo
KNHoWo
BlockClusterAdaptor
,
typename
CBlockIdToBlockClusterAdaptor
_K_N_Ho_Wo
,
bool
HasMainKBlockLoop
,
bool
HasMainKBlockLoop
,
bool
HasDoubleTailKBlockLoop
>
bool
HasDoubleTailKBlockLoop
>
__global__
void
__global__
void
...
@@ -69,19 +69,19 @@ __global__ void
...
@@ -69,19 +69,19 @@ __global__ void
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
kernel_gemm_dlops_v2
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_e_k_grid_desc
,
const
void
CONSTANT
*
p_a_e
0_e1
_k_grid_desc
,
const
void
CONSTANT
*
p_b_e_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_b_e
0_e1
_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_k_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_k_n_ho_wo_grid_desc
,
const
void
CONSTANT
*
p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
const
void
CONSTANT
*
p_c_blockid_to_k_n_ho_wo_block_cluster_adaptor
)
{
{
// first cast void CONSTANT void* to void*
// first cast void CONSTANT void* to void*
// second cast void* to Desc*
// second cast void* to Desc*
// the copy constructor of tensor descriptor doesn't take address_space(4)
// the copy constructor of tensor descriptor doesn't take address_space(4)
const
auto
a_e_k_grid_desc
=
*
reinterpret_cast
<
const
A
EK
GridDesc
*>
(
const
auto
a_e
0_e1
_k_grid_desc
=
*
reinterpret_cast
<
const
AGridDesc
_E0_E1_K
*>
(
cast_pointer_to_generic_address_space
(
p_a_e_k_grid_desc
));
cast_pointer_to_generic_address_space
(
p_a_e
0_e1
_k_grid_desc
));
const
auto
b_e_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
B
ENHoWo
GridDesc
*>
(
const
auto
b_e
0_e1
_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
BGridDesc
_E_N_Ho_Wo
*>
(
cast_pointer_to_generic_address_space
(
p_b_e_n_ho_wo_grid_desc
));
cast_pointer_to_generic_address_space
(
p_b_e
0_e1
_n_ho_wo_grid_desc
));
const
auto
c_k_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
C
KNHoWo
GridDesc
*>
(
const
auto
c_k_n_ho_wo_grid_desc
=
*
reinterpret_cast
<
const
CGridDesc
_K_N_Ho_Wo
*>
(
cast_pointer_to_generic_address_space
(
p_c_k_n_ho_wo_grid_desc
));
cast_pointer_to_generic_address_space
(
p_c_k_n_ho_wo_grid_desc
));
constexpr
index_t
shared_block_size
=
constexpr
index_t
shared_block_size
=
...
@@ -93,8 +93,8 @@ __global__ void
...
@@ -93,8 +93,8 @@ __global__ void
p_b_grid
,
p_b_grid
,
p_c_grid
,
p_c_grid
,
p_shared_block
,
p_shared_block
,
a_e_k_grid_desc
,
a_e
0_e1
_k_grid_desc
,
b_e_n_ho_wo_grid_desc
,
b_e
0_e1
_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
c_k_n_ho_wo_grid_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasMainKBlockLoop
>
{},
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
{});
...
@@ -106,9 +106,9 @@ template <index_t BlockSize,
...
@@ -106,9 +106,9 @@ template <index_t BlockSize,
typename
FloatAcc
,
typename
FloatAcc
,
typename
FloatC
,
typename
FloatC
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
InMemoryDataOperationEnum_t
CGlobalMemoryDataOperation
,
typename
AGlobalDesc
,
typename
AGlobalDesc
_E0_E1_K
,
typename
BGlobalDesc
,
typename
BGlobalDesc
_E0_E1_N_Ho_Wo
,
typename
CGlobalDesc
,
typename
CGlobalDesc
_K_N_Ho_Wo
,
index_t
KPerBlock
,
index_t
KPerBlock
,
index_t
HoPerBlock
,
index_t
HoPerBlock
,
index_t
WoPerBlock
,
index_t
WoPerBlock
,
...
@@ -148,12 +148,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -148,12 +148,12 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e
_
k_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_e
1_k_bloc
k_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// LDS allocation for A and B: be careful of alignment
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
constexpr
auto
a_block_space_size
=
math
::
integer_least_multiple
(
a_e
_
k_desc
.
GetElementSpaceSize
(),
max_lds_align
);
math
::
integer_least_multiple
(
a_e
1_k_bloc
k_desc
.
GetElementSpaceSize
(),
max_lds_align
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
return
a_block_space_size
*
sizeof
(
FloatAB
);
}
}
...
@@ -163,9 +163,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -163,9 +163,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
const
FloatAB
*
__restrict__
p_b_global
,
const
FloatAB
*
__restrict__
p_b_global
,
FloatC
*
__restrict__
p_c_global
,
FloatC
*
__restrict__
p_c_global
,
FloatAB
*
__restrict__
p_shared_block
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGlobalDesc
&
a_e_k_global_desc
,
const
AGlobalDesc
_E0_E1_K
&
a_e
0_e1
_k_global_desc
,
const
BGlobalDesc
&
b_e_n_ho_wo_global_desc
,
const
BGlobalDesc
_E0_E1_N_Ho_Wo
&
b_e
0_e1
_n_ho_wo_global_desc
,
const
CGlobalDesc
&
c_k_n_ho_wo_global_desc
,
const
CGlobalDesc
_K_N_Ho_Wo
&
c_k_n_ho_wo_global_desc
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasMainKBlockLoop
>
,
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
integral_constant
<
bool
,
HasDoubleTailKBlockLoop
>
)
{
{
...
@@ -175,18 +175,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -175,18 +175,18 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_a_global
,
a_e_k_global_desc
.
GetElementSpaceSize
());
p_a_global
,
a_e
0_e1
_k_global_desc
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_b_global
,
b_e_n_ho_wo_global_desc
.
GetElementSpaceSize
());
p_b_global
,
b_e
0_e1
_n_ho_wo_global_desc
.
GetElementSpaceSize
());
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
p_c_global
,
c_k_n_ho_wo_global_desc
.
GetElementSpaceSize
());
// const auto E = a_e_k_global_desc.GetLength(I0);
// const auto E = a_e
0_e1
_k_global_desc.GetLength(I0);
// const auto K = a_e_k_global_desc.GetLength(I1);
// const auto K = a_e
0_e1
_k_global_desc.GetLength(I1);
// const auto N = b_e_n_ho_wo_global_desc.GetLength(I1);
// const auto N = b_e
0_e1
_n_ho_wo_global_desc.GetLength(I1);
const
auto
Ho
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Ho
=
b_e
0_e1
_n_ho_wo_global_desc
.
GetLength
(
I2
);
const
auto
Wo
=
b_e_n_ho_wo_global_desc
.
GetLength
(
I3
);
const
auto
Wo
=
b_e
0_e1
_n_ho_wo_global_desc
.
GetLength
(
I3
);
// divide block work by [M, N]
// divide block work by [M, N]
#if 0
#if 0
...
@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -220,15 +220,15 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// A matrix in LDS memory, dst of blockwise copy
// A matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
a_e_k_block_desc
=
make_naive_tensor_descriptor_aligned
(
constexpr
auto
a_e1_k_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
constexpr
auto
a_e_k_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
make_tuple
(
Number
<
E
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
constexpr
auto
a_e2_k_block_desc
=
make_naive_tensor_descriptor_aligned
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
KPerBlock
>
{}),
max_lds_align
);
// B matrix in LDS memory, dst of blockwise copy
// B matrix in LDS memory, dst of blockwise copy
// be careful of LDS alignment
// be careful of LDS alignment
constexpr
auto
b_e_n_ho_wo_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
b_e
2
_n_ho_wo_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerBlock
>
{},
Number
<
WoPerBlock
>
{}));
// c_thread_mtx definition: this is a mess
// c_thread_mtx definition: this is a mess
...
@@ -240,12 +240,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -240,12 +240,9 @@ struct GridwiseGemmDlops_km_kn_mn_v3
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
BlockwiseGemmDlops_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
FloatAcc
,
decltype
(
a_e_k_block_desc
),
decltype
(
a_e
2
_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
b_e
2
_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
EPerThread
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
>
{};
ABlockTransferDstScalarPerVector_K
>
{};
...
@@ -275,8 +272,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -275,8 +272,8 @@ struct GridwiseGemmDlops_km_kn_mn_v3
ABlockTransferThreadClusterArrangeOrder
,
ABlockTransferThreadClusterArrangeOrder
,
FloatAB
,
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
a_e_k_global_desc
),
decltype
(
a_e
0_e1
_k_global_desc
),
decltype
(
a_e
_
k_desc
),
decltype
(
a_e
1_k_bloc
k_desc
),
ABlockTransferSrcAccessOrder
,
ABlockTransferSrcAccessOrder
,
Sequence
<
0
,
1
>
,
Sequence
<
0
,
1
>
,
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
...
@@ -286,30 +283,30 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -286,30 +283,30 @@ struct GridwiseGemmDlops_km_kn_mn_v3
1
,
1
,
1
,
1
,
AThreadTransferSrcResetCoordinateAfterRun
,
AThreadTransferSrcResetCoordinateAfterRun
,
true
>
(
a_e_k_global_desc
,
true
>
(
a_e
0_e1
_k_global_desc
,
make_multi_index
(
0
,
k_block_data_on_global
),
make_multi_index
(
0
,
k_block_data_on_global
),
a_e
_
k_desc
,
a_e
1_k_bloc
k_desc
,
make_multi_index
(
0
,
0
));
make_multi_index
(
0
,
0
));
constexpr
auto
b_e_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
constexpr
auto
b_e
2
_n_ho_wo_thread_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
Number
<
EPerBlock
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
auto
b_threadwise_transfer
=
auto
b_threadwise_transfer
=
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
ThreadwiseTensorSliceTransfer_v2
<
FloatAB
,
FloatAB
,
FloatAB
,
decltype
(
b_e_n_ho_wo_global_desc
),
decltype
(
b_e
0_e1
_n_ho_wo_global_desc
),
decltype
(
b_e_n_ho_wo_thread_desc
),
decltype
(
b_e
2
_n_ho_wo_thread_desc
),
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
Sequence
<
EPerBlock
,
1
,
HoPerThread
,
WoPerThread
>
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcAccessOrder
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
BBlockTransferSrcScalarPerVector
,
BBlockTransferSrcScalarPerVector
,
1
,
1
,
true
>
(
true
>
(
b_e_n_ho_wo_global_desc
,
b_e
0_e1
_n_ho_wo_global_desc
,
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
make_multi_index
(
0
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
));
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
auto
a_block_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Lds
>
(
p_shared_block
,
a_e
_
k_desc
.
GetElementSpaceSize
());
p_shared_block
,
a_e
1_k_bloc
k_desc
.
GetElementSpaceSize
());
// register allocation for output
// register allocation for output
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
...
@@ -327,34 +324,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -327,34 +324,29 @@ struct GridwiseGemmDlops_km_kn_mn_v3
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
// hack to control index calculation when iterating over A and B matrix for threadwise copy
// hack to control index calculation when iterating over A and B matrix for threadwise copy
constexpr
auto
a_e_k_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
a_e0_e1_k_global_step_hacks
=
AGlobalStepHacks
{};
constexpr
auto
b_e_n_ho_wo_global_step_hacks
=
BGlobalStepHacks
{};
constexpr
auto
b_e0_e1_n_ho_wo_global_step_hacks
=
BGlobalStepHacks
{};
// hack to control index calculation when move slice window for A and B matrix for
// threadwise copy
// constexpr auto a_e_k_global_move_slice_window_step_hack =
// AGlobalMoveSliceWindowStepHacks{}; constexpr auto
// b_e_n_ho_wo_global_move_slice_window_step_hack = BGlobalMoveSliceWindowStepHacks{};
// double regsiter buffer for b
// double regsiter buffer for b
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
StaticBuffer
<
AddressSpaceEnum_t
::
Vgpr
,
FloatAB
,
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
b_e
2
_n_ho_wo_thread_desc
.
GetElementSpaceSize
(),
true
>
true
>
b_thread_even_buf
,
b_thread_odd_buf
;
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
// LDS double buffer: preload data
{
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
a_global_buf
,
a_e_k_global_step_hacks
);
a_blockwise_copy
.
RunRead
(
a_e0_e1_k_global_desc
,
a_global_buf
,
a_e0_e1_k_global_step_hacks
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e
2
_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_step_hacks
);
b_e
0_e1
_n_ho_wo_global_step_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e
_
k_desc
,
a_block_buf
);
a_blockwise_copy
.
RunWrite
(
a_e
1_k_bloc
k_desc
,
a_block_buf
);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -368,36 +360,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -368,36 +360,36 @@ struct GridwiseGemmDlops_km_kn_mn_v3
do
do
{
{
// even iteration
// even iteration
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e
2
_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_step_hacks
);
b_e
0_e1
_n_ho_wo_global_step_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
));
blockwise_gemm
.
MoveA
Block
SliceWindow
(
make_tuple
(
EPerBlock
,
0
));
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e
2
_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_even_buf
,
b_thread_even_buf
,
b_e_n_ho_wo_global_step_hacks
);
b_e
0_e1
_n_ho_wo_global_step_hacks
);
// LDS double buffer: GEMM on current data
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
));
blockwise_gemm
.
MoveA
Block
SliceWindow
(
make_tuple
(
EPerBlock
,
0
));
e_block_data_begin
+=
2
*
EPerBlock
;
e_block_data_begin
+=
2
*
EPerBlock
;
...
@@ -407,20 +399,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
...
@@ -407,20 +399,20 @@ struct GridwiseGemmDlops_km_kn_mn_v3
// LDS double buffer: tail
// LDS double buffer: tail
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
if
constexpr
(
HasDoubleTailKBlockLoop
)
// if has 2 iteration left
{
{
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_thread_slice_copy_step
);
b_thread_slice_copy_step
);
b_threadwise_transfer
.
Run
(
b_e_n_ho_wo_global_desc
,
b_threadwise_transfer
.
Run
(
b_e
0_e1
_n_ho_wo_global_desc
,
b_global_buf
,
b_global_buf
,
b_e_n_ho_wo_thread_desc
,
b_e
2
_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
b_thread_odd_buf
,
b_thread_odd_buf
,
b_e_n_ho_wo_global_step_hacks
);
b_e
0_e1
_n_ho_wo_global_step_hacks
);
// LDS double buffer: GEMM on 2nd-last data
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
blockwise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
));
blockwise_gemm
.
MoveA
Block
SliceWindow
(
make_tuple
(
EPerBlock
,
0
));
// LDS double buffer: GEMM on last data
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_dlops_v3.hpp
View file @
10fdada7
...
@@ -26,14 +26,6 @@ template <typename FloatA,
...
@@ -26,14 +26,6 @@ template <typename FloatA,
struct
ThreadwiseGemmDlops_km_kn_mn_v3
struct
ThreadwiseGemmDlops_km_kn_mn_v3
{
{
__device__
ThreadwiseGemmDlops_km_kn_mn_v3
()
{
static_assert
(
AThreadDesc_E_K
::
IsKnownAtCompileTime
()
&&
BThreadDesc_E_N_Ho_Wo
::
IsKnownAtCompileTime
()
&&
CThreadDesc_K_N_Ho_Wo
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
}
template
<
typename
ABuffer
,
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BBuffer
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment