Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
aeb05cc4
Commit
aeb05cc4
authored
Apr 22, 2021
by
Chao Liu
Browse files
use vector type for holding C thread matrix data, but it cause register over-allocation
parent
d990eff6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
315 additions
and
25 deletions
+315
-25
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
...ble_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
+1
-0
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
...kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
+13
-2
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
.../tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
+59
-0
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
...or_operation/threadwise_dynamic_tensor_slice_transfer.hpp
+16
-14
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+226
-9
No files found.
composable_kernel/include/tensor_operation/blockwise_gemm_v2.hpp
View file @
aeb05cc4
...
@@ -2,6 +2,7 @@
...
@@ -2,6 +2,7 @@
#define CK_BLOCKWISE_GEMM_V2_HPP
#define CK_BLOCKWISE_GEMM_V2_HPP
#include "common_header.hpp"
#include "common_header.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_gemm_v2.hpp"
#include "threadwise_gemm_v2.hpp"
namespace
ck
{
namespace
ck
{
...
...
composable_kernel/include/tensor_operation/gridwise_dynamic_gemm.hpp
View file @
aeb05cc4
...
@@ -5,9 +5,10 @@
...
@@ -5,9 +5,10 @@
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_multi_index_transform_helper.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
#include "blockwise_gemm_v2.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "blockwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "threadwise_dynamic_tensor_slice_transfer.hpp"
#include "
blockwise_gemm_v2
.hpp"
#include "
threadwise_dynamic_tensor_slice_set
.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -730,12 +731,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -730,12 +731,22 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
FloatAB
*
p_b_block_double
=
p_shared_block
+
2
*
a_block_space_size
;
// register allocation for output
// register allocation for output
#if 0
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
FloatAcc p_c_thread[c_m0m1_n0n1_thread_desc.GetElementSpaceSize()];
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
auto c_thread_buf = make_dynamic_buffer<FloatAcc>(p_c_thread);
// zero out threadwise output
// zero out threadwise output
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
threadwise_matrix_set_zero_v2(c_m0m1_n0n1_thread_desc, p_c_thread);
#else
auto
c_thread_buf
=
make_static_buffer
<
FloatAcc
>
(
c_m0m1_n0n1_thread_desc
.
GetElementSpaceSize
());
ThreadwiseDynamicTensorSliceSet_v1
<
FloatAcc
,
decltype
(
c_m0m1_n0n1_thread_desc
),
Sequence
<
MRepeat
*
MPerThread
,
NRepeat
*
NPerThread
>>
{}
.
Run
(
c_m0m1_n0n1_thread_desc
,
make_tuple
(
I0
,
I0
),
c_thread_buf
,
FloatAcc
{
0
});
#endif
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
a_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
constexpr
auto
b_block_slice_copy_step
=
make_multi_index
(
KPerBlock
,
0
);
...
@@ -916,7 +927,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
...
@@ -916,7 +927,7 @@ struct GridwiseDynamicGemm_km_kn_m0m1n0n1_v1
n_thread_data_on_global
%
N1
))
n_thread_data_on_global
%
N1
))
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
.
Run
(
c_m0_m1_n0_n1_thread_desc
,
make_tuple
(
I0
,
I0
,
I0
,
I0
),
make_tuple
(
I0
,
I0
,
I0
,
I0
),
p_
c_thread
,
c_thread
_buf
,
c_m0_m1_n0_n1_global_desc
,
c_m0_m1_n0_n1_global_desc
,
p_c_global
,
p_c_global
,
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
c_m0_m1_n0_n1_global_tensor_iterator_hacks
);
...
...
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_set.hpp
0 → 100644
View file @
aeb05cc4
#ifndef CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#define CK_THREADWISE_DYNAMIC_TENSOR_SET_HPP
#include "common_header.hpp"
#include "dynamic_tensor_descriptor.hpp"
#include "dynamic_tensor_descriptor_helper.hpp"
namespace
ck
{
// Assume:
// 1. Desc is known at compile-time
// 2. Buffer is StaticBuffer
// 3. OriginIdx is known at compile-time
// 4. use #-iterator
template
<
typename
Data
,
typename
Desc
,
typename
SliceLengths
,
typename
std
::
enable_if
<
Desc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
struct
ThreadwiseDynamicTensorSliceSet_v1
{
static
constexpr
index_t
nDim
=
SliceLengths
::
Size
();
using
Index
=
MultiIndex
<
nDim
>
;
template
<
typename
OriginIdx
,
typename
Buffer
>
__device__
void
Run
(
const
Desc
&
,
const
OriginIdx
&
,
Buffer
&
buf
,
const
Data
&
initial_value
)
const
{
static_assert
(
Desc
::
IsKnownAtCompileTime
(),
"wrong! SrcDesc and DstDesc need to known at compile-time"
);
static_assert
(
Buffer
::
IsStaticBuffer
(),
"wrong! DstBuffer need to be StaticBuffer"
);
static_assert
(
is_known_at_compile_time
<
remove_cv_t
<
remove_reference_t
<
OriginIdx
>>>::
value
,
"wrong! OriginIdx need to be known at compile-time"
);
// Desc is known at compile-time
constexpr
auto
desc
=
remove_cv_t
<
remove_reference_t
<
Desc
>>
{};
// OriginIdx is known at compile-time
constexpr
auto
origin_idx
=
to_multi_index
(
OriginIdx
{});
static_ford
<
SliceLengths
>
{}([
&
](
auto
access_idx
)
{
constexpr
auto
coord
=
make_dynamic_tensor_coordinate
(
desc
,
origin_idx
+
access_idx
);
constexpr
bool
is_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
desc
,
coord
);
constexpr
index_t
offset
=
coord
.
GetOffset
();
if
constexpr
(
is_valid
)
{
buf
.
template
AsType
<
Data
>()(
Number
<
offset
>
{})
=
initial_value
;
}
});
}
};
}
// namespace ck
#endif
composable_kernel/include/tensor_operation/threadwise_dynamic_tensor_slice_transfer.hpp
View file @
aeb05cc4
...
@@ -38,11 +38,14 @@ struct lambda_scalar_step_in_vector
...
@@ -38,11 +38,14 @@ struct lambda_scalar_step_in_vector
}
// namespace detail
}
// namespace detail
// Assume:
// Assume:
// 1. src_desc is known at compile-time
// 1. src:
// 2. dst_desc is not known at compile-time
// 1. SrcDesc is known at compile-time
// 3. src_slice_origin_idx is known at compile-time and it's 0
// 2. SrcBuffer is StaticBuffer
// 4. dst_slice_origin_idx is not-known at compile time
// 3. SrcSliceOrginIdx is known at compile-time
// TODO: support non-zero src_slice_oring_idx
// 2. dst:
// 1. DstDesc is not known at compile-time
// 2. DstBuffer is DynamicBuffer
// 3. DstSliceOrginIdx is not known at compile time
template
<
typename
SrcData
,
template
<
typename
SrcData
,
typename
DstData
,
typename
DstData
,
typename
SrcDesc
,
typename
SrcDesc
,
...
@@ -80,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -80,10 +83,10 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
dst_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
dst_slice_origin_coord_
=
make_dynamic_tensor_coordinate
(
dst_desc
,
dst_slice_origin_idx
);
}
}
template
<
typename
SrcSliceOriginIdx
,
typename
DstIteratorHacks
>
template
<
typename
SrcSliceOriginIdx
,
typename
SrcBuffer
,
typename
DstIteratorHacks
>
__device__
void
Run
(
const
SrcDesc
&
,
__device__
void
Run
(
const
SrcDesc
&
,
const
SrcSliceOriginIdx
&
,
const
SrcSliceOriginIdx
&
,
const
Src
Data
*
p_src
,
const
Src
Buffer
&
src_buf
,
const
DstDesc
&
dst_desc
,
const
DstDesc
&
dst_desc
,
DstData
*
p_dst
,
DstData
*
p_dst
,
const
DstIteratorHacks
&
dst_iterator_hacks
)
const
DstIteratorHacks
&
dst_iterator_hacks
)
...
@@ -97,7 +100,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -97,7 +100,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
// SrcDesc and src_slice_origin_idx are known at compile-time
// SrcDesc and src_slice_origin_idx are known at compile-time
constexpr
auto
src_desc
=
remove_cv_t
<
remove_reference_t
<
SrcDesc
>>
{};
constexpr
auto
src_desc
=
remove_cv_t
<
remove_reference_t
<
SrcDesc
>>
{};
constexpr
auto
src_slice_origin_idx
=
SrcSliceOriginIdx
{};
constexpr
auto
src_slice_origin_idx
=
to_multi_index
(
SrcSliceOriginIdx
{}
)
;
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -189,12 +192,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
...
@@ -189,12 +192,11 @@ struct ThreadwiseDynamicTensorSliceTransfer_v1r3
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
typename
vector_type_maker
<
DstData
,
DstScalarPerVector
>::
type
::
type
;
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
DstScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
constexpr
index_t
src_offset
=
constexpr
index_t
src_offset
=
src_desc
.
CalculateOffset
(
src_desc
.
CalculateOffset
(
to_multi_index
(
src_slice_origin_idx
)
+
dst_data_idx
+
src_slice_origin_idx
+
dst_data_idx
+
i
*
dst_scalar_step_in_vector
);
i
*
dst_scalar_step_in_vector
);
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
dst_vector
.
template
AsType
<
DstData
>()(
i
)
=
type_convert
<
DstData
>
{}(
type_convert
<
DstData
>
{}(
p_src
[
Number
<
src_offset
>
{}]);
src_buf
.
template
AsType
<
SrcData
>()
[
Number
<
src_offset
>
{}]);
});
});
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
const
bool
is_dst_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
...
@@ -1489,7 +1491,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
...
@@ -1489,7 +1491,7 @@ struct ThreadwiseDynamicTensorSliceTransfer_v4
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
// TODO: if SrcData and DstData are vetor type, then static_cast may not compile
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
static_for
<
0
,
SrcScalarPerVector
,
1
>
{}([
&
](
auto
i
)
{
dst_tmp_buf
.
template
AsType
<
DstData
>()(
i
)
=
dst_tmp_buf
.
template
AsType
<
DstData
>()(
i
)
=
static_cas
t
<
DstData
>
(
src_tmp_buf
.
template
AsType
<
SrcData
>()[
i
]);
type_conver
t
<
DstData
>
{}
(
src_tmp_buf
.
template
AsType
<
SrcData
>()[
i
]);
});
});
// copy data from dst_tmp_buf into dst_buf
// copy data from dst_tmp_buf into dst_buf
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
aeb05cc4
...
@@ -403,32 +403,249 @@ struct vector_type<T, 16>
...
@@ -403,32 +403,249 @@ struct vector_type<T, 16>
}
}
};
};
template <typename T>
struct vector_type<T, 32>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
using type = d32_t;
union
{
d32_t d32_;
StaticallyIndexedArray<d1_t, 32> d1x32_;
StaticallyIndexedArray<d2_t, 16> d2x16_;
StaticallyIndexedArray<d4_t, 8> d4x8_;
StaticallyIndexedArray<d8_t, 4> d8x4_;
StaticallyIndexedArray<d16_t, 2> d16x2_;
StaticallyIndexedArray<d32_t, 1> d32x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x32_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x16_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x8_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x4_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x2_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x1_;
}
}
};
template <typename T>
struct vector_type<T, 64>
{
using d1_t = T;
typedef T d2_t __attribute__((ext_vector_type(2)));
typedef T d4_t __attribute__((ext_vector_type(4)));
typedef T d8_t __attribute__((ext_vector_type(8)));
typedef T d16_t __attribute__((ext_vector_type(16)));
typedef T d32_t __attribute__((ext_vector_type(32)));
typedef T d64_t __attribute__((ext_vector_type(64)));
using type = d64_t;
union
{
d64_t d64_;
StaticallyIndexedArray<d1_t, 64> d1x64_;
StaticallyIndexedArray<d2_t, 32> d2x32_;
StaticallyIndexedArray<d4_t, 16> d4x16_;
StaticallyIndexedArray<d8_t, 8> d8x8_;
StaticallyIndexedArray<d16_t, 4> d16x4_;
StaticallyIndexedArray<d32_t, 2> d32x2_;
StaticallyIndexedArray<d64_t, 1> d64x1_;
} data_;
__host__ __device__ constexpr vector_type() : data_{type{0}} {}
__host__ __device__ constexpr vector_type(type v) : data_{v} {}
template <typename X>
__host__ __device__ constexpr const auto& AsType() const
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
template <typename X>
__host__ __device__ constexpr auto& AsType()
{
static_assert(is_same<X, d1_t>::value || is_same<X, d2_t>::value ||
is_same<X, d4_t>::value || is_same<X, d8_t>::value ||
is_same<X, d16_t>::value || is_same<X, d32_t>::value ||
is_same<X, d64_t>::value,
"wrong!");
if constexpr(is_same<X, d1_t>::value)
{
return data_.d1x64_;
}
else if constexpr(is_same<X, d2_t>::value)
{
return data_.d2x32_;
}
else if constexpr(is_same<X, d4_t>::value)
{
return data_.d4x16_;
}
else if constexpr(is_same<X, d8_t>::value)
{
return data_.d8x8_;
}
else if constexpr(is_same<X, d16_t>::value)
{
return data_.d16x4_;
}
else if constexpr(is_same<X, d32_t>::value)
{
return data_.d32x2_;
}
else if constexpr(is_same<X, d64_t>::value)
{
return data_.d64x1_;
}
}
};
// fp32
// fp32
using float2_t = typename vector_type<float, 2>::type;
using float2_t = typename vector_type<float, 2>::type;
using float4_t = typename vector_type<float, 4>::type;
using float4_t = typename vector_type<float, 4>::type;
using float8_t = typename vector_type<float, 8>::type;
using float8_t = typename vector_type<float, 8>::type;
using float16_t = typename vector_type<float, 16>::type;
using float32_t = typename vector_type<float, 32>::type;
using float64_t = typename vector_type<float, 64>::type;
// fp16
// fp16
using half2_t = typename vector_type<half_t, 2>::type;
using half2_t = typename vector_type<half_t, 2>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half4_t = typename vector_type<half_t, 4>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half8_t = typename vector_type<half_t, 8>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half16_t = typename vector_type<half_t, 16>::type;
using half32_t = typename vector_type<half_t, 32>::type;
using half64_t = typename vector_type<half_t, 64>::type;
// bfp16
// bfp16
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort2_t = typename vector_type<ushort, 2>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort4_t = typename vector_type<ushort, 4>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort8_t = typename vector_type<ushort, 8>::type;
using ushort16_t = typename vector_type<ushort, 16>::type;
using ushort32_t = typename vector_type<ushort, 32>::type;
using ushort64_t = typename vector_type<ushort, 64>::type;
// i32
// i32
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x2_t = typename vector_type<int32_t, 2>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x4_t = typename vector_type<int32_t, 4>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x8_t = typename vector_type<int32_t, 8>::type;
using int32x16_t = typename vector_type<int32_t, 16>::type;
using int32x32_t = typename vector_type<int32_t, 32>::type;
using int32x64_t = typename vector_type<int32_t, 64>::type;
// i8
// i8
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x2_t = typename vector_type<int8_t, 2>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x4_t = typename vector_type<int8_t, 4>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x8_t = typename vector_type<int8_t, 8>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x16_t = typename vector_type<int8_t, 16>::type;
using int8x32_t = typename vector_type<int8_t, 32>::type;
using int8x64_t = typename vector_type<int8_t, 64>::type;
// data type conversion
// data type conversion
template <typename T>
template <typename T>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment