Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
2178d1d8
Commit
2178d1d8
authored
Apr 26, 2021
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/no_array' into no_raw_index_calculation
parents
fa163f3b
7484a103
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
1316 additions
and
661 deletions
+1316
-661
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+196
-190
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
+81
-93
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+42
-31
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
...nel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
+43
-56
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
.../tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
+59
-0
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+279
-56
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+95
-123
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
+107
-82
composable_kernel/include/utility/amd_inline_asm.hpp
composable_kernel/include/utility/amd_inline_asm.hpp
+69
-0
composable_kernel/include/utility/buffer.hpp
composable_kernel/include/utility/buffer.hpp
+72
-0
composable_kernel/include/utility/common_header.hpp
composable_kernel/include/utility/common_header.hpp
+1
-0
composable_kernel/include/utility/config.amd.hpp.in
composable_kernel/include/utility/config.amd.hpp.in
+3
-3
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+237
-9
composable_kernel/include/utility/sequence_helper.hpp
composable_kernel/include/utility/sequence_helper.hpp
+15
-0
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
...convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
+2
-2
driver/src/conv_driver.cpp
driver/src/conv_driver.cpp
+15
-16
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
2178d1d8
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/blockwise_gemm_v3.hpp
View file @
2178d1d8
...
...
@@ -6,12 +6,10 @@
namespace
ck
{
// blockwise GEMM: C[M, N] += transpose(A[K, M]) * B[K, N]
// A and B are visable to the whole block, C is distributed among each thread
// If following number are power of 2, index calculation shall be greatly reduced:
// KPerThread, HPerThread, MLevel0ThreadCluster, NLevel0ThreadCluster,
// MLevel1ThreadCluster, NLevel1ThreadCluster
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
BlockMatrixA
,
typename
BlockMatrixB
,
typename
ThreadMatrixC
,
...
...
@@ -30,9 +28,34 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
index_t
w
;
};
index_t
mMyThreadOffsetA
;
// HACK: fix this @Jing Zhang
static
constexpr
index_t
KPerThreadSubC
=
4
;
static
constexpr
auto
a_thread_mtx_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
static
constexpr
auto
b_thread_mtx_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
static
constexpr
auto
c_thread_mtx_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
BlockMatrixA
,
decltype
(
a_thread_mtx_
),
Sequence
<
EPerThreadLoop
,
KPerThreadSubC
>
,
Sequence
<
0
,
1
>
,
1
,
ThreadGemmADataPerRead_K
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
__device__
BlockwiseGemm_km_kn_m0m1n0n1_v3
()
:
c_thread_begin_mtx_idx_
{
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_begin_mtx_idx_
.
k
*
KPerThread
)}
{
static_assert
(
BlockMatrixA
::
IsKnownAtCompileTime
()
&&
BlockMatrixB
::
IsKnownAtCompileTime
()
&&
...
...
@@ -61,11 +84,6 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
BlockSize
==
KThreadCluster
*
HThreadCluster
*
WThreadCluster
,
"wrong! wrong blocksize
\n
"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
mMyThreadOffsetA
=
BlockMatrixA
{}.
CalculateOffset
(
make_tuple
(
0
,
c_thread_mtx_index
.
k
*
KPerThread
));
}
__device__
static
constexpr
auto
GetThreadMatrixCLengths
()
...
...
@@ -91,37 +109,18 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
return
MatrixIndex
{
k_thread_id
,
h_thread_id
,
w_thread_id
};
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseSliceCopy_a
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type_maker
<
Data
,
DataPerAccess
>::
type
::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run_naive
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_thread
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockBuffer
,
typename
BThreadBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BThreadBuffer
&
b_thread_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABlockBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CThreadBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -132,8 +131,7 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
constexpr
auto
EPerBlock
=
a_block_mtx
.
GetLength
(
I0
);
constexpr
auto
KPerThreadSubC
=
4
;
// HACK: fix this @Jing Zhang
constexpr
auto
HoPerThreadSubC
=
2
;
constexpr
auto
WoPerThreadSubC
=
2
;
...
...
@@ -141,63 +139,53 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v3
static_assert
(
HPerThread
%
HoPerThreadSubC
==
0
,
""
);
static_assert
(
WPerThread
%
WoPerThreadSubC
==
0
,
""
);
// thread A, B for GEMM
constexpr
auto
a_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
KPerThreadSubC
>
{}));
constexpr
auto
b_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
EPerThreadLoop
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
constexpr
auto
c_thread_mtx
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadSubC
>
{},
Number
<
1
>
{},
Number
<
HPerThread
>
{},
Number
<
WPerThread
>
{}));
FloatA
p_a_thread
[
a_thread_mtx
.
GetElementSpaceSize
()];
// thread A buffer for GEMM
StaticBuffer
<
FloatA
,
a_thread_mtx_
.
GetElementSpaceSize
()
>
a_thread_buf
;
constexpr
auto
a_thread_copy
=
ThreadwiseSliceCopy_a
<
BlockMatrixA
,
decltype
(
a_thread_mtx
),
EPerThreadLoop
,
KPerThreadSubC
,
ThreadGemmADataPerRead_K
>
{};
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
decltype
(
a_thread_mtx
),
decltype
(
b_thread_mtx
),
decltype
(
c_thread_mtx
),
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km_kn_mn_v3
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_mtx_
),
decltype
(
b_thread_mtx_
),
decltype
(
c_thread_mtx_
),
HoPerThreadSubC
,
WoPerThreadSubC
>
{};
// loop over k
#pragma unroll
for
(
index_t
e_begin
=
0
;
e_begin
<
EPerBlock
;
e_begin
+=
EPerThreadLoop
)
{
#pragma unroll
for
(
index_t
k_begin
=
0
;
k_begin
<
KPerThread
;
k_begin
+=
KPerThreadSubC
)
{
a_thread_copy
.
Run
(
p_a_block
+
a_block_mtx
.
CalculateOffset
(
make_tuple
(
e_begin
,
k_begin
))
+
mMyThreadOffsetA
,
p_a_thread
);
#pragma unroll
for
(
index_t
h_begin
=
0
;
h_begin
<
HPerThread
;
h_begin
+=
HoPerThreadSubC
)
{
#pragma unroll
for
(
index_t
w_begin
=
0
;
w_begin
<
WPerThread
;
w_begin
+=
WoPerThreadSubC
)
{
threadwise_gemm
.
Run
(
p_a_thread
,
p_b_thread
+
b_thread_mtx
.
CalculateOffset
(
make_tuple
(
e_begin
,
0
,
h_begin
,
w_begin
)),
p_c_thread
+
c_thread_mtx
.
CalculateOffset
(
make_tuple
(
k_begin
,
0
,
h_begin
,
w_begin
)));
}
}
}
}
static_for
<
0
,
EPerBlock
,
EPerThreadLoop
>
{}([
&
](
auto
e_begin
)
{
static_for
<
0
,
KPerThread
,
KPerThreadSubC
>
{}([
&
](
auto
k_begin
)
{
a_thread_copy_
.
Run
(
a_block_mtx
,
make_tuple
(
e_begin
,
k_begin
),
a_block_buf
,
a_thread_mtx_
,
make_tuple
(
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
HPerThread
,
HoPerThreadSubC
>
{}([
&
](
auto
h_begin
)
{
static_for
<
0
,
WPerThread
,
WoPerThreadSubC
>
{}([
&
](
auto
w_begin
)
{
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
),
b_thread_buf
,
make_tuple
(
e_begin
,
I0
,
h_begin
,
w_begin
),
c_thread_buf
,
make_tuple
(
k_begin
,
I0
,
h_begin
,
w_begin
));
});
});
});
});
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
void
Run
(
const
FloatA
*
p_a_block
,
const
FloatB
*
p_b_thread
,
FloatC
*
p_c_thread
)
const
template
<
typename
ABlockSliceMoveStepIdx
>
__device__
void
MoveASliceWindow
(
const
BlockMatrixA
&
,
const
ABlockSliceMoveStepIdx
&
a_block_slice_move_step_idx
)
{
Run_naive
(
p_a_block
,
p
_b
_thread
,
p_c_thread
);
a_thread_copy_
.
MoveSrcSliceWindow
(
BlockMatrixA
{}
,
a
_b
lock_slice_move_step_idx
);
}
private:
MatrixIndex
c_thread_begin_mtx_idx_
;
AThreadCopy
a_thread_copy_
;
};
}
// namespace ck
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
2178d1d8
...
...
@@ -5,9 +5,10 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "
blockwise_gemm_v2
.hpp"
#include "
threadwise_dynamic_tensor_slice_set
.hpp"
namespace
ck
{
...
...
@@ -256,19 +257,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
make_tuple
(
Number
<
MRepeat
*
MPerThread
>
{},
Number
<
NRepeat
*
NPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v1
<
BlockSize
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
NPerThread
>
{};
BlockwiseGemm_km_kn_m0m1n0n1_v1r1
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_k_m_block_desc
),
decltype
(
b_k_n_block_desc
),
decltype
(
c_m0m1_n0n1_thread_desc
),
MPerThread
,
NPerThread
,
KPerThread
,
MLevel0Cluster
,
NLevel0Cluster
,
MLevel1Cluster
,
NLevel1Cluster
,
MPerThread
,
NPerThread
>
{};
// LDS allocation for A and B: be careful of alignment
constexpr
auto
a_block_space_size
=
...
...
@@ -281,10 +285,13 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
FloatAcc
p_c_thread
[
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
()];
auto
c_thread_buf
=
make_static_buffer
<
FloatAcc
>
(
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
());
// zero out threadwise output
threadwise_matrix_set_zero_v2
(
c_m0m1_n0n1_thread_desc
,
p_c_thread
);
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0m1_n0n1_thread_desc
),
Sequence
<
MRepeat
*
MPerThread
,
NRepeat
*
NPerThread
>>
{}
.
Run
(
c_m0m1_n0n1_thread_desc
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
...
@@ -300,6 +307,18 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
constexpr
auto
b_k_n_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
FloatAB
*
p_a_block_even
=
p_a_block_double
;
FloatAB
*
p_b_block_even
=
p_b_block_double
;
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
auto
a_block_even_buf
=
make_dynamic_buffer
(
p_a_block_even
);
auto
b_block_even_buf
=
make_dynamic_buffer
(
p_b_block_even
);
auto
a_block_odd_buf
=
make_dynamic_buffer
(
p_a_block_odd
);
auto
b_block_odd_buf
=
make_dynamic_buffer
(
p_b_block_odd
);
// LDS double buffer: preload data into LDS
{
a_blockwise_copy
.
RunRead
(
a_k_m_global_desc
,
p_a_global
,
a_k_m_global_iterator_hacks
);
...
...
@@ -311,12 +330,6 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
if
constexpr
(
HasMainKBlockLoop
)
{
FloatAB
*
p_a_block_even
=
p_a_block_double
;
FloatAB
*
p_b_block_even
=
p_b_block_double
;
FloatAB
*
p_a_block_odd
=
p_a_block_double
+
a_block_space_size
;
FloatAB
*
p_b_block_odd
=
p_b_block_double
+
b_block_space_size
;
index_t
k_block_data_begin
=
0
;
// LDS double buffer: main body
...
...
@@ -340,7 +353,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_
a_block_even
,
p_
b_block_even
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_even
_buf
,
b_block_even
_buf
,
c_thread
_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_odd
);
...
...
@@ -363,7 +376,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_
a_block_odd
,
p_
b_block_odd
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_odd
_buf
,
b_block_odd
_buf
,
c_thread
_buf
);
// LDS double buffer: store next data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_even
);
...
...
@@ -390,7 +403,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
b_blockwise_copy
.
RunRead
(
b_k_n_global_desc
,
p_b_global
,
b_k_n_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_
a_block_
double
,
p_b_block_double
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_
even_buf
,
b_block_even_buf
,
c_thread
_buf
);
// LDS double buffer: store last data to LDS
a_blockwise_copy
.
RunWrite
(
a_k_m_block_desc
,
p_a_block_double
+
a_block_space_size
);
...
...
@@ -399,16 +412,14 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block_double
+
a_block_space_size
,
p_b_block_double
+
b_block_space_size
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_odd_buf
,
b_block_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
__syncthreads
();
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_
a_block_
double
,
p_b_block_double
,
p_
c_thread
);
blockwise_gemm
.
Run
(
a_block_
even_buf
,
b_block_even_buf
,
c_thread
_buf
);
}
// output: register to global memory
...
...
@@ -461,7 +472,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
n_thread_data_on_global
%
N1
))
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
c_thread
,
c_thread
_buf
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm_v2.hpp
View file @
2178d1d8
...
...
@@ -145,17 +145,19 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThread
>
{},
Number
<
1
>
{},
Number
<
HoPerThread
>
{},
Number
<
WoPerThread
>
{}));
const
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
>
{};
auto
blockwise_gemm
=
BlockwiseGemm_km_kn_m0m1n0n1_v3
<
BlockSize
,
FloatAB
,
FloatAB
,
FloatAcc
,
decltype
(
a_e_k_block_desc
),
decltype
(
b_e_n_ho_wo_block_desc
),
decltype
(
c_k_n_ho_wo_thread_desc
),
KPerThread
,
HoPerThread
,
WoPerThread
,
EPerThread
,
ABlockTransferSrcScalarPerVector
,
ABlockTransferDstScalarPerVector_K
>
{};
auto
c_thread_mtx_index
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
...
@@ -223,11 +225,16 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
FloatAB
*
p_a_block
=
p_shared_block
;
auto
a_block_buf
=
make_dynamic_buffer
(
p_a_block
);
// register allocation for output
FloatAcc
p_c_thread
[
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
]
;
StaticBuffer
<
FloatAcc
,
c_k_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
c_thread_buf
;
// zero out threadwise output
threadwise_matrix_set_zero_v3
(
c_k_n_ho_wo_thread_desc
,
p_c_thread
);
// initialize output thread tensor
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_k_n_ho_wo_thread_desc
),
Sequence
<
KPerThread
,
1
,
HoPerThread
,
WoPerThread
>>
{}
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
constexpr
auto
b_thread_slice_copy_step
=
make_multi_index
(
EPerBlock
,
0
,
0
,
0
);
...
...
@@ -242,12 +249,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
constexpr
auto
b_e_n_ho_wo_global_move_slice_window_iterator_hack
=
BGlobalMoveSliceWindowIteratorHacks
{};
constexpr
auto
b_thread_space_size
=
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
();
FloatAB
p_b_thread
[
b_thread_space_size
*
2
];
FloatAB
*
p_b_thread_double
=
p_b_thread
;
// double regsiter buffer for b
StaticBuffer
<
FloatAB
,
b_e_n_ho_wo_thread_desc
.
GetElementSpaceSize
()
>
b_thread_even_buf
,
b_thread_odd_buf
;
// LDS double buffer: preload data
into LDS
// LDS double buffer: preload data
{
a_blockwise_copy
.
RunRead
(
a_e_k_global_desc
,
p_a_global
,
a_e_k_global_iterator_hacks
);
...
...
@@ -255,7 +261,7 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_
double
,
b_thread_
even_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
a_blockwise_copy
.
RunWrite
(
a_e_k_desc
,
p_a_block
);
...
...
@@ -263,13 +269,9 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
__syncthreads
();
index_t
b_block_data_begin
=
0
;
#if 1
if
constexpr
(
HasMainKBlockLoop
)
{
FloatAB
*
p_b_thread_even
=
p_b_thread_double
;
FloatAB
*
p_b_thread_odd
=
p_b_thread_double
+
b_thread_space_size
;
index_t
e_block_data_begin
=
0
;
// LDS double buffer: main body
// use Do-While loop instead of For loop to simplify control flow
...
...
@@ -283,16 +285,14 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_odd
,
b_thread_odd
_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_even
,
p_c_thread
);
// TODO: @Zhang Jing: blockwise gemm should be able to move slice window
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
b_
block
_data_begin
+=
EPerBlock
;
block
wise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
))
;
b_threadwise_transfer
.
MoveSrcSliceWindow
(
b_e_n_ho_wo_global_desc
,
b_thread_slice_copy_step
);
...
...
@@ -301,18 +301,17 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_even
,
b_thread_even
_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on current data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_odd
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
b_
block
_data_begin
+=
EPerBlock
;
block
wise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
))
;
}
while
(
b_block_data_begin
<
E
-
2
*
EPerBlock
);
e_block_data_begin
+=
2
*
EPerBlock
;
}
while
(
e_block_data_begin
<
E
-
2
*
EPerBlock
);
}
// LDS double buffer: tail
...
...
@@ -325,34 +324,23 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
p_b_global
,
b_e_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
b_thread_
double
+
b_thread_space_size
,
b_thread_
odd_buf
,
b_e_n_ho_wo_global_iterator_hacks
);
// LDS double buffer: GEMM on 2nd-last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
b_
block
_data_begin
+=
EPerBlock
;
block
wise_gemm
.
MoveASliceWindow
(
a_e_k_block_desc
,
make_tuple
(
EPerBlock
,
0
))
;
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_double
+
b_thread_space_size
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_odd_buf
,
c_thread_buf
);
}
else
// if has 1 iteration left
{
// LDS double buffer: GEMM on last data
blockwise_gemm
.
Run
(
p_a_block
+
a_e_k_block_desc
.
CalculateOffset
(
make_tuple
(
b_block_data_begin
,
0
)),
p_b_thread_double
,
p_c_thread
);
blockwise_gemm
.
Run
(
a_block_buf
,
b_thread_even_buf
,
c_thread_buf
);
}
#endif
#if 1
// output: register to global memory
{
// hack to control index calculation when iterating over c_k_n_ho_wo_global tensor
...
...
@@ -380,12 +368,11 @@ struct GridwiseDynamicGemm_km_kn_mn_v3
k_thread_data_on_global
,
0
,
ho_thread_data_on_global
,
wo_thread_data_on_global
))
.
Run
(
c_k_n_ho_wo_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
c_thread
,
c_thread
_buf
,
c_k_n_ho_wo_global_desc
,
p_c_global
,
c_k_n_ho_wo_global_tensor_iterator_hacks
);
}
#endif
}
// pass tensor descriptor by reference
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
0 → 100644
View file @
2178d1d8
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace
ck
{
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-iterator
template
<
typename
Data
,
typename
Desc
,
typename
SliceLengths
,
typename
std
::
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseDynamicTensorSliceSet_v1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
template
<
typename
OriginIdx
,
typename
Buffer
>
__device__
void
Run
(
const
Desc
&
,
const
OriginIdx
&
,
Buffer
&
buf
,
const
Data
&
initial_value
)
const
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
Buffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
OriginIdx
>>>::
value
,
"wrong! OriginIdx need to be known at compile-time"
);
// Desc is known at compile-time
constexpr
auto
desc
=
remove_cv_t
<
remove_reference_t
<
Desc
>>
{};
// OriginIdx is known at compile-time
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
coord
=
make_dynamic_tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
constexpr
index_t
offset
=
coord
.
GetOffset
();
if
constexpr
(
is_valid
)
{
buf
(
Number
<
offset
>
{})
=
initial_value
;
}
});
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
2178d1d8
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
2178d1d8
...
...
@@ -6,100 +6,52 @@
namespace
ck
{
template
<
typename
Float
,
typename
Desc
>
__device__
void
threadwise_matrix_set_zero_v2
(
Desc
,
Float
*
__restrict__
p_thread
)
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
desc
=
Desc
{};
constexpr
auto
M
=
desc
.
GetLength
(
I0
);
constexpr
auto
N
=
desc
.
GetLength
(
I1
);
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
offset
=
desc
.
CalculateOffset
(
make_tuple
(
i
,
j
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
}
template
<
typename
SrcDesc
,
typename
DstDesc
,
index_t
NSliceRow
,
index_t
NSliceCol
,
index_t
DataPerAccess
>
struct
ThreadwiseMatrixSliceCopy_v2
{
template
<
typename
Data
>
__device__
static
void
Run
(
const
Data
*
p_src
,
Data
*
p_dst
)
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
using
vector_t
=
typename
vector_type_maker
<
Data
,
DataPerAccess
>::
type
::
type
;
static_for
<
0
,
NSliceRow
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
NSliceCol
,
DataPerAccess
>
{}([
&
](
auto
j
)
{
constexpr
auto
src_offset
=
SrcDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
constexpr
auto
dst_offset
=
DstDesc
{}.
CalculateOffset
(
make_tuple
(
i
,
j
));
*
reinterpret_cast
<
vector_t
*>
(
&
p_dst
[
dst_offset
])
=
*
reinterpret_cast
<
const
vector_t
*>
(
&
p_src
[
src_offset
]);
});
});
}
};
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km_kn_mn_v1
struct
ThreadwiseGemm_km_kn_mn_v1
r1
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
M
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
n
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
n
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
});
});
});
}
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_amd_asm
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong!
Desc should be known at compile-tim
e"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong!
inconsistent typ
e"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -110,61 +62,81 @@ struct ThreadwiseGemm_km_kn_mn_v1
constexpr
auto
N
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
static_assert
(
N
==
4
||
N
==
2
,
"wrong! this config not supported by asm yet"
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
k
,
m
));
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
k
,
m
));
#if 0
if constexpr(N == 2)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
amd_assembly_outer_product_1x2
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
]);
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
amd_assembly_outer_product_1x2(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}));
}
else if constexpr(N == 4)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I2
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
I3
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I2
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
m
,
I3
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
constexpr index_t b_offset_0 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I0));
constexpr index_t b_offset_1 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I1));
constexpr index_t b_offset_2 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I2));
constexpr index_t b_offset_3 =
BDesc{}.CalculateOffset(b_origin_idx + make_tuple(k, I3));
constexpr index_t c_offset_0 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I0));
constexpr index_t c_offset_1 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I1));
constexpr index_t c_offset_2 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I2));
constexpr index_t c_offset_3 =
CDesc{}.CalculateOffset(c_origin_idx + make_tuple(m, I3));
amd_assembly_outer_product_1x4(a_buf[Number<a_offset>{}],
b_buf[Number<b_offset_0>{}],
b_buf[Number<b_offset_1>{}],
b_buf[Number<b_offset_2>{}],
b_buf[Number<b_offset_3>{}],
c_buf(Number<c_offset_0>{}),
c_buf(Number<c_offset_1>{}),
c_buf(Number<c_offset_2>{}),
c_buf(Number<c_offset_3>{}));
}
});
});
}
else
#endif
{
static_for
<
0
,
N
,
1
>
{}([
&
](
auto
n
)
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
#if CK_THREADWISE_GEMM_USE_AMD_INLINE_ASM
Run_amd_asm
(
p_a
,
p_b
,
p_c
);
#else
Run_source
(
p_a
,
p_b
,
p_c
);
#endif
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
k
,
n
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
m
,
n
));
amd_assembly_inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
});
}
});
});
}
};
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v3.hpp
View file @
2178d1d8
...
...
@@ -6,35 +6,15 @@
namespace
ck
{
template
<
typename
Float
,
typename
Desc
>
__device__
void
threadwise_matrix_set_zero_v3
(
Desc
,
Float
*
__restrict__
p_thread
)
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
desc
=
Desc
{};
constexpr
auto
K
=
desc
.
GetLength
(
I0
);
constexpr
auto
H
=
desc
.
GetLength
(
I2
);
constexpr
auto
W
=
desc
.
GetLength
(
I3
);
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
j
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
k
)
{
constexpr
auto
offset
=
desc
.
CalculateOffset
(
make_tuple
(
i
,
0
,
j
,
k
));
p_thread
[
offset
]
=
Float
(
0
);
});
});
});
}
// C[M, N] += transpose(A[K, M]) * B[K, N]
// Element of matrix can be vectorized data
template
<
typename
ADesc
,
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
index_t
H
,
...
...
@@ -44,13 +24,37 @@ template <typename ADesc,
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km_kn_mn_v3
{
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run_source
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
...
...
@@ -59,79 +63,100 @@ struct ThreadwiseGemm_km_kn_mn_v3
constexpr
auto
E
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I1
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
E
,
1
>
{}([
&
](
auto
e
)
{
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
constexpr
auto
a_offset
=
ADesc
{}.
CalculateOffset
(
make_tuple
(
e
,
k
));
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_tuple
(
e
,
k
));
if
constexpr
(
H
==
2
&&
W
==
2
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
0
,
1
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
1
,
1
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
0
,
1
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
1
,
1
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
constexpr
index_t
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
index_t
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
1
));
constexpr
index_t
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
index_t
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
1
));
constexpr
index_t
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
index_t
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
1
));
constexpr
index_t
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
index_t
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
1
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset_0
>
{}],
b_buf
[
Number
<
b_offset_1
>
{}],
b_buf
[
Number
<
b_offset_2
>
{}],
b_buf
[
Number
<
b_offset_3
>
{}],
c_buf
(
Number
<
c_offset_0
>
{}),
c_buf
(
Number
<
c_offset_1
>
{}),
c_buf
(
Number
<
c_offset_2
>
{}),
c_buf
(
Number
<
c_offset_3
>
{}));
}
else
if
constexpr
(
H
==
4
&&
W
==
1
)
{
constexpr
auto
b_offset_0
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
auto
b_offset_1
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
auto
b_offset_2
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
2
,
0
));
constexpr
auto
b_offset_3
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
3
,
0
));
constexpr
auto
c_offset_0
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
auto
c_offset_1
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
auto
c_offset_2
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
2
,
0
));
constexpr
auto
c_offset_3
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
3
,
0
));
amd_assembly_outer_product_1x4
(
p_a
[
a_offset
],
p_b
[
b_offset_0
],
p_b
[
b_offset_1
],
p_b
[
b_offset_2
],
p_b
[
b_offset_3
],
p_c
[
c_offset_0
],
p_c
[
c_offset_1
],
p_c
[
c_offset_2
],
p_c
[
c_offset_3
]);
constexpr
index_t
b_offset_0
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
0
,
0
));
constexpr
index_t
b_offset_1
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
1
,
0
));
constexpr
index_t
b_offset_2
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
2
,
0
));
constexpr
index_t
b_offset_3
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
3
,
0
));
constexpr
index_t
c_offset_0
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
0
,
0
));
constexpr
index_t
c_offset_1
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
1
,
0
));
constexpr
index_t
c_offset_2
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
2
,
0
));
constexpr
index_t
c_offset_3
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
3
,
0
));
amd_assembly_outer_product_1x4
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset_0
>
{}],
b_buf
[
Number
<
b_offset_1
>
{}],
b_buf
[
Number
<
b_offset_2
>
{}],
b_buf
[
Number
<
b_offset_3
>
{}],
c_buf
(
Number
<
c_offset_0
>
{}),
c_buf
(
Number
<
c_offset_1
>
{}),
c_buf
(
Number
<
c_offset_2
>
{}),
c_buf
(
Number
<
c_offset_3
>
{}));
}
else
{
static_for
<
0
,
H
,
1
>
{}([
&
](
auto
h
)
{
static_for
<
0
,
W
,
1
>
{}([
&
](
auto
w
)
{
constexpr
auto
b_offset
=
BDesc
{}.
CalculateOffset
(
make_tuple
(
e
,
0
,
h
,
w
));
constexpr
auto
c_offset
=
CDesc
{}.
CalculateOffset
(
make_tuple
(
k
,
0
,
h
,
w
));
p_c
[
c_offset
]
+=
inner_product_with_conversion
<
FloatC
>
{}(
p_a
[
a_offset
],
p_b
[
b_offset
]);
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_tuple
(
e
,
0
,
h
,
w
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_tuple
(
k
,
0
,
h
,
w
));
#if 0
c_buf(Number<c_offset>{}) += inner_product_with_conversion<FloatC>{}(
a_buf[Number<a_offset>{}], b_buf[Number<b_offset>{}]);
#else
amd_assembly_inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
#endif
});
});
}
});
});
}
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
>
__device__
static
void
Run
(
const
FloatA
*
p_a
,
const
FloatB
*
p_b
,
FloatC
*
p_c
)
{
Run_source
(
p_a
,
p_b
,
p_c
);
}
};
}
// namespace ck
...
...
composable_kernel/include/utility/amd_inline_asm.hpp
View file @
2178d1d8
...
...
@@ -5,6 +5,75 @@
namespace
ck
{
// c += inner_product(a, b)
__device__
void
amd_assembly_inner_product
(
const
float
&
a
,
const
float
&
b
,
float
&
c
)
{
#if CK_USE_AMD_V_FMAC_F32
asm
volatile
(
"
\n
\
v_fmac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#else
asm
volatile
(
"
\n
\
v_mac_f32 %0, %1, %2
\n
\
"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
),
"0"
(
c
));
#endif
}
__device__
void
amd_assembly_inner_product
(
const
int8x4_t
&
a
,
const
int8x4_t
&
b
,
int32_t
&
c
)
{
#if 1
asm
volatile
(
"
\n
\
v_dot4_i32_i8 %0, %1, %2, %0
\n
\
"
:
"=v"
(
c
)
:
"v"
(
as_type
<
int32_t
>
(
a
)),
"v"
(
as_type
<
int32_t
>
(
b
)),
"0"
(
c
));
#else
c
=
__builtin_amdgcn_sdot4
(
as_type
<
int32_t
>
(
a
),
as_type
<
int32_t
>
(
b
),
c
,
false
);
#endif
}
__device__
void
amd_assembly_inner_product
(
const
int8x8_t
&
a
,
const
int8x8_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
amd_assembly_inner_product
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
8
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
8
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
}
__device__
void
amd_assembly_inner_product
(
const
int8x16_t
&
a
,
const
int8x16_t
&
b
,
int32_t
&
c
)
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I0
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I0
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I1
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I1
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I2
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I2
],
c
);
amd_assembly_inner_product
(
vector_type
<
int8_t
,
16
>
{
a
}.
AsType
<
int8x4_t
>
()[
I3
],
vector_type
<
int8_t
,
16
>
{
b
}.
AsType
<
int8x4_t
>
()[
I3
],
c
);
}
// c0 += inner_product(a, b0)
// c1 += inner_product(a, b1)
__device__
void
amd_assembly_outer_product_1x2
(
float
a
,
float
b0
,
float
b1
,
float
&
c0
,
float
&
c1
)
...
...
composable_kernel/include/utility/buffer.hpp
0 → 100644
View file @
2178d1d8
#ifndef CK_BUFFER_HPP
#define CK_BUFFER_HPP
#include "statically_indexed_array.hpp"
namespace
ck
{
template
<
typename
T
,
index_t
N
>
struct
StaticBuffer
:
public
StaticallyIndexedArray
<
T
,
N
>
{
using
type
=
T
;
using
base
=
StaticallyIndexedArray
<
T
,
N
>
;
__host__
__device__
constexpr
StaticBuffer
()
:
base
{}
{}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
false
;
}
};
template
<
typename
T
,
index_t
N
>
__host__
__device__
constexpr
auto
make_static_buffer
(
Number
<
N
>
)
{
return
StaticBuffer
<
T
,
N
>
{};
}
template
<
typename
T
>
struct
DynamicBuffer
{
using
type
=
T
;
T
*
p_data_
;
__host__
__device__
constexpr
DynamicBuffer
(
T
*
p_data
)
:
p_data_
{
p_data
}
{}
__host__
__device__
constexpr
const
T
&
operator
[](
index_t
i
)
const
{
return
p_data_
[
i
];
}
__host__
__device__
constexpr
T
&
operator
()(
index_t
i
)
{
return
p_data_
[
i
];
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
const
auto
Get
(
index_t
i
)
const
{
return
*
reinterpret_cast
<
const
X
*>
(
&
p_data_
[
i
]);
}
template
<
typename
X
,
typename
std
::
enable_if
<
is_same
<
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
X
>
>>::
type
,
typename
scalar_type
<
remove_cv_t
<
remove_reference_t
<
T
>>>::
type
>::
value
,
bool
>::
type
=
false
>
__host__
__device__
void
Set
(
index_t
i
,
const
X
&
x
)
{
*
reinterpret_cast
<
X
*>
(
&
p_data_
[
i
])
=
x
;
}
__host__
__device__
static
constexpr
bool
IsStaticBuffer
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsDynamicBuffer
()
{
return
true
;
}
};
template
<
typename
T
>
__host__
__device__
constexpr
auto
make_dynamic_buffer
(
T
*
p
)
{
return
DynamicBuffer
<
T
>
{
p
};
}
}
// namespace ck
#endif
composable_kernel/include/utility/common_header.hpp
View file @
2178d1d8
...
...
@@ -7,6 +7,7 @@
#include "statically_indexed_array.hpp"
#include "container_element_picker.hpp"
#include "float_type.hpp"
#include "buffer.hpp"
#include "functional.hpp"
#include "functional2.hpp"
#include "functional3.hpp"
...
...
composable_kernel/include/utility/config.amd.hpp.in
View file @
2178d1d8
...
...
@@ -14,11 +14,11 @@
#define CK_DEVICE_BACKEND_AMD 1
// GPU ID
#if
1
#if
0
#define CK_AMD_GPU_GFX906 1
#elif 0
#define CK_AMD_GPU_GFX908 1
#elif
0
#elif
1
#define CK_AMD_GPU_GFX1030 1
#endif
...
...
@@ -28,7 +28,7 @@
#endif
// launch bounds
#define CK_USE_LAUNCH_BOUNDS
0
#define CK_USE_LAUNCH_BOUNDS
1
#ifdef CK_USE_LAUNCH_BOUNDS
#define CK_MAX_THREAD_PER_BLOCK 256
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
2178d1d8
#ifndef CK_FLOAT_TYPE_AMD_HPP
#define CK_FLOAT_TYPE_AMD_HPP
#include "statically_indexed_array.hpp"
namespace ck {
using half_t = _Float16;
...
...
@@ -43,6 +45,15 @@ struct vector_type_maker<vector_type<T, N1>, N0>
using type = vector_type<T, N0 * N1>;
};
template <typename T, index_t N>
using vector_type_maker_t = typename vector_type_maker<T, N>::type;
template <typename T, index_t N>
__host__ __device__ constexpr auto make_vector_type(Number<N>)
{
return typename vector_type_maker<T, N>::type{};
}
// scalar_type
template <typename TV>
struct scalar_type;
...
...
@@ -403,32 +414,249 @@ struct vector_type<T, 16>
}
};
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion
template <typename T>
...
...
composable_kernel/include/utility/sequence_helper.hpp
View file @
2178d1d8
...
...
@@ -5,11 +5,26 @@
namespace
ck
{
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
make_sequence
(
Number
<
Is
>
...)
{
return
Sequence
<
Is
...
>
{};
}
// F returns index_t
template
<
typename
F
,
index_t
N
>
__host__
__device__
constexpr
auto
generate_sequence
(
F
,
Number
<
N
>
)
{
return
typename
sequence_gen
<
N
,
F
>::
type
{};
}
// F returns Number<>
template
<
typename
F
,
index_t
N
>
__host__
__device__
constexpr
auto
generate_sequence_v2
(
F
&&
f
,
Number
<
N
>
)
{
return
unpack
([
&
f
](
auto
&&
...
xs
)
{
return
make_sequence
(
f
(
xs
)...);
},
typename
arithmetic_sequence_gen
<
0
,
N
,
1
>::
type
{});
}
}
// namespace ck
#endif
driver/include/device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk.hpp
View file @
2178d1d8
...
...
@@ -53,7 +53,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
constexpr
auto
C0
=
C
/
Number
<
InWeiVectorSize
>
{};
constexpr
auto
C1
=
Number
<
InWeiVectorSize
>
{};
#if
1
#if
0
// run-time variables
constexpr auto in_n_hi_wi_c0_desc =
make_dynamic_naive_tensor_descriptor_packed_v2(make_multi_index(N, Hi, Wi, C0));
...
...
@@ -112,7 +112,7 @@ void device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk(
wei_k_y_x_c_device_buf
.
ToDevice
(
wei_k_y_x_c
.
mData
.
data
());
out_n_ho_wo_k_device_buf
.
ToDevice
(
out_n_ho_wo_k
.
mData
.
data
());
#if
0
#if
1
// cdata = 16, BlockSize = 64, 16x64x4
constexpr
index_t
BlockSize
=
64
;
...
...
driver/src/conv_driver.cpp
View file @
2178d1d8
...
...
@@ -64,7 +64,7 @@ int main(int argc, char* argv[])
using
LeftPads
=
Sequence
<
0
,
0
>
;
using
RightPads
=
Sequence
<
0
,
0
>
;
#elif
0
#elif
1
constexpr
index_t
N
=
1
;
constexpr
index_t
C
=
16
;
constexpr
index_t
HI
=
1080
;
...
...
@@ -630,7 +630,7 @@ int main(int argc, char* argv[])
print_array
(
"ConvStrides"
,
to_multi_index
(
ConvStrides
{}));
print_array
(
"ConvDilations"
,
to_multi_index
(
ConvDilations
{}));
#if
1
#if
0
using in_data_t = float;
constexpr index_t in_vector_size = 1;
using acc_data_t = float;
...
...
@@ -724,23 +724,22 @@ int main(int argc, char* argv[])
LeftPads
{},
RightPads
{},
nrepeat
);
#elif
1
#elif
0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nchw_kcyx_nkhw
<
in_data_t
,
in_vector_size
,
acc_data_t
,
out_data_t
>
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
out_data_t
>
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
ConvStrides
{},
ConvDilations
{},
LeftPads
{},
RightPads
{},
nrepeat
);
#elif 0
device_dynamic_convolution_forward_implicit_gemm_v4r4_nhwc_kyxc_nhwk
<
in_data_t
,
in_vector_size
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment