Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
7b002f23
Commit
7b002f23
authored
Dec 14, 2020
by
Jing Zhang
Browse files
add v4r4 xdlops
parent
87a75734
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
260 additions
and
1017 deletions
+260
-1017
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+20
-1003
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+0
-3
driver/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+228
-0
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+12
-11
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
View file @
7b002f23
...
...
@@ -51,702 +51,6 @@ struct make_block_work_sequence<MBlockWork, NBlockWork, NBlock1MBlock0>
__device__
constexpr
auto
get
()
{
return
Sequence
<
NBlockWork
,
MBlockWork
>
{};
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
OutputMemOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
,
index_t
ABlockCopySrcDataStride
=
1
,
index_t
BBlockCopySrcDataStride
=
1
>
struct
GridwiseGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
b_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
a_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
c_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
K
=
b_k_n_kpack_global_desc
.
GetLengths
()[
0
];
constexpr
auto
N
=
b_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
M
=
a_k_m_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
KPACK
=
b_k_n_kpack_global_desc
.
GetLengths
()[
2
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_sequence
=
make_block_work_sequence
<
MBlockWork
,
NBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
k_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
0
]
*
MPerBlock
)
:
(
block_work_id
[
1
]
*
MPerBlock
);
const
index_t
b_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
NPerBlock
)
:
(
block_work_id
[
0
]
*
NPerBlock
);
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_KPACK
,
ABlockCopyDstDataPerWrite_KPACK
,
KPACK
*
GemmDataPerReadM
,
KPACK
*
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
MPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_k_m_kpack_global_desc
),
decltype
(
a_k_m_kpack_block_desc
),
decltype
(
a_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_K_M_KPACK
,
ABlockCopyThreadClusterLengths_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (M dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
ABlockCopySrcDataStride
>
({
0
,
k_block_data_on_global
,
0
},
{
0
,
0
,
0
});
constexpr
auto
b_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
KPerBlock
,
NPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_k_n_kpack_global_desc
),
decltype
(
b_k_n_kpack_block_desc
),
decltype
(
b_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_K_N_KPACK
,
BBlockCopyThreadClusterLengths_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (N dimension)
2
,
// Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
BBlockCopySrcDataStride
>
({
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block_double
[
2
*
a_block_space
];
__shared__
ABFloat
p_b_block_double
[
2
*
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
blockwise_a_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
using
blockwise_b_copy_src_step
=
Sequence
<
KPerBlock
,
0
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
ABFloat
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
ABFloat
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
ABFloat
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
ABFloat
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_now
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_now
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
+
b_block_space
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
}
// copy output: register to global memory
{
constexpr
auto
OutputLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
K0
=
OutputLayout
.
M1
();
constexpr
index_t
K1
=
OutputLayout
.
N1
();
constexpr
index_t
K2
=
OutputLayout
.
M0
();
constexpr
auto
out_k0_k1_k2_b_global_desc
=
transform_tensor_descriptor
(
c_m_n_global_desc
,
make_tuple
(
UnMerge
<
Sequence
<
K0
,
K1
,
K2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}));
// src descriptor
constexpr
auto
out_k0_k1_k2_b_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
K0
,
1
,
K2
,
1
>
{});
using
OutThreadCopySliceLengths
=
Sequence
<
K0
,
1
,
K2
,
1
>
;
constexpr
index_t
BlkSize
=
OutputLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
OutputLayout
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
k_thread_data_on_global
=
k_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
b_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
out_k0_k1_k2_b_thread_desc
),
decltype
(
out_k0_k1_k2_b_global_desc
),
OutThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
4
,
1
>::
type
,
3
,
1
,
1
,
AddressSpace
::
Vgpr
,
is_same
<
AccFloat
,
CFloat
>::
value
?
AddressSpace
::
Global
:
AddressSpace
::
Generic
,
OutputMemOp
>
({
0
,
0
,
0
,
0
},
{
k_thread_data_on_global
/
(
K2
*
K1
),
k_thread_data_on_global
%
(
K2
*
K1
)
/
K2
,
k_thread_data_on_global
%
K2
,
b_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
index_t
GemmDataPerReadM
,
index_t
GemmDataPerReadN
,
class
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
OutputMemOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
,
index_t
ABlockCopySrcDataStride
=
1
,
index_t
BBlockCopySrcDataStride
=
1
>
struct
GridwiseBatchedGemmTransposedANormalBNormalCXdlopsFp16Bfp16_v1
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
a_g_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
Gi
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
0
];
constexpr
auto
Go
=
c_g_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
K
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
2
];
constexpr
auto
M
=
a_g_k_m_kpack_global_desc
.
GetLengths
()[
2
];
constexpr
auto
KPACK
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
3
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWaves
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWaves
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_sequence
=
make_batch_block_work_sequence
<
Gi
,
MBlockWork
,
NBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
group_id
=
block_work_id
[
0
];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
MPerBlock
)
:
(
block_work_id
[
2
]
*
MPerBlock
);
const
index_t
n_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
2
]
*
NPerBlock
)
:
(
block_work_id
[
1
]
*
NPerBlock
);
// LDS mem
constexpr
index_t
max_align
=
math
::
lcm
(
BBlockCopyDstDataPerWrite_KPACK
,
ABlockCopyDstDataPerWrite_KPACK
,
KPACK
*
GemmDataPerReadM
,
KPACK
*
GemmDataPerReadN
);
// LDS
// be careful of LDS alignment
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
decltype
(
a_g_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (K dimension)
3
,
// Dst dim to be written in vector form (KPACK dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
ABlockCopySrcDataStride
>
({
group_id
,
0
,
m_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
constexpr
auto
b_g_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
KPACK
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_g_k_n_kpack_global_desc
),
decltype
(
b_g_k_n_kpack_block_desc
),
decltype
(
b_g_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form (K dimension)
3
,
// Dst dim to be written in vector form (KPACK dimension)
BBlockCopySrcDataPerRead
,
// N dimension
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
,
BBlockCopySrcDataStride
>
({
group_id
,
0
,
n_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWaves
,
NWaves
,
GemmDataPerReadM
,
GemmDataPerReadN
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block_double
[
2
*
a_block_space
];
__shared__
ABFloat
p_b_block_double
[
2
*
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block_double
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block_double
);
}
using
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
;
using
blockwise_b_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
;
// LDS double buffer: main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
+
2
*
KPerBlock
<
K
;
k_block_data_begin
+=
2
*
KPerBlock
)
{
#pragma unroll
for
(
index_t
iloop
=
0
;
iloop
<
2
;
++
iloop
)
{
const
bool
even_loop
=
(
iloop
%
2
==
0
);
ABFloat
*
p_a_block_now
=
even_loop
?
p_a_block_double
:
p_a_block_double
+
a_block_space
;
ABFloat
*
p_b_block_now
=
even_loop
?
p_b_block_double
:
p_b_block_double
+
b_block_space
;
ABFloat
*
p_a_block_next
=
even_loop
?
p_a_block_double
+
a_block_space
:
p_a_block_double
;
ABFloat
*
p_b_block_next
=
even_loop
?
p_b_block_double
+
b_block_space
:
p_b_block_double
;
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS doubel buffer: load next data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on current data
// Vectorize the pointer to match with how fp16/bfloat16 datatypes are
// processed in gemm operation. fp16 type packs 4 fp16 values while
// bfloat16 packs 2 bfloat16 values. Since gemm's matrix A and B
// 2D indexes are computed with vectorized value in mind (e.g. float, half2, half4),
// we recast datatype from a single fp16 to 4 packed fp16/2 packed bfloat16
// respectively.
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_now
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_now
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_next
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_next
);
}
}
// LDS double buffer: tail
{
constexpr
bool
has_two_iteration_left
=
(
K
%
(
2
*
KPerBlock
)
==
0
);
if
(
has_two_iteration_left
)
// if has 2 iteration left
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
{},
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
{},
True
);
__syncthreads
();
// LDS double buffer: load last data from device mem
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
// LDS double buffer: GEMM on 2nd-last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block_double
+
a_block_space
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block_double
+
b_block_space
);
__syncthreads
();
// LDS double buffer: GEMM on current data
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
+
a_block_space
);
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
+
b_block_space
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_a_block_double
);
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPACK
>::
MemoryType
*>
(
p_b_block_double
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
Go
>
{},
UnMerge
<
Sequence
<
M0
,
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
CLayout
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
CLayout
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
OutputMemOp
>
(
{
0
,
0
,
0
,
0
,
0
},
{
group_id
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
...
...
@@ -812,13 +116,13 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
g_block_data_on_global
=
block_work_id
[
0
];
const
index_t
g_block_data_on_global
=
block_work_id
[
Number
<
0
>
{}
];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
MPerBlock
)
:
(
block_work_id
[
2
]
*
MPerBlock
);
?
(
block_work_id
[
Number
<
1
>
{}
]
*
MPerBlock
)
:
(
block_work_id
[
Number
<
2
>
{}
]
*
MPerBlock
);
const
index_t
n_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
2
]
*
NPerBlock
)
:
(
block_work_id
[
1
]
*
NPerBlock
);
?
(
block_work_id
[
Number
<
2
>
{}
]
*
NPerBlock
)
:
(
block_work_id
[
Number
<
1
>
{}
]
*
NPerBlock
);
constexpr
index_t
max_align
=
KPack
;
...
...
@@ -826,7 +130,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v
4
<
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v
5
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
...
...
@@ -843,8 +147,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
m_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
g_block_data_on_global
,
0
,
m_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
constexpr
auto
b_g_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
...
...
@@ -867,8 +172,9 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
n_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
g_block_data_on_global
,
0
,
n_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
...
...
@@ -918,14 +224,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
KPerBlock
)
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
// ABFloat p_b_thread_buffer[b_blockwise_copy.GetThreadBufferSize()];
// load next data from device mem
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
);
block_sync_lds
();
...
...
@@ -943,7 +246,7 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
block_sync_lds
();
// store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block
);
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_block
);
}
...
...
@@ -1015,298 +318,12 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryOp
>
(
{
0
,
0
,
0
,
0
,
0
},
{
g_block_data_on_global
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
})
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
BPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
BPerWave
,
class
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_G_K_N1_B_KPack
,
class
BBlockCopyThreadClusterLengths_G_K_N1_B_KPack
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
CGlobalMemoryOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
GridwiseBatchGemmXdlops_gkmkpack_gkn1bkpack_gmn_v2
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_g_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n1_b_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
G
=
c_g_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
c_g_m_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
c_g_m_n_global_desc
.
GetLengths
()[
2
];
constexpr
auto
K
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
in_N1
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
2
];
constexpr
auto
B
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
3
];
constexpr
auto
KPack
=
b_g_k_n1_b_kpack_global_desc
.
GetLengths
()[
4
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
B
%
BPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
BBlockWork
=
B
/
BPerBlock
;
constexpr
index_t
MWavePerBlock
=
MPerBlock
/
MPerWave
;
constexpr
index_t
BWavePerBlock
=
in_N1
;
static_assert
((
G
*
MBlockWork
*
BBlockWork
)
==
GridSize
,
"Invalid GridSize"
);
constexpr
auto
block_work_sequence
=
make_batch_block_work_sequence
<
G
,
MBlockWork
,
BBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
g_block_data_on_global
=
block_work_id
[
0
];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
1
]
*
MPerBlock
)
:
(
block_work_id
[
2
]
*
MPerBlock
);
const
index_t
b_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
2
]
*
BPerBlock
)
:
(
block_work_id
[
1
]
*
BPerBlock
);
constexpr
index_t
max_align
=
KPack
;
// LDS be careful of LDS alignment
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
decltype
(
a_g_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
3
,
// Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
m_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
});
constexpr
auto
b_g_k_n1_b_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
in_N1
,
BPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_g_k_n1_b_kpack_global_desc
),
decltype
(
b_g_k_n1_b_kpack_block_desc
),
decltype
(
b_g_k_n1_b_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N1_B_KPack
,
BBlockCopyThreadClusterLengths_G_K_N1_B_KPack
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
4
,
// Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
({
g_block_data_on_global
,
0
,
0
,
b_block_data_on_global
,
0
},
{
0
,
0
,
0
,
0
,
0
});
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, BPerBlock * in_N1] is in LDS
// c_mtx[MPerBlock, BPerBlock * in_N1] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
BPerBlock
*
in_N1
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
BPerWave
,
MWavePerBlock
,
BWavePerBlock
,
1
,
1
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n1_b_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block
[
a_block_space
];
__shared__
ABFloat
p_b_block
[
b_block_space
];
// get zero-initialized output register of vector type
auto
c_thread_vec
=
blockwise_gemm
.
CreateOutputVecZero
();
// preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block
);
}
constexpr
auto
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
constexpr
auto
blockwise_b_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
,
0
>
{};
// main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
KPerBlock
)
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
// load next data from device mem
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
block_sync_lds
();
// GEMM on current data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
block_sync_lds
();
// store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block
);
}
// tail
{
block_sync_lds
();
// GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
UnMerge
<
Sequence
<
M
/
(
M1
*
M2
),
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
blockwise_gemm
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
blockwise_gemm
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
template
GetBeginOfThreadMatrixC
<
MPerWave
,
B
>(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
b_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryOp
>
(
{
0
,
0
,
0
,
0
,
0
},
{
g_block_data_on_global
,
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
g_block_data_on_global
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
}
)
n_thread_data_on_global
)
)
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
...
...
composable_kernel/include/utility/common_header.hpp
View file @
7b002f23
...
...
@@ -27,9 +27,6 @@
#include "amd_inline_asm.hpp"
#endif
#if CK_USE_AMD_XDLOPS
#include "amd_xdlops.hpp"
#include "amd_xdlops_inline_asm.hpp"
#endif
#endif
driver/include/
gridwis
e_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.
c
pp
→
driver/include/
devic
e_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.
h
pp
View file @
7b002f23
#include "common_header.hpp"
#include <unistd.h>
#include "device.hpp"
#include "host_tensor.hpp"
#include "gridwise_operation_wrapper.hpp"
#include "gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "float_types.h"
template
<
class
T
,
class
InDesc
,
...
...
@@ -10,8 +12,8 @@ template <class T,
class
ConvDilations
,
class
InLeftPads
,
class
InRightPads
>
void
gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
void
gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
...
...
@@ -25,29 +27,32 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
{
using
namespace
ck
;
// read params: problem description
constexpr
index_t
G
=
CK_PARAM_PROBLEM_G
;
constexpr
index_t
N
=
CK_PARAM_PROBLEM_N
;
constexpr
index_t
K
=
CK_PARAM_PROBLEM_K
;
constexpr
index_t
C
=
CK_PARAM_PROBLEM_C
;
constexpr
index_t
Hi
=
CK_PARAM_PROBLEM_HI
;
constexpr
index_t
Wi
=
CK_PARAM_PROBLEM_WI
;
constexpr
index_t
Ho
=
CK_PARAM_PROBLEM_HO
;
constexpr
index_t
Wo
=
CK_PARAM_PROBLEM_WO
;
constexpr
index_t
Y
=
CK_PARAM_PROBLEM_Y
;
constexpr
index_t
X
=
CK_PARAM_PROBLEM_X
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{}
;
constexpr
auto
I2
=
Number
<
2
>
{}
;
constexpr
auto
I3
=
Number
<
3
>
{}
;
constexpr
auto
in_nchw_desc
=
make_native_tensor_descriptor
(
InDesc
::
GetLengths
(),
InDesc
::
GetStrides
())
;
constexpr
auto
wei_kcyx_desc
=
make_native_tensor_descriptor
(
WeiDesc
::
GetLengths
(),
WeiDesc
::
GetStrides
())
;
constexpr
auto
out_nkhw_desc
=
make_native_tensor_descriptor
(
OutDesc
::
GetLengths
(),
OutDesc
::
GetStrides
())
;
constexpr
index_t
ConvStrideH
=
CK_PARAM_PROBLEM_CONV_STRIDE_H
;
constexpr
index_t
ConvStrideW
=
CK_PARAM_PROBLEM_CONV_STRIDE_W
;
// read params: problem description
constexpr
index_t
G
=
1
;
constexpr
index_t
ConvDilationH
=
CK_PARAM_PROBLEM_CONV_DILATION_H
;
constexpr
index_t
ConvDilationW
=
CK_PARAM_PROBLEM_CONV_DILATION_W
;
constexpr
index_t
N
=
out_nkhw_desc
.
GetLength
(
I0
);
constexpr
index_t
K
=
out_nkhw_desc
.
GetLength
(
I1
);
constexpr
index_t
Ho
=
out_nkhw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wo
=
out_nkhw_desc
.
GetLength
(
I3
);
constexpr
index_t
InLeftPadH
=
CK_PARAM_PROBLEM_IN_LEFT_PAD_H
;
constexpr
index_t
InLeftPadW
=
CK_PARAM_PROBLEM_IN_LEFT_PAD_W
;
constexpr
index_t
C
=
in_nchw_desc
.
GetLength
(
I1
);
constexpr
index_t
Hi
=
in_nchw_desc
.
GetLength
(
I2
);
constexpr
index_t
Wi
=
in_nchw_desc
.
GetLength
(
I3
);
constexpr
index_t
InRightPadH
=
CK_PARAM_PROBLEM_IN_RIGHT_PAD_H
;
constexpr
index_t
InRightPadW
=
CK_PARAM_PROBLEM_IN_RIGHT_PAD_W
;
constexpr
index_t
Y
=
wei_kcyx_desc
.
GetLength
(
I2
)
;
constexpr
index_t
X
=
wei_kcyx_desc
.
GetLength
(
I3
)
;
constexpr
auto
CPerGroup
=
C
/
G
;
...
...
@@ -58,31 +63,27 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
constexpr
auto
out_n_k_ho_wo_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
N
,
K
,
Ho
,
Wo
>
{});
using
ConvStrides
=
Sequence
<
ConvStrideH
,
ConvStrideW
>
;
using
ConvDilations
=
Sequence
<
ConvDilationH
,
ConvDilationW
>
;
using
InLeftPads
=
Sequence
<
InLeftPadH
,
InLeftPadW
>
;
using
InRightPads
=
Sequence
<
InRightPadH
,
InRightPadW
>
;
// read params: tunning parameters
constexpr
index_t
GemmMPerBlock
=
CK_PARAM_TUNABLE_GEMM_M_PER_BLOCK
;
constexpr
index_t
GemmNPerBlock
=
CK_PARAM_TUNABLE_GEMM_N_PER_BLOCK
;
constexpr
index_t
GemmKPerBlock
=
CK_PARAM_TUNABLE_GEMM_K_PER_BLOCK
;
constexpr
index_t
GemmMPerWave
=
CK_PARAM_TUNABLE_GEMM_M_PER_WAVE
;
constexpr
index_t
GemmNPerWave
=
CK_PARAM_TUNABLE_GEMM_N_PER_WAVE
;
constexpr
index_t
GemmKPack
=
CK_PARAM_TUNABLE_GEMM_KPACK
;
constexpr
index_t
GemmMPerBlock
=
128
;
constexpr
index_t
GemmNPerBlock
=
128
;
constexpr
index_t
GemmKPerBlock
=
8
;
constexpr
index_t
GemmMPerWave
=
64
;
constexpr
index_t
GemmNPerWave
=
64
;
constexpr
index_t
GemmKPack
=
1
;
// read params: dependent parameters
constexpr
index_t
BlockSize
=
CK_PARAM_DEPENDENT_BLOCK_SIZE
;
constexpr
index_t
GridSize
=
CK_PARAM_DEPENDENT_GRID_SIZE
;
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
GemmM
=
K
;
constexpr
index_t
GemmN
=
N
*
Ho
*
Wo
;
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
// A matrix copy
constexpr
index_t
GemmABlockCopyClusterLengths_GemmK
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmM
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_M
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmM
=
64
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmKPack
=
1
;
constexpr
index_t
GemmABlockCopyThreadSliceLengths_GemmK
=
GemmKPerBlock
/
GemmABlockCopyClusterLengths_GemmK
;
...
...
@@ -107,19 +108,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using
GemmABlockCopySrcAccessOrder
=
Sequence
<
0
,
2
,
1
,
3
>
;
// [GemmG, GemmM, GemmK, GemmKPack]
using
GemmABlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmM, GemmKPack]
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_KPACK
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_A_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK
;
constexpr
index_t
GemmABlockCopySrcDataPerRead_GemmKPack
=
1
;
constexpr
index_t
GemmABlockCopyDstDataPerWrite_GemmKPack
=
1
;
// B matrix Copy
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmK
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_K
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmN
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_N
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_CLUSTER_LENGTHS_GEMM_KPACK
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmK
=
4
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmN
=
64
;
constexpr
index_t
GemmBBlockCopyClusterLengths_GemmKPack
=
1
;
constexpr
index_t
GemmBBlockCopyThreadSliceLengths_GemmK
=
GemmKPerBlock
/
GemmBBlockCopyClusterLengths_GemmK
;
...
...
@@ -144,22 +139,20 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
using
GemmBBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [GemmG, GemmK, GemmKPack, GemmN]
using
GemmBBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmN, GemmKPack]
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_SRC_DATA_PER_READ_GEMM_N
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
CK_PARAM_DEPENDENT_GEMM_B_BLOCK_COPY_DST_DATA_PER_WRITE_GEMM_KPACK
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
1
;
// gridwise GEMM
constexpr
auto
wkgrp_schd_order
=
NBlock1MBlock0
;
constexpr
auto
gridwise_conv
=
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
<
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
using
gridwise_conv
=
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
<
GridSize
,
BlockSize
,
FLOAT
,
// Input data type
FLOAT_ACCUM
,
// Acc data type
FLOAT
,
// Ouput data type
TDevice
,
// Input data type
TDevice
,
// Acc data type
TDevice
,
// Ouput data type
decltype
(
in_n_c_hi_wi_desc
),
decltype
(
wei_k_cpergroup_y_x_desc
),
decltype
(
out_n_k_ho_wo_desc
),
...
...
@@ -188,6 +181,48 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw
GemmBBlockCopyDstAccessOrder
,
GemmBBlockCopySrcDataPerRead_GemmN
,
GemmBBlockCopyDstDataPerWrite_GemmKPack
,
wkgrp_schd_order
>
{};
gridwise_conv
.
Run
(
p_in_global
,
p_wei_global
,
p_out_global
);
wkgrp_schd_order
>
;
std
::
size_t
data_sz
=
sizeof
(
T
);
DeviceMem
in_nchw_device_buf
(
data_sz
*
in_nchw
.
mDesc
.
GetElementSpace
());
DeviceMem
wei_kcyx_device_buf
(
data_sz
*
wei_kcyx
.
mDesc
.
GetElementSpace
());
DeviceMem
out_nkhw_device_buf
(
data_sz
*
out_nkhw
.
mDesc
.
GetElementSpace
());
in_nchw_device_buf
.
ToDevice
(
in_nchw
.
mData
.
data
());
wei_kcyx_device_buf
.
ToDevice
(
wei_kcyx
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
for
(
index_t
i
=
0
;
i
<
5
;
++
i
)
{
std
::
cout
<<
"Start running "
<<
nrepeat
<<
" times..."
<<
std
::
endl
;
KernelTimer
timer
;
timer
.
Start
();
for
(
index_t
j
=
0
;
j
<
nrepeat
;
++
j
)
{
launch_kernel
(
run_gridwise_operation
<
gridwise_conv
,
const
TDevice
*
const
__restrict__
,
const
TDevice
*
const
__restrict__
,
TDevice
*
const
__restrict__
>
,
dim3
(
GridSize
),
dim3
(
BlockSize
),
0
,
0
,
static_cast
<
TDevice
*>
(
in_nchw_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
wei_kcyx_device_buf
.
GetDeviceBuffer
()),
static_cast
<
TDevice
*>
(
out_nkhw_device_buf
.
GetDeviceBuffer
()));
}
timer
.
End
();
float
ave_time
=
timer
.
GetElapsedTime
()
/
nrepeat
;
float
perf
=
(
float
)
calculate_convolution_flops
(
InDesc
{},
WeiDesc
{},
OutDesc
{})
/
(
std
::
size_t
(
1000
)
*
1000
*
1000
)
/
ave_time
;
std
::
cout
<<
"Average time : "
<<
ave_time
<<
" ms, "
<<
perf
<<
" TFlop/s"
<<
std
::
endl
;
}
out_nkhw_device_buf
.
FromDevice
(
out_nkhw
.
mData
.
data
());
}
driver/src/conv_driver.cpp
View file @
7b002f23
...
...
@@ -13,6 +13,7 @@
#include "device_tensor.hpp"
#include "device_convolution_forward_implicit_gemm_v4r1_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp"
#include "device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw.hpp"
#include "device_dummy_static_transform.hpp"
#include "device_dummy_dynamic_transform_v1.hpp"
...
...
@@ -111,7 +112,7 @@ int main(int argc, char* argv[])
RightPads{},
nrepeat);
#elif
1
devic
e_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
(
in_nchw_desc
,
gridwis
e_convolution_forward_implicit_gemm_v4r4_
xdlops_
nchw_kcyx_nkhw
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment