Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
888b7c78
Commit
888b7c78
authored
Dec 16, 2020
by
Jing Zhang
Browse files
add original gridwise gemm
parent
821ec5ae
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
291 additions
and
4 deletions
+291
-4
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+2
-1
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+283
-0
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+1
-0
driver/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
...tion_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
+5
-3
No files found.
composable_kernel/include/kernel_algorithm/gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
888b7c78
...
@@ -158,7 +158,8 @@ struct GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
...
@@ -158,7 +158,8 @@ struct GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
// gridwise batch-GEMM
// gridwise batch-GEMM
constexpr
auto
gridwise_gemm
=
GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
<
// constexpr auto gridwise_gemm = GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2<
constexpr
auto
gridwise_gemm
=
GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2_org
<
GridSize
,
GridSize
,
BlockSize
,
BlockSize
,
ABFloat
,
ABFloat
,
...
...
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
View file @
888b7c78
...
@@ -51,6 +51,289 @@ struct make_block_work_sequence<MBlockWork, NBlockWork, NBlock1MBlock0>
...
@@ -51,6 +51,289 @@ struct make_block_work_sequence<MBlockWork, NBlockWork, NBlock1MBlock0>
__device__
constexpr
auto
get
()
{
return
Sequence
<
NBlockWork
,
MBlockWork
>
{};
}
__device__
constexpr
auto
get
()
{
return
Sequence
<
NBlockWork
,
MBlockWork
>
{};
}
};
};
template
<
index_t
GridSize
,
index_t
BlockSize
,
class
ABFloat
,
class
AccFloat
,
class
CFloat
,
class
AGlobalDesc
,
class
BGlobalDesc
,
class
CGlobalDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWave
,
index_t
NPerWave
,
class
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
class
ABlockCopyThreadClusterArrangeOrder
,
class
ABlockCopySrcAccessOrder
,
class
ABlockCopyDstAccessOrder
,
index_t
ABlockCopySrcVectorReadDim
,
index_t
ABlockCopySrcDataPerRead
,
index_t
ABlockCopyDstDataPerWrite_KPACK
,
class
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
class
BBlockCopyThreadClusterArrangeOrder
,
class
BBlockCopySrcAccessOrder
,
class
BBlockCopyDstAccessOrder
,
index_t
BBlockCopySrcVectorReadDim
,
index_t
BBlockCopySrcDataPerRead
,
index_t
BBlockCopyDstDataPerWrite_KPACK
,
InMemoryDataOperation
CGlobalMemoryOp
,
WorkgroupScheduleOrder
WorkgroupSchdOrder
>
struct
GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2_org
{
__device__
void
Run
(
const
ABFloat
*
const
__restrict__
p_a_global
,
const
ABFloat
*
const
__restrict__
p_b_global
,
CFloat
*
const
__restrict__
p_c_global
)
const
{
constexpr
auto
True
=
integral_constant
<
bool
,
true
>
{};
constexpr
auto
a_g_k_m_kpack_global_desc
=
AGlobalDesc
{};
constexpr
auto
b_g_k_n_kpack_global_desc
=
BGlobalDesc
{};
constexpr
auto
c_g_m_n_global_desc
=
CGlobalDesc
{};
constexpr
auto
G
=
c_g_m_n_global_desc
.
GetLengths
()[
0
];
constexpr
auto
M
=
c_g_m_n_global_desc
.
GetLengths
()[
1
];
constexpr
auto
N
=
c_g_m_n_global_desc
.
GetLengths
()[
2
];
constexpr
auto
K
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
1
];
constexpr
auto
KPack
=
b_g_k_n_kpack_global_desc
.
GetLengths
()[
3
];
// divide block work by [M, N]
static_assert
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K
%
KPerBlock
==
0
,
"wrong! cannot divide work evenly among block"
);
constexpr
index_t
MBlockWork
=
M
/
MPerBlock
;
constexpr
index_t
NBlockWork
=
N
/
NPerBlock
;
constexpr
index_t
MWavePerBlock
=
MPerBlock
/
MPerWave
;
constexpr
index_t
NWavePerBlock
=
NPerBlock
/
NPerWave
;
constexpr
auto
block_work_sequence
=
make_batch_block_work_sequence
<
G
,
MBlockWork
,
NBlockWork
,
WorkgroupSchdOrder
>
{}.
get
();
constexpr
auto
block_work_desc
=
make_cluster_descriptor
(
block_work_sequence
);
const
auto
block_work_id
=
block_work_desc
.
CalculateClusterIndex
(
get_block_1d_id
());
const
index_t
g_block_data_on_global
=
block_work_id
[
Number
<
0
>
{}];
const
index_t
m_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
Number
<
1
>
{}]
*
MPerBlock
)
:
(
block_work_id
[
Number
<
2
>
{}]
*
MPerBlock
);
const
index_t
n_block_data_on_global
=
(
WorkgroupSchdOrder
==
MBlock1NBlock0
)
?
(
block_work_id
[
Number
<
2
>
{}]
*
NPerBlock
)
:
(
block_work_id
[
Number
<
1
>
{}]
*
NPerBlock
);
constexpr
index_t
max_align
=
KPack
;
// LDS be careful of LDS alignment
constexpr
auto
a_g_k_m_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
MPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
auto
a_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
a_g_k_m_kpack_global_desc
),
decltype
(
a_g_k_m_kpack_block_desc
),
decltype
(
a_g_k_m_kpack_block_desc
.
GetLengths
()),
ABlockCopyThreadSliceLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterLengths_G_K_M_KPACK
,
ABlockCopyThreadClusterArrangeOrder
,
ABlockCopySrcAccessOrder
,
ABlockCopyDstAccessOrder
,
ABlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
3
,
// Dst dim to be written in vector form (KPack dimension)
ABlockCopySrcDataPerRead
,
ABlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
g_block_data_on_global
,
0
,
m_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
constexpr
auto
b_g_k_n_kpack_block_desc
=
make_native_tensor_descriptor_aligned
(
Sequence
<
1
,
KPerBlock
,
NPerBlock
,
KPack
>
{},
Number
<
max_align
>
{});
// input blockwise copy
auto
b_blockwise_copy
=
BlockwiseGenericTensorSliceCopy_v4
<
BlockSize
,
decltype
(
b_g_k_n_kpack_global_desc
),
decltype
(
b_g_k_n_kpack_block_desc
),
decltype
(
b_g_k_n_kpack_block_desc
.
GetLengths
()),
BBlockCopyThreadSliceLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterLengths_G_K_N_KPACK
,
BBlockCopyThreadClusterArrangeOrder
,
BBlockCopySrcAccessOrder
,
BBlockCopyDstAccessOrder
,
BBlockCopySrcVectorReadDim
,
// Src dim to be read in vector form
3
,
// Dst dim to be written in vector form (KPack dimension)
BBlockCopySrcDataPerRead
,
BBlockCopyDstDataPerWrite_KPACK
,
AddressSpace
::
Global
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Lds
,
InMemoryDataOperation
::
Set
>
(
make_multi_index
(
g_block_data_on_global
,
0
,
n_block_data_on_global
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
));
// GEMM definition
// c_mtx += transpose(a_mtx) * b_mtx
// a_mtx[KPerBlock, MPerBlock] is in LDS
// b_mtx[KPerBlocl, NPerBlock] is in LDS
// c_mtx[MPerBlock, NPerBlock] is distributed among threads, and saved in
// register
constexpr
auto
a_k_m_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
MPerBlock
>
{});
constexpr
auto
b_k_n_block_mtx_desc
=
make_ConstantMatrixDescriptor_packed
(
Number
<
KPerBlock
>
{},
Number
<
NPerBlock
>
{});
const
auto
blockwise_gemm
=
BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_xdlops
<
BlockSize
,
decltype
(
a_k_m_block_mtx_desc
),
decltype
(
b_k_n_block_mtx_desc
),
ABFloat
,
MPerWave
,
NPerWave
,
MWavePerBlock
,
NWavePerBlock
,
1
,
1
>
{};
constexpr
index_t
a_block_space
=
math
::
integer_least_multiple
(
a_g_k_m_kpack_block_desc
.
GetElementSpace
(),
max_align
);
constexpr
index_t
b_block_space
=
math
::
integer_least_multiple
(
b_g_k_n_kpack_block_desc
.
GetElementSpace
(),
max_align
);
__shared__
ABFloat
p_a_block
[
a_block_space
];
__shared__
ABFloat
p_b_block
[
b_block_space
];
// get zero-initialized output register of vector type
constexpr
index_t
c_thread_size
=
MPerBlock
*
NPerBlock
/
BlockSize
;
auto
c_thread_vec
=
GetRegBuffer
<
AccFloat
,
c_thread_size
>
();
// preload data into LDS
{
a_blockwise_copy
.
Run
(
p_a_global
,
p_a_block
);
b_blockwise_copy
.
Run
(
p_b_global
,
p_b_block
);
}
constexpr
auto
blockwise_a_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
constexpr
auto
blockwise_b_copy_src_step
=
Sequence
<
0
,
KPerBlock
,
0
,
0
>
{};
// main body
for
(
index_t
k_block_data_begin
=
0
;
k_block_data_begin
<
K
-
KPerBlock
;
k_block_data_begin
+=
KPerBlock
)
{
ABFloat
p_a_thread_buffer
[
a_blockwise_copy
.
GetThreadBufferSize
()];
ABFloat
p_b_thread_buffer
[
b_blockwise_copy
.
GetThreadBufferSize
()];
// load next data from device mem
a_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_a_copy_src_step
,
True
);
b_blockwise_copy
.
MoveSrcSliceWindow
(
blockwise_b_copy_src_step
,
True
);
a_blockwise_copy
.
RunLoadThreadBuffer
(
p_a_global
,
p_a_thread_buffer
);
b_blockwise_copy
.
RunLoadThreadBuffer
(
p_b_global
,
p_b_thread_buffer
);
block_sync_lds
();
// GEMM on current data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
block_sync_lds
();
// store next data to LDS
a_blockwise_copy
.
RunStoreThreadBuffer
(
p_a_thread_buffer
,
p_a_block
);
b_blockwise_copy
.
RunStoreThreadBuffer
(
p_b_thread_buffer
,
p_b_block
);
}
// tail
{
block_sync_lds
();
// GEMM on last data
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_a_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_a_block
);
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*
p_b_block_vec
=
reinterpret_cast
<
const
typename
vector_type
<
ABFloat
,
KPack
>::
MemoryType
*>
(
p_b_block
);
c_thread_vec
=
blockwise_gemm
.
Run
(
p_a_block_vec
,
p_b_block_vec
,
c_thread_vec
);
}
// copy output: register to global memory
{
///\todo inconsistent layout of xdlops and tensor
// xdlops layout
// M1 = num_groups;
// M0 = group_size;
// N1 = num_blks_per_wave;
// N0 = num_threads_per_blks;
constexpr
auto
CLayout
=
blockwise_gemm
.
GetOutputLayout
();
constexpr
index_t
M0
=
CLayout
.
M1
();
constexpr
index_t
M1
=
CLayout
.
N1
();
constexpr
index_t
M2
=
CLayout
.
M0
();
constexpr
auto
c_g_m0_m1_m2_n_global_desc
=
transform_tensor_descriptor
(
c_g_m_n_global_desc
,
make_tuple
(
PassThrough
<
G
>
{},
UnMerge
<
Sequence
<
M
/
(
M1
*
M2
),
M1
,
M2
>>
{},
PassThrough
<
N
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
,
3
>
{},
Sequence
<
4
>
{}));
// src descriptor
constexpr
auto
c_g_m0_m1_m2_n_thread_desc
=
make_native_tensor_descriptor_packed
(
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
{});
using
CThreadCopySliceLengths
=
Sequence
<
1
,
M0
,
1
,
M2
,
1
>
;
constexpr
index_t
BlkSize
=
blockwise_gemm
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
blockwise_gemm
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
#pragma unroll
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
1
,
1
,
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryOp
>
(
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
g_block_data_on_global
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
))
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
}
}
}
};
template
<
index_t
GridSize
,
template
<
index_t
GridSize
,
index_t
BlockSize
,
index_t
BlockSize
,
class
ABFloat
,
class
ABFloat
,
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
888b7c78
...
@@ -214,6 +214,7 @@ union float_vec128_t
...
@@ -214,6 +214,7 @@ union float_vec128_t
StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {}
__host__ __device__ constexpr float_vec128_t() {}
template <index_t vs>
template <index_t vs>
...
...
driver/include/device_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw.hpp
View file @
888b7c78
...
@@ -80,6 +80,8 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
...
@@ -80,6 +80,8 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
constexpr
index_t
GridSize
=
math
::
integer_divide_ceil
(
GemmM
,
GemmMPerBlock
)
*
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
math
::
integer_divide_ceil
(
GemmN
,
GemmNPerBlock
);
static_assert
(
GridSize
==
1568
,
""
);
// A matrix copy
// A matrix copy
constexpr
index_t
GemmABlockCopyClusterLengths_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmK
=
4
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmM
=
64
;
constexpr
index_t
GemmABlockCopyClusterLengths_GemmM
=
64
;
...
@@ -139,13 +141,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
...
@@ -139,13 +141,13 @@ void gridwise_convolution_forward_implicit_gemm_v4r4_xdlops_nchw_kcyx_nkhw(
using
GemmBBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [GemmG, GemmK, GemmKPack, GemmN]
using
GemmBBlockCopySrcAccessOrder
=
Sequence
<
0
,
1
,
3
,
2
>
;
// [GemmG, GemmK, GemmKPack, GemmN]
using
GemmBBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmN, GemmKPack]
using
GemmBBlockCopyDstAccessOrder
=
Sequence
<
0
,
1
,
2
,
3
>
;
// [GemmG, GemmK, GemmN, GemmKPack]
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
1
;
constexpr
index_t
GemmBBlockCopySrcDataPerRead_GemmN
=
4
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
1
;
constexpr
index_t
GemmBBlockCopyDstDataPerWrite_GemmKPack
=
4
;
// gridwise GEMM
// gridwise GEMM
constexpr
auto
wkgrp_schd_order
=
NBlock1MBlock0
;
constexpr
auto
wkgrp_schd_order
=
NBlock1MBlock0
;
using
TDevice
=
typename
conditional
<
is_same
<
half_float
::
half
,
T
>::
value
,
half_t
,
T
>::
type
;
using
TDevice
=
float
;
using
gridwise_conv
=
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
<
using
gridwise_conv
=
GridwiseConvolutionForwardImplicitGemm_v4r4_xdlops_nchw_kcyx_nkhw
<
GridSize
,
GridSize
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment