Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0374f8de
"driver/src/conv_driver.cpp" did not exist on "cb6475c77d74f9d9f0a5fb2c0b80d5008fe420da"
Commit
0374f8de
authored
Apr 29, 2021
by
Chao Liu
Browse files
blockwise gemm does 3d*3d=4d
parent
4a661578
Changes
4
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
712 additions
and
0 deletions
+712
-0
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+175
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+449
-0
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+2
-0
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
...le_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
+86
-0
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
0374f8de
...
@@ -378,5 +378,180 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
...
@@ -378,5 +378,180 @@ struct BlockwiseGemm_km_kn_m0m1n0n1_v1r1
AThreadCopy
a_thread_copy_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
};
// C[M0, M1, N0, N1] += transpose(A[K, M0, M1]) * B[K, N0, N1]
// A and B are visable to the whole block, C is distributed among each thread
// Assume:
// 1. A:
// 1. ABlockDesc is known at compile-time
// 2. ABlockBuffer is DynamicBuffer
// 2. B:
// 1. ABlockDesc is known at compile-time
// 2. BBlockBuffer is DynamicBuffer
// 3. C:
// 1. CThreadDesc is known at compile-time
// 2. CThreadBuffer is StaticBuffer
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ABlockDesc
,
typename
BBlockDesc
,
typename
CThreadDesc
,
index_t
M1PerThread
,
index_t
N1PerThread
,
index_t
KPerThreadLoop
,
index_t
MLevel0ThreadCluster
,
index_t
NLevel0ThreadCluster
,
index_t
MLevel1ThreadCluster
,
index_t
NLevel1ThreadCluster
,
index_t
AThreadCopyScalarPerVector_M1
,
index_t
BThreadCopyScalarPerVector_N1
,
typename
std
::
enable_if
<
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
{
using
AIndex
=
MultiIndex
<
3
>
;
using
BIndex
=
MultiIndex
<
3
>
;
using
CIndex
=
MultiIndex
<
4
>
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
public:
__device__
BlockwiseGemm_km0m1_kn0n1_m0m1n0n1_v1
()
:
c_thread_origin_data_idx_
{
CalculateCThreadOriginDataIndex
(
get_thread_local_1d_id
())},
a_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I0
],
c_thread_origin_data_idx_
[
I1
])},
b_thread_copy_
{
make_tuple
(
0
,
c_thread_origin_data_idx_
[
I2
],
c_thread_origin_data_idx_
[
I3
])}
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
()
&&
CThreadDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
BlockSize
==
c_thread_cluster_desc_
.
GetElementSize
(),
"wrong! wrong blocksize"
);
static_assert
(
ABlockDesc
{}.
GetLength
(
I0
)
==
BBlockDesc
{}.
GetLength
(
I0
),
"wrong! K dimension not consistent"
);
}
__device__
static
CIndex
CalculateCThreadOriginDataIndex
(
index_t
thread_id
)
{
const
auto
thread_cluster_idx
=
c_thread_cluster_desc_
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
constexpr
index_t
MPerLevel0Cluster
=
M1PerThread
*
MLevel0ThreadCluster
;
constexpr
index_t
NPerLevel0Cluster
=
N1PerThread
*
NLevel0ThreadCluster
;
return
make_multi_index
(
0
,
thread_cluster_idx
[
I0
]
*
MPerLevel0Cluster
+
thread_cluster_idx
[
I2
]
*
M1PerThread
,
0
,
thread_cluster_idx
[
I1
]
*
NPerLevel0Cluster
+
thread_cluster_idx
[
I3
]
*
N1PerThread
);
}
__host__
__device__
static
constexpr
auto
GetCThreadClusterDescriptor
()
{
return
make_cluster_descriptor_v2
(
Sequence
<
MLevel1ThreadCluster
,
NLevel1ThreadCluster
,
MLevel0ThreadCluster
,
NLevel0ThreadCluster
>
{},
Sequence
<
0
,
1
,
2
,
3
>
{});
}
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
constexpr
auto
threadwise_gemm
=
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
<
FloatA
,
FloatB
,
FloatC
,
decltype
(
a_thread_desc_
),
decltype
(
b_thread_desc_
),
CThreadDesc
>
{};
constexpr
index_t
K
=
ABlockDesc
{}.
GetLength
(
I0
);
static_for
<
0
,
K
,
KPerThreadLoop
>
{}([
&
](
auto
k
)
{
a_thread_copy_
.
Run
(
ABlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
a_thread_buf
);
b_thread_copy_
.
Run
(
BBlockDesc
{},
make_tuple
(
k
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
);
threadwise_gemm
.
Run
(
a_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
b_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
),
c_thread_buf
,
make_tuple
(
I0
,
I0
,
I0
,
I0
));
});
}
private:
static
constexpr
auto
c_thread_cluster_desc_
=
GetCThreadClusterDescriptor
();
static
constexpr
index_t
M0_
=
ABlockDesc
{}.
GetLength
(
I1
);
static
constexpr
index_t
N0_
=
BBlockDesc
{}.
GetLength
(
I1
);
// A[K, M0, M1]
static
constexpr
auto
a_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
M0_
>
{},
Number
<
M1PerThread
>
{}));
// B[K, N0, N1]
static
constexpr
auto
b_thread_desc_
=
make_dynamic_naive_tensor_descriptor_packed_v2
(
make_tuple
(
Number
<
KPerThreadLoop
>
{},
Number
<
N0_
>
{},
Number
<
N1PerThread
>
{}));
using
AThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
ABlockDesc
,
decltype
(
a_thread_desc_
),
Sequence
<
KPerThreadLoop
,
M0_
,
M1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
AThreadCopyScalarPerVector_M1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
using
BThreadCopy
=
ThreadwiseDynamicTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
BBlockDesc
,
decltype
(
b_thread_desc_
),
Sequence
<
KPerThreadLoop
,
N0_
,
N1PerThread
>
,
Sequence
<
0
,
1
,
2
>
,
2
,
BThreadCopyScalarPerVector_N1
,
AddressSpace
::
Generic
,
AddressSpace
::
Vgpr
,
1
>
;
CIndex
c_thread_origin_data_idx_
;
AThreadCopy
a_thread_copy_
;
BThreadCopy
b_thread_copy_
;
};
}
// namespace ck
}
// namespace ck
#endif
#endif
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
0374f8de
This diff is collapsed.
Click to expand it.
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
0374f8de
...
@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1376,6 +1376,8 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
{
{
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
static_assert
(
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
SliceLengths
::
At
(
Number
<
SrcVectorDim
>
{})
%
SrcScalarPerVector
==
0
,
"wrong!"
);
}
}
template
<
typename
SrcRefToOriginDisplacement
,
template
<
typename
SrcRefToOriginDisplacement
,
...
...
composable_kernel/include/tensor_operation/threadwise_gemm_v2.hpp
View file @
0374f8de
...
@@ -140,5 +140,91 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
...
@@ -140,5 +140,91 @@ struct ThreadwiseGemm_km_kn_mn_v1r1
}
}
};
};
// C[M0, M1, N0, N1] += A[K, M0, M1] * B[K, N0, N1]
// Tensor element can be vectorized data
// Assume:
// 1. ADesc, BDesc, CDesc are known at compile-time
// 2. AOriginIdx, BOriginIdx, COriginIdx are known at compile-time
template
<
typename
FloatA
,
typename
FloatB
,
typename
FloatC
,
typename
ADesc
,
typename
BDesc
,
typename
CDesc
,
typename
std
::
enable_if
<
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseGemm_km0m1_kn0n1_m0m1n0n1
{
template
<
typename
ABuffer
,
typename
AOriginIdx
,
typename
BBuffer
,
typename
BOriginIdx
,
typename
CBuffer
,
typename
COriginIdx
>
__device__
static
void
Run
(
const
ABuffer
&
a_buf
,
AOriginIdx
,
const
BBuffer
&
b_buf
,
BOriginIdx
,
CBuffer
&
c_buf
,
COriginIdx
)
{
static_assert
(
ADesc
::
IsKnownAtCompileTime
()
&&
BDesc
::
IsKnownAtCompileTime
()
&&
CDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
AOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
BOriginIdx
>>>::
value
&&
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
COriginIdx
>>>::
value
,
"wrong! AOriginIdx, BOriginIdx, COringinIdx should be known at compile-time"
);
static_assert
(
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
ABuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatA
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
BBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatB
>>>::
value
&&
is_same
<
remove_cv_t
<
remove_reference_t
<
typename
CBuffer
::
type
>>
,
remove_cv_t
<
remove_reference_t
<
FloatC
>>>::
value
&&
"wrong! inconsistent type"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I2
=
Number
<
2
>
{};
constexpr
auto
I3
=
Number
<
3
>
{};
constexpr
auto
K
=
ADesc
{}.
GetLength
(
I0
);
constexpr
auto
M0
=
CDesc
{}.
GetLength
(
I0
);
constexpr
auto
M1
=
CDesc
{}.
GetLength
(
I1
);
constexpr
auto
N0
=
CDesc
{}.
GetLength
(
I2
);
constexpr
auto
N1
=
CDesc
{}.
GetLength
(
I3
);
constexpr
auto
a_origin_idx
=
to_multi_index
(
AOriginIdx
{});
constexpr
auto
b_origin_idx
=
to_multi_index
(
BOriginIdx
{});
constexpr
auto
c_origin_idx
=
to_multi_index
(
COriginIdx
{});
static_for
<
0
,
K
,
1
>
{}([
&
](
auto
k
)
{
static_for
<
0
,
M0
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
M1
,
1
>
{}([
&
](
auto
m1
)
{
static_for
<
0
,
N0
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
N1
,
1
>
{}([
&
](
auto
n1
)
{
constexpr
index_t
a_offset
=
ADesc
{}.
CalculateOffset
(
a_origin_idx
+
make_multi_index
(
k
,
m0
,
m1
));
constexpr
index_t
b_offset
=
BDesc
{}.
CalculateOffset
(
b_origin_idx
+
make_multi_index
(
k
,
n0
,
n1
));
constexpr
index_t
c_offset
=
CDesc
{}.
CalculateOffset
(
c_origin_idx
+
make_multi_index
(
m0
,
m1
,
n0
,
n1
));
amd_assembly_inner_product
(
a_buf
[
Number
<
a_offset
>
{}],
b_buf
[
Number
<
b_offset
>
{}],
c_buf
(
Number
<
c_offset
>
{}));
});
});
});
});
});
}
};
}
// namespace ck
}
// namespace ck
#endif
#endif
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment