Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
33b5a855
Commit
33b5a855
authored
May 16, 2019
by
Chao Liu
Browse files
adding implicit gemm v3
parent
5e5c27a6
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
172 additions
and
197 deletions
+172
-197
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+1
-1
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
+1
-1
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
...er/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
+2
-2
driver/driver.hip.cpp
driver/driver.hip.cpp
+5
-5
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+14
-0
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+19
-41
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+105
-118
src/include/blockwise_gemm.hip.hpp
src/include/blockwise_gemm.hip.hpp
+2
-1
src/include/blockwise_tensor_slice_op.hip.hpp
src/include/blockwise_tensor_slice_op.hip.hpp
+13
-20
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+5
-3
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+1
-1
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
+2
-2
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
+2
-2
No files found.
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
33b5a855
...
@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
...
@@ -140,7 +140,7 @@ void device_convolution_implicit_gemm_v1_chwn_cyxk_khwn(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_N
=
2
;
#elif
0
#elif
1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 28x28, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
// for 3x3, 14x14, v1r3, Pascal
...
...
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp
View file @
33b5a855
...
@@ -64,7 +64,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
...
@@ -64,7 +64,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_khwn(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
out_khwn_device_buf
.
ToDevice
(
out_khwn
.
mData
.
data
());
#if
0
#if
1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 34x34, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
...
driver/device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp
View file @
33b5a855
...
@@ -57,7 +57,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...
@@ -57,7 +57,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
wei_cyxk_device_buf
.
ToDevice
(
wei_cyxk
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
out_nkhw_device_buf
.
ToDevice
(
out_nkhw
.
mData
.
data
());
#if
0
#if
1
// for 3x3, 34x34, v1r3, Pascal
// for 3x3, 34x34, v1r3, Pascal
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
@@ -162,7 +162,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
...
@@ -162,7 +162,7 @@ void device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw(InDesc,
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
WeiBlockCopyDataPerRead_K
=
4
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
2
;
constexpr
index_t
OutThreadCopyDataPerWrite_W
=
2
;
#elif
1
#elif
0
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
// for 3x3, 34x34, v1r3, Vega 20, WoPerBlock = 8
constexpr
index_t
BlockSize
=
256
;
constexpr
index_t
BlockSize
=
256
;
...
...
driver/driver.hip.cpp
View file @
33b5a855
...
@@ -13,7 +13,7 @@
...
@@ -13,7 +13,7 @@
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
//
#include "device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw.hpp"
struct
GeneratorTensor_1
struct
GeneratorTensor_1
{
{
...
@@ -411,7 +411,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
...
@@ -411,7 +411,7 @@ void check_error(const Tensor<T>& ref, const Tensor<T>& result)
int
main
(
int
argc
,
char
*
argv
[])
int
main
(
int
argc
,
char
*
argv
[])
{
{
#if
0
#if
1
// 3x3, 34x34
// 3x3, 34x34
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -435,7 +435,7 @@ int main(int argc, char* argv[])
...
@@ -435,7 +435,7 @@ int main(int argc, char* argv[])
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
#elif
1
#elif
0
// 3x3 filter, 28x28 image
// 3x3 filter, 28x28 image
constexpr
index_t
N
=
128
;
constexpr
index_t
N
=
128
;
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
...
@@ -608,7 +608,7 @@ int main(int argc, char* argv[])
...
@@ -608,7 +608,7 @@ int main(int argc, char* argv[])
device_convolution_direct_v2_nchw_kcyx_nkhw
device_convolution_direct_v2_nchw_kcyx_nkhw
#elif 0
#elif 0
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif
0
#elif
1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
...
@@ -616,7 +616,7 @@ int main(int argc, char* argv[])
...
@@ -616,7 +616,7 @@ int main(int argc, char* argv[])
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
device_convolution_implicit_gemm_v1_nchw_cyxk_nkhw
#elif 0
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#elif
1
#elif
0
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
device_convolution_implicit_gemm_v3_nchw_cyxk_nkhw
#endif
#endif
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
(
in_nchw_desc
,
in_nchw
,
wei_kcyx_desc
,
wei_kcyx
,
out_nkhw_desc
,
out_nkhw_device
,
nrepeat
);
...
...
src/include/Array.hip.hpp
View file @
33b5a855
...
@@ -66,3 +66,17 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
...
@@ -66,3 +66,17 @@ __host__ __device__ auto reorder_array_given_old2new(const Array<TData, NSize>&
return
new_array
;
return
new_array
;
}
}
template
<
class
TData
,
index_t
NSize
>
__host__
__device__
constexpr
auto
operator
+
(
const
Array
<
TData
,
NSize
>&
a
,
const
Array
<
TData
,
NSize
>&
b
)
{
Array
<
TData
,
NSize
>
result
;
static_for
<
0
,
NSize
,
1
>
{}([
&
](
auto
I
)
{
constexpr
index_t
i
=
I
.
Get
();
result
[
i
]
=
a
[
i
]
+
b
[
i
];
});
return
result
;
}
src/include/ConstantTensorDescriptor.hip.hpp
View file @
33b5a855
...
@@ -88,32 +88,11 @@ struct ConstantTensorDescriptor
...
@@ -88,32 +88,11 @@ struct ConstantTensorDescriptor
return
accumulate_on_sequence
(
Lengths
{},
std
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
return
accumulate_on_sequence
(
Lengths
{},
std
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
}
}
#if 0
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct f_GetElementSpace_impl
{
template <class IDim>
__host__ __device__ constexpr index_t operator()(IDim idim) const
{
return (Type{}.GetLength(idim) - 1) * Type{}.GetStride(idim);
}
__host__ __device__ constexpr index_t operator()(index_t length, index_t stride) const
{
return (length - 1) * stride;
}
};
#endif
template
<
class
Align
=
Number
<
1
>
>
template
<
class
Align
=
Number
<
1
>
>
__host__
__device__
static
constexpr
index_t
GetElementSpace
(
Align
align
=
Align
{})
__host__
__device__
static
constexpr
index_t
GetElementSpace
(
Align
align
=
Align
{})
{
{
#if 0
index_t element_space_unaligned =
static_const_reduce_n<nDim>{}(f_GetElementSpace_impl{}, std::plus<index_t>{}) + 1;
#else
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
constexpr
index_t
element_space_unaligned
=
accumulate_on_sequence
(
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
std
::
plus
<
index_t
>
{},
Number
<
1
>
{});
(
GetLengths
()
-
Number
<
1
>
{})
*
GetStrides
(),
std
::
plus
<
index_t
>
{},
Number
<
1
>
{});
#endif
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
}
}
...
@@ -150,10 +129,7 @@ struct ConstantTensorDescriptor
...
@@ -150,10 +129,7 @@ struct ConstantTensorDescriptor
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
constexpr
auto
multi_id
=
Sequence
<
Is
...
>
{};
constexpr
auto
seq_tmp
=
return
accumulate_on_sequence
(
multi_id
*
GetStrides
(),
std
::
plus
<
index_t
>
{},
Number
<
0
>
{});
transform_sequences
(
std
::
multiplies
<
index_t
>
{},
multi_id
,
GetStrides
());
return
accumulate_on_sequence
(
seq_tmp
,
std
::
plus
<
index_t
>
{},
Number
<
0
>
{});
}
}
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
...
@@ -177,14 +153,14 @@ struct ConstantTensorDescriptor
...
@@ -177,14 +153,14 @@ struct ConstantTensorDescriptor
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
}
}
template
<
index_t
IDims
...
>
template
<
index_t
...
IDims
>
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
IDims
>
...
extract_dims
)
__host__
__device__
static
constexpr
auto
Extract
(
Number
<
IDims
>
...
extract_dims
)
{
{
static_assert
(
sizeof
...(
IDims
)
<=
GetNumOfDimension
(),
static_assert
(
sizeof
...(
IDims
)
<=
GetNumOfDimension
(),
"wrong! too many number of dimensions to be extracted"
);
"wrong! too many number of dimensions to be extracted"
);
return
make_ConstantTensorDescriptor
(
Lengths
{}.
Extract
(
extract_dims
),
return
make_ConstantTensorDescriptor
(
Lengths
{}.
Extract
(
extract_dims
...
),
Strides
{}.
Extract
(
extract_dims
));
Strides
{}.
Extract
(
extract_dims
...
));
}
}
template
<
index_t
IDim
,
index_t
SliceLen
>
template
<
index_t
IDim
,
index_t
SliceLen
>
...
@@ -195,11 +171,11 @@ struct ConstantTensorDescriptor
...
@@ -195,11 +171,11 @@ struct ConstantTensorDescriptor
}
}
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
template
<
index_t
IDim
,
index_t
...
FoldIntervals
>
__host__
device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Number
<
FoldIntervals
>
...)
__host__
__
device__
static
constexpr
auto
Fold
(
Number
<
IDim
>
,
Number
<
FoldIntervals
>
...)
{
{
constexpr
auto
fold_intervals
=
Sequence
<
FoldIntervals
...
>
{};
constexpr
auto
fold_intervals
=
Sequence
<
FoldIntervals
...
>
{};
constexpr
fold_intervals_product
=
constexpr
index_t
fold_intervals_product
=
accumulate_on_sequence
(
fold_intervals
,
std
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
accumulate_on_sequence
(
fold_intervals
,
std
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
auto
unfold_length
=
GetLength
(
Number
<
IDim
>
{});
constexpr
auto
unfold_length
=
GetLength
(
Number
<
IDim
>
{});
...
@@ -207,29 +183,31 @@ struct ConstantTensorDescriptor
...
@@ -207,29 +183,31 @@ struct ConstantTensorDescriptor
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// length of the dimension to be folded needs to be dividable by fold_interval_product,
// otherwise, folding is invalid
// otherwise, folding is invalid
static_assert
(
unfold_length
%
fold_interval_product
==
0
,
static_assert
(
unfold_length
%
fold_interval
s
_product
==
0
,
"wrong! length on the dimension to be folded cannot be evenly divided!"
);
"wrong! length on the dimension to be folded cannot be evenly divided!"
);
// folded lengths
// folded lengths
constexpr
auto
fold_lengths
=
constexpr
auto
fold_lengths
=
Sequence
<
unfold_length
/
fold_interval_product
>
{}.
Append
(
fold_intervals
);
Sequence
<
unfold_length
/
fold_interval
s
_product
>
{}.
Append
(
fold_intervals
);
// folded strides
// folded strides
constexpr
auto
fold_strides
=
transform_sequences
(
mod_conv
::
scales
<
index_t
,
unfold_stride
>
{},
constexpr
auto
fold_strides
=
Number
<
unfold_stride
>
{}
*
reverse_scan_sequence
(
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
std
::
multiplies
<
index_t
>
{});
reverse_scan_sequence
(
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
std
::
multiplies
<
index_t
>
{});
// left and right lengths
// left and right lengths
constexpr
auto
lengths_pair
=
GetLengths
().
Split
(
Number
<
I
>
{});
constexpr
auto
lengths_pair
=
GetLengths
().
Split
(
Number
<
I
Dim
>
{});
constexpr
auto
left_lengths
=
lengths_pair
.
first
;
constexpr
auto
left_lengths
=
lengths_pair
.
first
;
constexpr
auto
right_lengths
=
lengths_pair
.
second
.
PopFront
();
constexpr
auto
right_lengths
=
lengths_pair
.
second
.
PopFront
();
// left and right strides
// left and right strides
constexpr
auto
strides_pair
=
GetStrides
().
Split
(
Number
<
I
>
{});
constexpr
auto
strides_pair
=
GetStrides
().
Split
(
Number
<
I
Dim
>
{});
constexpr
auto
left_strides
=
strides_pair
.
first
;
constexpr
auto
left_strides
=
strides_pair
.
first
;
constexpr
auto
right_strides
=
strides_pair
.
second
.
PopFront
();
constexpr
auto
right_strides
=
strides_pair
.
second
.
PopFront
();
return
make_ConstantTensorDescriptor
(
left_lengths
.
Append
(
fold_lengths
).
Append
(
right_lengths
),
return
make_ConstantTensorDescriptor
(
left_strides
.
Append
(
fold_strides
).
Append
(
right_strides
));
left_lengths
.
Append
(
fold_lengths
).
Append
(
right_lengths
),
left_strides
.
Append
(
fold_strides
).
Append
(
right_strides
));
}
}
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
>
template
<
index_t
FirstUnfoldDim
,
index_t
LastUnfoldDim
>
...
@@ -264,8 +242,8 @@ struct ConstantTensorDescriptor
...
@@ -264,8 +242,8 @@ struct ConstantTensorDescriptor
constexpr
index_t
unfold_length
=
constexpr
index_t
unfold_length
=
accumulate_on_sequence
(
fold_lengths
,
std
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
accumulate_on_sequence
(
fold_lengths
,
std
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
constexpr
auto
new_
stride
s
=
constexpr
auto
new_
length
s
=
left_
stride
s
.
PopBack
(
Number
<
unfold_
strides
>
{}).
Append
(
right_
stride
s
);
left_
length
s
.
PopBack
(
Number
<
unfold_
length
>
{}).
Append
(
right_
length
s
);
// strides
// strides
constexpr
auto
strides_pair1
=
Strides
{}.
Split
(
Number
<
LastUnfoldDim
+
1
>
{});
constexpr
auto
strides_pair1
=
Strides
{}.
Split
(
Number
<
LastUnfoldDim
+
1
>
{});
...
@@ -281,7 +259,7 @@ struct ConstantTensorDescriptor
...
@@ -281,7 +259,7 @@ struct ConstantTensorDescriptor
constexpr
index_t
unfold_stride
=
fold_strides
.
Back
();
constexpr
index_t
unfold_stride
=
fold_strides
.
Back
();
constexpr
auto
new_strides
=
constexpr
auto
new_strides
=
left_strides
.
PushBack
(
Number
<
unfold_stride
s
>
{}).
Append
(
right_strides
);
left_strides
.
PushBack
(
Number
<
unfold_stride
>
{}).
Append
(
right_strides
);
return
make_ConstantTensorDescriptor
(
new_lengths
,
new_strides
);
return
make_ConstantTensorDescriptor
(
new_lengths
,
new_strides
);
}
}
...
@@ -289,7 +267,7 @@ struct ConstantTensorDescriptor
...
@@ -289,7 +267,7 @@ struct ConstantTensorDescriptor
template
<
index_t
...
IRs
>
template
<
index_t
...
IRs
>
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
__host__
__device__
static
constexpr
auto
ReorderGivenNew2Old
(
Sequence
<
IRs
...
>
/*new2old*/
)
{
{
static_assert
(
sizeof
...(
IRs
)
==
GetNum
ber
OfDimension
(),
"wrong! dimension is wrong"
);
static_assert
(
sizeof
...(
IRs
)
==
GetNumOfDimension
(),
"wrong! dimension is wrong"
);
constexpr
auto
map_new2old
=
Sequence
<
IRs
...
>
{};
constexpr
auto
map_new2old
=
Sequence
<
IRs
...
>
{};
return
make_ConstantTensorDescriptor
(
Lengths
{}.
ReorderGivenNew2Old
(
map_new2old
),
return
make_ConstantTensorDescriptor
(
Lengths
{}.
ReorderGivenNew2Old
(
map_new2old
),
Strides
{}.
ReorderGivenNew2Old
(
map_new2old
));
Strides
{}.
ReorderGivenNew2Old
(
map_new2old
));
...
...
src/include/Sequence.hip.hpp
View file @
33b5a855
...
@@ -2,14 +2,7 @@
...
@@ -2,14 +2,7 @@
#include "constant_integral.hip.hpp"
#include "constant_integral.hip.hpp"
#include "functional.hip.hpp"
#include "functional.hip.hpp"
struct
EmptySequence
struct
EmptySequence
;
{
template
<
class
Seq
>
__host__
__device__
constexpr
Seq
Append
(
Seq
)
const
{
return
{};
}
};
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
struct
Sequence
struct
Sequence
...
@@ -73,18 +66,18 @@ struct Sequence
...
@@ -73,18 +66,18 @@ struct Sequence
__host__
__device__
constexpr
auto
PopBack
()
const
;
__host__
__device__
constexpr
auto
PopBack
()
const
;
template
<
index_t
Xs
...>
template
<
index_t
...
Xs
>
__host__
__device__
constexpr
auto
Append
(
Sequence
<
Xs
...
>
)
const
__host__
__device__
constexpr
auto
Append
(
Sequence
<
Xs
...
>
)
const
{
{
return
Sequence
<
Is
...,
Xs
...
>
{};
return
Sequence
<
Is
...,
Xs
...
>
{};
}
}
__host__
__device__
constexpr
auto
Append
(
EmptySequence
)
const
{
return
Type
{};
}
__host__
__device__
constexpr
auto
Append
(
EmptySequence
)
const
;
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
constexpr
auto
Extract
(
Number
<
Ns
>
...)
const
__host__
__device__
constexpr
auto
Extract
(
Number
<
Ns
>
...)
const
{
{
return
Sequence
<
Type
{}.
Get
(
Number
<
Ns
>
)...
>
{};
return
Sequence
<
Get
(
Number
<
Ns
>
{}
)...
>
{};
}
}
template
<
index_t
N
>
template
<
index_t
N
>
...
@@ -93,8 +86,8 @@ struct Sequence
...
@@ -93,8 +86,8 @@ struct Sequence
template
<
class
FirstSeq
,
class
SecondSeq
>
template
<
class
FirstSeq
,
class
SecondSeq
>
__host__
__device__
constexpr
auto
operator
()(
FirstSeq
,
SecondSeq
)
const
__host__
__device__
constexpr
auto
operator
()(
FirstSeq
,
SecondSeq
)
const
{
{
constexpr
new_first
=
FirstSeq
{}.
PushBack
(
Number
<
Second
{}.
Front
()
>
{});
constexpr
index_t
new_first
=
FirstSeq
{}.
PushBack
(
Number
<
Second
Seq
{}.
Front
()
>
{});
constexpr
new_second
=
SecondSeq
{}.
PopFront
();
constexpr
index_t
new_second
=
SecondSeq
{}.
PopFront
();
static_if
<
(
N
>
0
)
>
{}([
&
](
auto
fwd
)
{
static_if
<
(
N
>
0
)
>
{}([
&
](
auto
fwd
)
{
return
split_impl
<
N
-
1
>
{}(
new_first
,
fwd
(
new_second
));
return
split_impl
<
N
-
1
>
{}(
new_first
,
fwd
(
new_second
));
...
@@ -102,26 +95,10 @@ struct Sequence
...
@@ -102,26 +95,10 @@ struct Sequence
}
}
};
};
// split one sequence to two sequnces: [0, I) and [I,
n
Size)
// split one sequence to two sequnces: [0, I) and [I,
m
Size)
// return type is std::pair
// return type is std::pair
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
Split
(
Number
<
I
>
)
const
__host__
__device__
constexpr
auto
Split
(
Number
<
I
>
)
const
;
{
static_assert
(
I
<=
nSize
,
"wrong! split position is too high!"
);
static_if
<
(
I
==
0
)
>
{}(
[
&
](
auto
fwd
)
{
return
std
::
make_pair
(
EmptySequence
<>
{},
fwd
(
Type
{}));
});
static_if
<
(
I
==
nSize
)
>
{}(
[
&
](
auto
fwd
)
{
return
std
::
make_pair
(
Type
<>
{},
fwd
(
EmptySequence
<>
{}));
});
static_if
<
(
I
>
0
&&
I
<
nSize
)
>
{}([
&
](
auto
fforwader
)
{
constexpr
auto
first
=
Sequence
<
Type
{}.
Front
()
>
{}
constexpr
auto
second
=
Type
{}.
PopFront
();
return
split_impl
<
I
-
1
>
{}(
first
,
fwd
(
second
));
});
}
template
<
index_t
I
,
index_t
X
>
template
<
index_t
I
,
index_t
X
>
__host__
__device__
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
const
__host__
__device__
constexpr
auto
Modify
(
Number
<
I
>
,
Number
<
X
>
)
const
...
@@ -135,22 +112,64 @@ struct Sequence
...
@@ -135,22 +112,64 @@ struct Sequence
}
}
};
};
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
struct
EmptySequence
__host__
__device__
auto
make_increasing_sequence
(
Number
<
IBegin
>
,
Number
<
IEnd
>
,
Number
<
Increment
>
)
{
{
static_assert
(
IBegin
<
IEnd
,
(
IEnd
-
IBegin
)
%
Increment
==
0
,
"wrong!"
);
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
0
;
}
// not implemented
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushFront
(
Number
<
I
>
)
const
{
return
Sequence
<
I
>
{};
}
template
<
index_t
I
>
__host__
__device__
constexpr
auto
PushBack
(
Number
<
I
>
)
const
{
return
Sequence
<
I
>
{};
}
template
<
class
Seq
>
__host__
__device__
constexpr
Seq
Append
(
Seq
)
const
{
return
Seq
{};
}
};
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Append
(
EmptySequence
)
const
{
return
Type
{};
}
// split one sequence to two sequnces: [0, I) and [I, mSize)
// return type is std::pair
template
<
index_t
...
Is
>
template
<
index_t
I
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Split
(
Number
<
I
>
)
const
{
static_assert
(
I
<=
GetSize
(),
"wrong! split position is too high!"
);
static_if
<
(
I
==
0
)
>
{}([
&
](
auto
fwd
)
{
return
std
::
make_pair
(
EmptySequence
{},
fwd
(
Type
{}));
});
static_if
<
(
I
==
GetSize
())
>
{}(
[
&
](
auto
fwd
)
{
return
std
::
make_pair
(
Type
{},
fwd
(
EmptySequence
{}));
});
static_if
<
(
I
>
0
&&
I
<
GetSize
())
>
{}(
[
&
](
auto
fwd
)
{
return
split_impl
<
I
>
{}(
EmptySequence
{},
fwd
(
Type
{}));
});
}
}
template
<
index_t
N
,
index_t
X
>
#if 0
__host__
__device__
auto
make_uniform_sequence
(
Number
<
N
>
,
Number
<
X
>
);
template <index_t IBegin, index_t IEnd, index_t Increment>
__host__ __device__ auto make_increasing_sequence(Number<IBegin>, Number<IEnd>, Number<Increment>)
{
{
static_assert(IBegin < IEnd, (IEnd - IBegin) % Increment == 0, "wrong!");
// not implemented
// not implemented
}
}
#endif
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
const
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
@@ -158,17 +177,18 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) c
...
@@ -158,17 +177,18 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Sequence<Ys...>) c
}
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Xs
...
>
seq_x
,
Sequence
<
Ys
...
>
seq_y
)
const
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Xs
...
>
seq_x
,
Sequence
<
Ys
...
>
seq_y
)
{
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_for
<
0
,
xs
.
GetSize
(),
1
>
{}([
&
](
auto
I
)
{
static_assert
(
seq_x
.
Get
(
I
)
>=
seq_y
.
Get
(
I
));
});
static_for
<
0
,
seq_x
.
GetSize
(),
1
>
{}(
[
&
](
auto
I
)
{
static_assert
(
seq_x
.
Get
(
I
)
>=
seq_y
.
Get
(
I
),
"wrong! going to undeflow"
);
});
return
Sequence
<
(
Xs
-
Ys
)...
>
{};
return
Sequence
<
(
Xs
-
Ys
)...
>
{};
}
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
*
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
const
__host__
__device__
constexpr
auto
operator
*
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
@@ -176,7 +196,7 @@ __host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)co
...
@@ -176,7 +196,7 @@ __host__ __device__ constexpr auto operator*(Sequence<Xs...>, Sequence<Ys...>)co
}
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
/
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
const
__host__
__device__
constexpr
auto
operator
/
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
@@ -184,15 +204,7 @@ __host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>) c
...
@@ -184,15 +204,7 @@ __host__ __device__ constexpr auto operator/(Sequence<Xs...>, Sequence<Ys...>) c
}
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
%
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
const
__host__
__device__
constexpr
auto
operator
%
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
return
Sequence
<
(
Xs
%
Ys
)...
>
{};
}
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
%
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
const
{
{
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
static_assert
(
sizeof
...(
Xs
)
==
sizeof
...(
Ys
),
"wrong! inconsistent size"
);
...
@@ -200,63 +212,79 @@ __host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>) c
...
@@ -200,63 +212,79 @@ __host__ __device__ constexpr auto operator%(Sequence<Xs...>, Sequence<Ys...>) c
}
}
template
<
index_t
...
Xs
,
index_t
Y
>
template
<
index_t
...
Xs
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
const
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
{
{
return
seq_x
+
make_uniform_sequence
(
Number
<
sizeof
...(
Xs
)
>
,
Number
<
Y
>
{}
)
;
return
Sequence
<
(
Xs
+
Y
)...
>
{};
}
}
template
<
index_t
...
Xs
,
index_t
Y
>
template
<
index_t
...
Xs
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
const
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
{
{
return
seq_x
-
make_uniform_sequence
(
Number
<
sizeof
...(
Xs
)
>
,
Number
<
Y
>
{});
constexpr
auto
seq_x
=
Sequence
<
Xs
...
>
{};
#if 0
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
});
#endif
return
Sequence
<
(
Xs
-
Y
)...
>
{};
}
}
template
<
index_t
...
Xs
,
index_t
Y
>
template
<
index_t
...
Xs
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
*
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
const
__host__
__device__
constexpr
auto
operator
*
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
{
{
return
seq_x
*
make_uniform_sequence
(
Number
<
sizeof
...(
Xs
)
>
,
Number
<
Y
>
{}
)
;
return
Sequence
<
(
Xs
*
Y
)...
>
{};
}
}
template
<
index_t
...
Xs
,
index_t
Y
>
template
<
index_t
...
Xs
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
/
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
const
__host__
__device__
constexpr
auto
operator
/
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
{
{
return
seq_x
/
make_uniform_sequence
(
Number
<
sizeof
...(
Xs
)
>
,
Number
<
Y
>
{}
)
;
return
Sequence
<
(
Xs
/
Y
)...
>
{};
}
}
template
<
index_t
...
Xs
,
index_t
Y
>
template
<
index_t
...
Xs
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
%
(
Sequence
<
Xs
...
>
seq_x
,
Number
<
Y
>
y
)
const
__host__
__device__
constexpr
auto
operator
%
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
{
{
return
seq_x
%
make_uniform_sequence
(
Number
<
sizeof
...(
Xs
)
>
,
Number
<
Y
>
{}
)
;
return
Sequence
<
(
Xs
%
Y
)...
>
{};
}
}
template
<
index_t
X
,
index_t
...
Y
s
>
template
<
index_t
Y
,
index_t
...
X
s
>
__host__
__device__
constexpr
auto
operator
+
(
Number
<
X
>
,
Sequence
<
Y
s
...
>
)
const
__host__
__device__
constexpr
auto
operator
+
(
Number
<
Y
>
,
Sequence
<
X
s
...
>
)
{
{
return
make_uniform_sequence
(
Number
<
sizeof
...(
Ys
)
>
{},
Number
<
X
>
{})
+
Sequence
<
Ys
...
>
{};
return
Sequence
<
(
Y
+
Xs
)
...
>
{};
}
}
template
<
index_t
X
,
index_t
...
Y
s
>
template
<
index_t
Y
,
index_t
...
X
s
>
__host__
__device__
constexpr
auto
operator
-
(
Number
<
X
>
,
Sequence
<
Y
s
...
>
)
const
__host__
__device__
constexpr
auto
operator
-
(
Number
<
Y
>
,
Sequence
<
X
s
...
>
)
{
{
return
make_uniform_sequence
(
Number
<
sizeof
...(
Ys
)
>
{},
Number
<
X
>
{})
-
Sequence
<
Ys
...
>
{};
constexpr
auto
seq_x
=
Sequence
<
Xs
...
>
{};
static_for
<
0
,
sizeof
...(
Xs
),
1
>
{}([
&
](
auto
Iter
)
{
constexpr
auto
I
=
decltype
(
Iter
){};
static_assert
(
seq_x
.
Get
(
I
)
<=
Y
,
"wrong! going to underflow"
);
});
return
Sequence
<
(
Y
-
Xs
)...
>
{};
}
}
template
<
index_t
X
,
index_t
...
Y
s
>
template
<
index_t
Y
,
index_t
...
X
s
>
__host__
__device__
constexpr
auto
operator
*
(
Number
<
X
>
,
Sequence
<
Y
s
...
>
)
const
__host__
__device__
constexpr
auto
operator
*
(
Number
<
Y
>
,
Sequence
<
X
s
...
>
)
{
{
return
make_uniform_sequence
(
Number
<
sizeof
...(
Ys
)
>
{},
Number
<
X
>
{})
*
Sequence
<
Ys
...
>
{};
return
Sequence
<
(
Y
*
Xs
)
...
>
{};
}
}
template
<
index_t
X
,
index_t
...
Y
s
>
template
<
index_t
Y
,
index_t
...
X
s
>
__host__
__device__
constexpr
auto
operator
/
(
Number
<
X
>
,
Sequence
<
Y
s
...
>
)
const
__host__
__device__
constexpr
auto
operator
/
(
Number
<
Y
>
,
Sequence
<
X
s
...
>
)
{
{
return
make_uniform_sequence
(
Number
<
sizeof
...(
Ys
)
>
{},
Number
<
X
>
{})
/
Sequence
<
Ys
...
>
{};
return
Sequence
<
(
Y
/
Xs
)
...
>
{};
}
}
template
<
index_t
X
,
index_t
...
Y
s
>
template
<
index_t
Y
,
index_t
...
X
s
>
__host__
__device__
constexpr
auto
operator
%
(
Number
<
X
>
,
Sequence
<
Y
s
...
>
)
const
__host__
__device__
constexpr
auto
operator
%
(
Number
<
Y
>
,
Sequence
<
X
s
...
>
)
{
{
return
make_uniform_sequence
(
Number
<
sizeof
...(
Ys
)
>
{},
Number
<
X
>
{})
%
Sequence
<
Ys
...
>
{};
return
Sequence
<
(
Y
%
Xs
)
...
>
{};
}
}
template
<
index_t
I
,
index_t
...
Is
>
template
<
index_t
I
,
index_t
...
Is
>
...
@@ -268,7 +296,7 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
...
@@ -268,7 +296,7 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
#if 0
#if 0
// TODO: for some reason, compiler cannot instantiate this template
// TODO: for some reason, compiler cannot instantiate this template
template <index_t I, index_t
...
I
s
>
template <index_t
...
I
s
, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
{
static_assert(sizeof...(Is) > 0, "empty Sequence!");
static_assert(sizeof...(Is) > 0, "empty Sequence!");
...
@@ -356,8 +384,6 @@ __host__ __device__ constexpr auto
...
@@ -356,8 +384,6 @@ __host__ __device__ constexpr auto
}
}
#endif
#endif
#if 1
// TODO: fix these mess
template
<
class
F
,
index_t
...
Xs
>
template
<
class
F
,
index_t
...
Xs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
{
{
...
@@ -382,45 +408,6 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
...
@@ -382,45 +408,6 @@ transform_sequences(F f, Sequence<Xs...>, Sequence<Ys...>, Sequence<Zs...>)
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
return
Sequence
<
f
(
Xs
,
Ys
,
Zs
)...
>
{};
}
}
#else
// TODO:: these doesn't compile
template
<
index_t
NRemain
>
struct
transform_sequences_impl
{
template
<
class
F
,
class
Y
,
class
...
Xs
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
Y
y
,
Xs
...
xs
)
const
{
static_assert
(
NRemain
>
1
,
"wrong! should have NRemain > 1"
);
constexpr
index_t
N
=
f
(
Xs
{}.
Get
(
Number
<
0
>
{})...);
constexpr
auto
y_new
=
y
.
PushBack
(
Number
<
N
>
{});
return
transform_sequences_impl
<
NRemain
-
1
>
{}(
f
,
y_new
,
xs
.
PopFront
()...);
}
};
template
<
>
struct
transform_sequences_impl
<
1
>
{
template
<
class
F
,
class
Y
,
class
...
Xs
>
__host__
__device__
constexpr
auto
operator
()(
F
f
,
Y
,
Xs
...)
const
{
constexpr
index_t
N
=
f
(
Xs
{}.
Get
(
Number
<
0
>
{})...);
return
Y
{}.
PushBack
(
Number
<
N
>
{});
}
};
template
<
class
F
,
class
X
,
class
...
Xs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
X
x
,
Xs
...
xs
)
{
constexpr
index_t
nSize
=
X
::
GetSize
();
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
y0
=
Sequence
<
f
(
X
{}.
Get
(
I0
),
Xs
{}.
Get
(
I0
)...)
>
{};
return
transform_sequences_impl
<
nSize
-
1
>
{}(
f
,
y0
,
x
.
PopFront
(),
xs
.
PopFront
()...);
}
#endif
template
<
index_t
...
Is
>
template
<
index_t
...
Is
>
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopFront
()
const
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
PopFront
()
const
...
...
src/include/blockwise_gemm.hip.hpp
View file @
33b5a855
...
@@ -7,6 +7,7 @@
...
@@ -7,6 +7,7 @@
template
<
index_t
BlockSize
,
template
<
index_t
BlockSize
,
class
BlockMatrixA
,
class
BlockMatrixA
,
class
BlockMatrixB
,
class
BlockMatrixB
,
class
ThreadMatrixC
,
index_t
MPerThreadSubC
,
index_t
MPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
NPerThreadSubC
,
index_t
MLevel0Cluster
,
index_t
MLevel0Cluster
,
...
@@ -45,7 +46,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
...
@@ -45,7 +46,7 @@ struct BlockwiseGemmBlockABlockBThreadCTransANormalBNormalC_v2
N
%
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
N
%
(
NPerThreadSubC
*
NLevel0Cluster
*
NLevel1Cluster
)
==
0
,
"wrong! Cannot evenly divide work among
\n
"
);
"wrong! Cannot evenly divide work among
\n
"
);
static_assert
(
ThreadMatrixC
::
GetLengths
()
==
GetThreadMatrixCLengths
,
static_assert
(
ThreadMatrixC
::
GetLengths
()
==
GetThreadMatrixCLengths
()
,
"wrong! ThreadMatrixC lengths is wrong"
);
"wrong! ThreadMatrixC lengths is wrong"
);
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
auto
c_thread_mtx_index
=
GetBeginOfThreadMatrixC
(
get_thread_local_1d_id
());
...
...
src/include/blockwise_tensor_slice_op.hip.hpp
View file @
33b5a855
...
@@ -132,16 +132,15 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -132,16 +132,15 @@ struct BlockwiseTensorSliceReorderCopy_v3
{
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
transform_sequences
(
constexpr
auto
src_data_per_cluster_per_dims
=
std
::
multiplies
<
index_t
>
{},
thread_sub_tensor_lengths
,
SrcClusterLengths
{}
)
;
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
constexpr
auto
repeat_lengths
=
transform_sequences
(
mod_conv
::
integer_divide_ceiler
<
index_t
>
{},
transform_sequences
(
mod_conv
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
transform_sequences
(
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
std
::
multiplies
<
index_t
>
{},
thread_sub_tensor_lengths
,
repeat_lengths
);
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor
(
thread_tensor_lengths
);
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor
(
thread_tensor_lengths
);
...
@@ -153,27 +152,24 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -153,27 +152,24 @@ struct BlockwiseTensorSliceReorderCopy_v3
{
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
transform_sequences
(
constexpr
auto
src_data_per_cluster_per_dims
=
std
::
multiplies
<
index_t
>
{},
thread_sub_tensor_lengths
,
SrcClusterLengths
{}
)
;
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
constexpr
auto
repeat_lengths
=
transform_sequences
(
mod_conv
::
integer_divide_ceiler
<
index_t
>
{},
transform_sequences
(
mod_conv
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
transform_sequences
(
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
std
::
multiplies
<
index_t
>
{},
thread_sub_tensor_lengths
,
repeat_lengths
);
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor
(
thread_tensor_lengths
);
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor
(
thread_tensor_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
src_data_multi_id
=
transform_sequences
(
constexpr
auto
src_data_multi_id
=
repeat_multi_id
*
src_data_per_cluster_per_dims
;
std
::
multiplies
<
index_t
>
{},
repeat_multi_id
,
src_data_per_cluster_per_dims
);
constexpr
auto
clipboard_data_multi_id
=
transform_sequences
(
constexpr
auto
clipboard_data_multi_id
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
std
::
multiplies
<
index_t
>
{},
repeat_multi_id
,
thread_sub_tensor_lengths
);
constexpr
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
src_data_multi_id
);
constexpr
index_t
src_offset
=
SrcDesc
{}.
Get1dIndex
(
src_data_multi_id
);
constexpr
index_t
clipboard_offset
=
constexpr
index_t
clipboard_offset
=
...
@@ -193,27 +189,24 @@ struct BlockwiseTensorSliceReorderCopy_v3
...
@@ -193,27 +189,24 @@ struct BlockwiseTensorSliceReorderCopy_v3
{
{
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
thread_sub_tensor_lengths
=
SrcSubLengths
{};
constexpr
auto
src_data_per_cluster_per_dims
=
transform_sequences
(
constexpr
auto
src_data_per_cluster_per_dims
=
std
::
multiplies
<
index_t
>
{},
thread_sub_tensor_lengths
,
SrcClusterLengths
{}
)
;
thread_sub_tensor_lengths
*
SrcClusterLengths
{};
constexpr
auto
repeat_lengths
=
constexpr
auto
repeat_lengths
=
transform_sequences
(
mod_conv
::
integer_divide_ceiler
<
index_t
>
{},
transform_sequences
(
mod_conv
::
integer_divide_ceiler
<
index_t
>
{},
SrcLengths
{},
SrcLengths
{},
src_data_per_cluster_per_dims
);
src_data_per_cluster_per_dims
);
constexpr
auto
thread_tensor_lengths
=
transform_sequences
(
constexpr
auto
thread_tensor_lengths
=
thread_sub_tensor_lengths
*
repeat_lengths
;
std
::
multiplies
<
index_t
>
{},
thread_sub_tensor_lengths
,
repeat_lengths
);
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor
(
thread_tensor_lengths
);
constexpr
auto
thread_tensor_desc
=
make_ConstantTensorDescriptor
(
thread_tensor_lengths
);
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
static_ford
<
decltype
(
repeat_lengths
)
>
{}([
&
](
auto
repeat_multi_id_
)
{
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
repeat_multi_id
=
decltype
(
repeat_multi_id_
){};
constexpr
auto
clipboard_data_multi_id
=
transform_sequences
(
constexpr
auto
clipboard_data_multi_id
=
repeat_multi_id
*
thread_sub_tensor_lengths
;
std
::
multiplies
<
index_t
>
{},
repeat_multi_id
,
thread_sub_tensor_lengths
);
constexpr
auto
src_data_multi_id
=
transform_sequences
(
constexpr
auto
src_data_multi_id
=
repeat_multi_id
*
src_data_per_cluster_per_dims
;
std
::
multiplies
<
index_t
>
{},
repeat_multi_id
,
src_data_per_cluster_per_dims
);
// reorder src_data_multi_id to get dst_data_multi_id
// reorder src_data_multi_id to get dst_data_multi_id
constexpr
auto
dst_data_multi_id
=
src_data_multi_id
.
ReorderGivenNew2Old
(
MapDst2Src
{});
constexpr
auto
dst_data_multi_id
=
src_data_multi_id
.
ReorderGivenNew2Old
(
MapDst2Src
{});
...
...
src/include/functional.hip.hpp
View file @
33b5a855
...
@@ -37,7 +37,8 @@ struct static_if<true>
...
@@ -37,7 +37,8 @@ struct static_if<true>
{
{
// This is a trick for compiler:
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// this will make "f" a generic lambda, so that "f" won't be compiled until here
// this will make "f" a generic lambda, so that "f" won't be compiled until being
// instantiated here
f
(
forwarder
{});
f
(
forwarder
{});
return
Type
{};
return
Type
{};
}
}
...
@@ -65,7 +66,8 @@ struct static_if<false>
...
@@ -65,7 +66,8 @@ struct static_if<false>
{
{
// This is a trick for compiler:
// This is a trick for compiler:
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// Pass forwarder to lambda "f" as "auto" argument, and maks sure "f" will use it,
// this will make "f" a generic lambda, so that "f" won't be compiled until here
// this will make "f" a generic lambda, so that "f" won't be compiled until being
// instantiated here
f
(
forwarder
{});
f
(
forwarder
{});
return
Type
{};
return
Type
{};
}
}
...
@@ -105,7 +107,7 @@ struct static_for
...
@@ -105,7 +107,7 @@ struct static_for
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
static_if
<
(
NBegin
<
End
)
>
{}(
static_if
<
(
NBegin
<
N
End
)
>
{}(
[
&
](
auto
fwd
)
{
static_for_impl
<
NBegin
,
NEnd
-
NBegin
,
fwd
(
Increment
)
>
{}(
f
);
});
[
&
](
auto
fwd
)
{
static_for_impl
<
NBegin
,
NEnd
-
NBegin
,
fwd
(
Increment
)
>
{}(
f
);
});
}
}
};
};
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
View file @
33b5a855
...
@@ -201,7 +201,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -201,7 +201,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
0
#if
1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
#elif 0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_khwn.hip.hpp
View file @
33b5a855
...
@@ -142,7 +142,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
...
@@ -142,7 +142,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
decltype
(
map_chwn2nchw
),
decltype
(
map_chwn2nchw
),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
>
{}
;
InBlockReorderDataPerWrite_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
})
;
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
// format is [CPerBlock, KPerBlock]
...
@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
...
@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_khwn
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
0
#if
1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
#elif 0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
...
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw.hip.hpp
View file @
33b5a855
...
@@ -142,7 +142,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -142,7 +142,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
decltype
(
map_chwn2nchw
),
decltype
(
map_chwn2nchw
),
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderMapThreadCluster2SrcCluster_CHNW2NCHW
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerRead_W
,
InBlockReorderDataPerWrite_N
>
{}
;
InBlockReorderDataPerWrite_N
>
({
0
,
0
,
0
,
0
},
{
0
,
0
,
0
,
0
})
;
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock, KPerBlock]
// format is [CPerBlock, KPerBlock]
...
@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
...
@@ -196,7 +196,7 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_nchw_cyxk_nkhw
// choose GEMM implementation here
// choose GEMM implementation here
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
const
auto
run_blockwise_batch_gemm
=
[
&
](
auto
...
Xs
)
{
#if
0
#if
1
return
blockwise_batch_gemm
.
Run
(
Xs
...);
return
blockwise_batch_gemm
.
Run
(
Xs
...);
#elif 0
#elif 0
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
return
blockwise_batch_gemm
.
Run_asm
(
Xs
...);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment