Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
yangql
composable_kernel-1
Commits
a6b95c39
"git@developer.sourcefind.cn:modelzoo/resnet50_tensorflow.git" did not exist on "a95f2cf5c2b4db0d6912b30e3b0c478d401999f3"
Commit
a6b95c39
authored
May 18, 2019
by
Chao Liu
Browse files
rework sequence
parent
df73287b
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
166 additions
and
228 deletions
+166
-228
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+6
-23
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+99
-136
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
...plicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
+61
-69
No files found.
src/include/ConstantTensorDescriptor.hip.hpp
View file @
a6b95c39
#pragma once
#pragma once
#include "common.hip.hpp"
#include "common.hip.hpp"
template
<
class
PreviousStrides
,
class
RemainLengths
>
__host__
__device__
constexpr
auto
calculate_default_strides_impl
(
PreviousStrides
,
RemainLengths
)
{
constexpr
index_t
previous_stride
=
PreviousStrides
{}.
Front
();
constexpr
index_t
current_length
=
RemainLengths
{}.
Back
();
constexpr
index_t
current_stride
=
current_length
*
previous_stride
;
return
calculate_default_strides_impl
(
PreviousStrides
{}.
PushFront
(
Number
<
current_stride
>
{}),
RemainLengths
{}.
PopBack
());
}
template
<
class
PreviousStrides
,
index_t
L0
,
index_t
L1
>
__host__
__device__
constexpr
auto
calculate_default_strides_impl
(
PreviousStrides
,
Sequence
<
L0
,
L1
>
)
{
constexpr
index_t
previous_stride
=
PreviousStrides
{}.
Front
();
constexpr
index_t
current_stride
=
L1
*
previous_stride
;
return
PreviousStrides
{}.
PushFront
(
Number
<
current_stride
>
{});
}
template
<
class
Lengths
>
template
<
class
Lengths
>
__host__
__device__
constexpr
auto
calculate_default_strides
(
Lengths
)
__host__
__device__
constexpr
auto
calculate_default_strides
(
Lengths
)
{
{
return
calculate_default_strides_impl
(
Sequence
<
1
>
{},
Lengths
{});
return
reverse_inclusive_scan_sequence
(
Lengths
{}.
PopFront
().
PushBack
(
Number
<
1
>
{}),
std
::
multiplies
<
index_t
>
{});
}
}
// this is ugly, only for 2d
// this is ugly, only for 2d
...
@@ -58,6 +39,7 @@ template <class Lengths, class Strides>
...
@@ -58,6 +39,7 @@ template <class Lengths, class Strides>
struct
ConstantTensorDescriptor
struct
ConstantTensorDescriptor
{
{
using
Type
=
ConstantTensorDescriptor
;
using
Type
=
ConstantTensorDescriptor
;
static
constexpr
index_t
nDim
=
Lengths
::
GetSize
();
static
constexpr
index_t
nDim
=
Lengths
::
GetSize
();
__host__
__device__
constexpr
ConstantTensorDescriptor
()
__host__
__device__
constexpr
ConstantTensorDescriptor
()
...
@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
...
@@ -193,7 +175,8 @@ struct ConstantTensorDescriptor
// folded strides
// folded strides
constexpr
auto
fold_strides
=
constexpr
auto
fold_strides
=
Number
<
unfold_stride
>
{}
*
Number
<
unfold_stride
>
{}
*
reverse_scan_sequence
(
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
std
::
multiplies
<
index_t
>
{});
reverse_inclusive_scan_sequence
(
fold_intervals
.
PushBack
(
Number
<
1
>
{}),
std
::
multiplies
<
index_t
>
{});
// left and right
// left and right
constexpr
auto
left
=
make_increasing_sequence
(
Number
<
0
>
{},
Number
<
IDim
>
{},
Number
<
1
>
{});
constexpr
auto
left
=
make_increasing_sequence
(
Number
<
0
>
{},
Number
<
IDim
>
{},
Number
<
1
>
{});
...
...
src/include/Sequence.hip.hpp
View file @
a6b95c39
...
@@ -9,7 +9,8 @@ struct Sequence
...
@@ -9,7 +9,8 @@ struct Sequence
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
static
constexpr
index_t
mSize
=
sizeof
...(
Is
);
const
index_t
mData
[
mSize
]
=
{
Is
...};
const
index_t
mData
[
mSize
+
1
]
=
{
Is
...,
0
};
// the last element is dummy, to prevent compiler complain on empty Sequence
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
__host__
__device__
static
constexpr
index_t
GetSize
()
{
return
mSize
;
}
...
@@ -39,10 +40,7 @@ struct Sequence
...
@@ -39,10 +40,7 @@ struct Sequence
assert
(
false
);
assert
(
false
);
}
}
__host__
__device__
constexpr
auto
Reverse
()
const
__host__
__device__
constexpr
auto
Reverse
()
const
;
{
// not implemented
}
__host__
__device__
constexpr
index_t
Front
()
const
{
return
mData
[
0
];
}
__host__
__device__
constexpr
index_t
Front
()
const
{
return
mData
[
0
];
}
...
@@ -73,13 +71,13 @@ struct Sequence
...
@@ -73,13 +71,13 @@ struct Sequence
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
constexpr
auto
Extract
(
Number
<
Ns
>
...)
const
__host__
__device__
constexpr
auto
Extract
(
Number
<
Ns
>
...)
const
{
{
return
Sequence
<
Get
(
Number
<
Ns
>
{})...
>
{};
return
Sequence
<
Type
{}.
Get
(
Number
<
Ns
>
{})...
>
{};
}
}
template
<
index_t
...
Ns
>
template
<
index_t
...
Ns
>
__host__
__device__
constexpr
auto
Extract
(
Sequence
<
Ns
...
>
)
const
__host__
__device__
constexpr
auto
Extract
(
Sequence
<
Ns
...
>
)
const
{
{
return
Sequence
<
Get
(
Number
<
Ns
>
{})...
>
{};
return
Sequence
<
Type
{}.
Get
(
Number
<
Ns
>
{})...
>
{};
}
}
};
};
...
@@ -89,43 +87,109 @@ struct sequence_merge;
...
@@ -89,43 +87,109 @@ struct sequence_merge;
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
struct
sequence_merge
<
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>>
{
{
using
Type
=
Sequence
<
Xs
...,
Ys
...
>
;
using
Seq
Type
=
Sequence
<
Xs
...,
Ys
...
>
;
};
};
template
<
index_t
IBegin
,
index_t
NSize
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
NSize
,
index_t
Increment
>
struct
increasing_sequence_gen
struct
increasing_sequence_gen
_impl
{
{
static
constexpr
index_t
NSizeLeft
=
NSize
/
2
;
static
constexpr
index_t
NSizeLeft
=
NSize
/
2
;
using
Type
=
using
Seq
Type
=
typename
sequence_merge
<
sequence_merge
<
typename
increasing_sequence_gen
<
IBegin
,
NSizeLeft
,
Increment
>::
Type
,
typename
increasing_sequence_gen
_impl
<
IBegin
,
NSizeLeft
,
Increment
>::
Seq
Type
,
typename
increasing_sequence_gen
<
IBegin
+
NSizeLeft
*
Increment
,
typename
increasing_sequence_gen
_impl
<
IBegin
+
NSizeLeft
*
Increment
,
NSize
-
NSizeLeft
,
NSize
-
NSizeLeft
,
Increment
>::
Type
>
;
Increment
>::
Seq
Type
>
::
SeqType
;
};
};
template
<
index_t
IBegin
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
Increment
>
struct
increasing_sequence_gen
<
IBegin
,
1
,
Increment
>
struct
increasing_sequence_gen
_impl
<
IBegin
,
1
,
Increment
>
{
{
using
Type
=
Sequence
<
IBegin
>
;
using
Seq
Type
=
Sequence
<
IBegin
>
;
};
};
template
<
index_t
IBegin
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
Increment
>
struct
increasing_sequence_gen
<
IBegin
,
0
,
Increment
>
struct
increasing_sequence_gen
_impl
<
IBegin
,
0
,
Increment
>
{
{
using
Type
=
Sequence
<>
;
using
SeqType
=
Sequence
<>
;
};
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
struct
increasing_sequence_gen
{
using
SeqType
=
typename
increasing_sequence_gen_impl
<
IBegin
,
IEnd
-
IBegin
,
Increment
>::
SeqType
;
};
};
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
template
<
index_t
IBegin
,
index_t
IEnd
,
index_t
Increment
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_increasing_sequence
(
Number
<
IBegin
>
,
Number
<
IEnd
>
,
Number
<
Increment
>
)
make_increasing_sequence
(
Number
<
IBegin
>
,
Number
<
IEnd
>
,
Number
<
Increment
>
)
{
{
static_assert
(
IBegin
<=
IEnd
&&
Increment
>
0
,
"wrong!"
);
return
typename
increasing_sequence_gen
<
IBegin
,
IEnd
,
Increment
>::
SeqType
{};
}
constexpr
index_t
NSize
=
(
IEnd
-
IBegin
)
/
Increment
;
template
<
class
,
class
>
struct
sequence_reverse_inclusive_scan
;
return
increasing_sequence_gen
<
IBegin
,
NSize
,
Increment
>
{};
template
<
index_t
I
,
index_t
...
Is
,
class
Reduce
>
}
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
,
Is
...
>
,
Reduce
>
{
using
old_scan
=
typename
sequence_reverse_inclusive_scan
<
Sequence
<
Is
...
>
,
Reduce
>::
SeqType
;
static
constexpr
index_t
new_reduce
=
Reduce
{}(
I
,
old_scan
{}.
Front
());
using
SeqType
=
typename
sequence_merge
<
Sequence
<
new_reduce
>
,
old_scan
>::
SeqType
;
};
template
<
index_t
I
,
class
Reduce
>
struct
sequence_reverse_inclusive_scan
<
Sequence
<
I
>
,
Reduce
>
{
using
SeqType
=
Sequence
<
I
>
;
};
template
<
class
,
class
>
struct
sequence_extract
;
template
<
class
Seq
,
index_t
...
Is
>
struct
sequence_extract
<
Seq
,
Sequence
<
Is
...
>>
{
using
SeqType
=
Sequence
<
Seq
{}.
Get
(
Number
<
Is
>
{})...
>
;
};
template
<
class
Seq
,
index_t
I
>
struct
sequence_split
{
static
constexpr
index_t
NSize
=
Seq
{}.
GetSize
();
using
range0
=
typename
increasing_sequence_gen
<
0
,
I
,
1
>::
SeqType
;
using
range1
=
typename
increasing_sequence_gen
<
I
,
NSize
,
1
>::
SeqType
;
using
SeqType0
=
typename
sequence_extract
<
Seq
,
range0
>::
SeqType
;
using
SeqType1
=
typename
sequence_extract
<
Seq
,
range1
>::
SeqType
;
};
template
<
class
Seq
>
struct
sequence_reverse
{
static
constexpr
index_t
NSize
=
Seq
{}.
GetSize
();
using
seq_split
=
sequence_split
<
Seq
,
NSize
/
2
>
;
using
SeqType
=
typename
sequence_merge
<
typename
sequence_reverse
<
typename
seq_split
::
SeqType1
>::
SeqType
,
typename
sequence_reverse
<
typename
seq_split
::
SeqType0
>::
SeqType
>::
SeqType
;
};
template
<
index_t
I
>
struct
sequence_reverse
<
Sequence
<
I
>>
{
using
SeqType
=
Sequence
<
I
>
;
};
template
<
index_t
I0
,
index_t
I1
>
struct
sequence_reverse
<
Sequence
<
I0
,
I1
>>
{
using
SeqType
=
Sequence
<
I1
,
I0
>
;
};
template
<
index_t
...
Xs
,
index_t
...
Ys
>
template
<
index_t
...
Xs
,
index_t
...
Ys
>
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
__host__
__device__
constexpr
auto
operator
+
(
Sequence
<
Xs
...
>
,
Sequence
<
Ys
...
>
)
...
@@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
...
@@ -179,9 +243,9 @@ __host__ __device__ constexpr auto operator+(Sequence<Xs...>, Number<Y>)
template
<
index_t
...
Xs
,
index_t
Y
>
template
<
index_t
...
Xs
,
index_t
Y
>
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
__host__
__device__
constexpr
auto
operator
-
(
Sequence
<
Xs
...
>
,
Number
<
Y
>
)
{
{
#if 0 // doesn't compile
constexpr auto seq_x = Sequence<Xs...>{};
constexpr auto seq_x = Sequence<Xs...>{};
#if 0 // doesn't compile
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
static_for<0, sizeof...(Xs), 1>{}([&](auto Iter) {
constexpr auto I = decltype(Iter){};
constexpr auto I = decltype(Iter){};
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
static_assert(seq_x.Get(I) >= Y, "wrong! going to underflow");
...
@@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
...
@@ -253,95 +317,12 @@ __host__ __device__ constexpr auto sequence_pop_front(Sequence<I, Is...>)
return
Sequence
<
Is
...
>
{};
return
Sequence
<
Is
...
>
{};
}
}
#if 0
template
<
class
Seq
>
// TODO: for some reason, compiler cannot instantiate this template
__host__
__device__
constexpr
auto
sequence_pop_back
(
Seq
)
template <index_t... Is, index_t I>
__host__ __device__ constexpr auto sequence_pop_back(Sequence<Is..., I>)
{
static_assert(sizeof...(Is) > 0, "empty Sequence!");
return Sequence<Is...>{};
}
#else
// TODO: delete these very ugly mess
template
<
index_t
I0
,
index_t
I1
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
>
)
{
return
Sequence
<
I0
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
>
)
{
return
Sequence
<
I0
,
I1
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
{
return
Sequence
<
I0
,
I1
,
I2
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
,
index_t
I7
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
,
index_t
I7
,
index_t
I8
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
,
I8
>
)
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
>
{};
}
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
,
index_t
I4
,
index_t
I5
,
index_t
I6
,
index_t
I7
,
index_t
I8
,
index_t
I9
>
__host__
__device__
constexpr
auto
sequence_pop_back
(
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
,
I8
,
I9
>
)
{
{
return
Sequence
<
I0
,
I1
,
I2
,
I3
,
I4
,
I5
,
I6
,
I7
,
I8
>
{};
static_assert
(
Seq
{}.
GetSize
()
>
0
,
"empty Sequence!"
);
return
sequence_pop_front
(
Seq
{}.
Reverse
()).
Reverse
();
}
}
#endif
template
<
class
F
,
index_t
...
Xs
>
template
<
class
F
,
index_t
...
Xs
>
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
__host__
__device__
constexpr
auto
transform_sequences
(
F
f
,
Sequence
<
Xs
...
>
)
...
@@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t
...
@@ -399,38 +380,20 @@ __host__ __device__ constexpr index_t
return
Reduce
{}(
a
,
I
);
return
Reduce
{}(
a
,
I
);
}
}
template
<
index_t
NRemain
>
template
<
index_t
...
Is
>
struct
scan_sequence_impl
__host__
__device__
constexpr
auto
Sequence
<
Is
...
>::
Reverse
()
const
{
{
template
<
class
ScanedSeq
,
class
RemainSeq
,
class
Reduce
>
return
typename
sequence_reverse
<
Sequence
<
Is
...
>>::
SeqType
{};
__host__
__device__
constexpr
auto
operator
()(
ScanedSeq
,
RemainSeq
,
Reduce
)
const
}
{
static_assert
(
RemainSeq
{}.
GetSize
()
==
NRemain
,
"wrong! RemainSeq and NRemain not consistent!"
);
constexpr
index_t
a
=
Reduce
{}(
ScanedSeq
{}.
Back
(),
RemainSeq
{}.
Front
());
constexpr
auto
scaned_seq
=
ScanedSeq
{}.
PushBack
(
Number
<
a
>
{});
static_if
<
(
NRemain
>
1
)
>
{}([
&
](
auto
fwd
)
{
return
scan_sequence_impl
<
NRemain
-
1
>
{}(
scaned_seq
,
RemainSeq
{}.
PopFront
(),
fwd
(
Reduce
{}));
}).
else_
([
&
](
auto
fwd
)
{
return
fwd
(
scaned_seq
);
});
}
};
template
<
class
Seq
,
class
Reduce
>
template
<
class
Seq
,
class
Reduce
>
__host__
__device__
constexpr
auto
scan_sequence
(
Seq
,
Reduce
)
__host__
__device__
constexpr
auto
reverse_inclusive_
scan_sequence
(
Seq
,
Reduce
)
{
{
constexpr
auto
scaned_seq
=
Sequence
<
Seq
{}.
front
()
>
{};
return
typename
sequence_reverse_inclusive_scan
<
Seq
,
Reduce
>::
SeqType
{};
constexpr
auto
remain_seq
=
Seq
{}.
PopFront
();
constexpr
index_t
remain_size
=
Seq
::
GetSize
()
-
1
;
return
scan_sequence_impl
<
remain_size
>
{}(
scaned_seq
,
remain_seq
,
Reduce
{});
}
}
template
<
class
Seq
,
class
Reduce
>
template
<
class
Seq
,
class
Reduce
>
__host__
__device__
constexpr
auto
revers
e_scan_sequence
(
Seq
,
Reduce
)
__host__
__device__
constexpr
auto
inclusiv
e_scan_sequence
(
Seq
,
Reduce
)
{
{
return
scan_seq
e
unce
(
Seq
{}.
Reverse
(),
Reduce
{}).
Reverse
();
return
reverse_inclusive_
scan_sequ
e
nce
(
Seq
{}.
Reverse
(),
Reduce
{}).
Reverse
();
}
}
src/include/gridwise_convolution_implicit_gemm_v1r3_lds_double_buffer_chwn_cyxk_khwn.hip.hpp
View file @
a6b95c39
...
@@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -80,29 +80,26 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
Ho
%
HoPerBlock
==
0
&&
Wo
%
WoPerBlock
==
0
,
"wrong! cannot evenly divide work for workgroup "
);
"wrong! cannot evenly divide work for workgroup "
);
constexpr
index_t
KBlockWork
=
(
K
+
KPerBlock
-
1
)
/
KPerBlock
;
constexpr
index_t
KBlockWork
=
mod_conv
::
integer_divide_ceil
(
K
,
KPerBlock
);
constexpr
index_t
HBlockWork
=
(
Ho
+
HoPerBlock
-
1
)
/
HoPerBlock
;
constexpr
index_t
HBlockWork
=
mod_conv
::
integer_divide_ceil
(
Ho
,
HoPerBlock
);
constexpr
index_t
WBlockWork
=
(
Wo
+
WoPerBlock
-
1
)
/
WoPerBlock
;
constexpr
index_t
WBlockWork
=
mod_conv
::
integer_divide_ceil
(
Wo
,
WoPerBlock
);
constexpr
index_t
NBlockWork
=
(
N
+
NPerBlock
-
1
)
/
NPerBlock
;
constexpr
index_t
NBlockWork
=
mod_conv
::
integer_divide_ceil
(
N
,
NPerBlock
);
const
index_t
k_block_work_id
=
get_block_1d_id
()
/
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
constexpr
auto
block_work_desc
=
make_ConstantTensorDescriptor
(
index_t
itmp
=
get_block_1d_id
()
-
k_block_work_id
*
(
HBlockWork
*
WBlockWork
*
NBlockWork
);
Sequence
<
KBlockWork
,
HBlockWork
,
WBlockWork
,
NBlockWork
>
{});
const
index_t
h_block_work_id
=
itmp
/
(
WBlockWork
*
NBlockWork
);
itmp
-=
h_block_work_id
*
(
WBlockWork
*
NBlockWork
);
const
auto
block_work_multi_id
=
block_work_desc
.
GetMultiIndex
(
get_block_1d_id
());
const
index_t
w_block_work_id
=
itmp
/
NBlockWork
;
const
index_t
n_block_work_id
=
itmp
-
w_block_work_id
*
NBlockWork
;
const
index_t
k_block_data_begin
=
block_work_multi_id
[
0
]
*
KPerBlock
;
const
index_t
ho_block_data_begin
=
block_work_multi_id
[
1
]
*
HoPerBlock
;
const
index_t
k_block_data_begin
=
k_block_work_id
*
KPerBlock
;
const
index_t
wo_block_data_begin
=
block_work_multi_id
[
2
]
*
WoPerBlock
;
const
index_t
ho_block_data_begin
=
h_block_work_id
*
HoPerBlock
;
const
index_t
n_block_data_begin
=
block_work_multi_id
[
3
]
*
NPerBlock
;
const
index_t
wo_block_data_begin
=
w_block_work_id
*
WoPerBlock
;
const
index_t
n_block_data_begin
=
n_block_work_id
*
NPerBlock
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
hi_block_data_begin
=
ho_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
const
index_t
wi_block_data_begin
=
wo_block_data_begin
;
// global tensor view
// global tensor view
constexpr
auto
wei_c_k_global_desc
=
constexpr
auto
wei_c_k_global_desc
=
wei_c_y_x_k_global_desc
.
Extract
(
I0
,
I3
);
make_ConstantTensorDescriptor
(
Sequence
<
C
,
K
>
{},
Sequence
<
Y
*
X
*
K
,
1
>
{});
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
...
@@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -360,13 +357,12 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
wo_thread_data_begin
=
c_thread_mtx_begin
.
col
/
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
const
index_t
n_thread_data_begin
=
c_thread_mtx_begin
.
col
%
NPerBlock
;
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
f_dummy
)
{
// f_dummy do nothing but
static_if
<
GemmNPerThreadSubC
<=
NPerBlock
>
{}([
&
](
auto
fwd
)
{
// perfect forwarding.
// fwd do nothing but perfect forwarding.
// Using this trick to
// Using this trick to make this lambda a generic lambda, so it won't be compiled until
// make this lambda a generic lambda, so it won't be compiled until
// being instantiated here
// instantiated
static_assert
(
static_assert
(
(
f
_dummy
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
(
f
wd
(
GemmNPerThreadSubC
)
<=
NPerBlock
&&
NPerBlock
%
GemmNPerThreadSubC
==
0
),
"wrong!"
);
"wrong!"
);
// output is a 10d tensor
// output is a 10d tensor
...
@@ -374,37 +370,32 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -374,37 +370,32 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
N1
=
NPerBlock
/
N2
;
constexpr
index_t
W2
=
constexpr
index_t
W2
=
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f
_dummy
(
NPerBlock
/
GemmNPerThreadSubC
);
(
GemmNLevel0Cluster
*
GemmNLevel1Cluster
)
/
f
wd
(
NPerBlock
/
GemmNPerThreadSubC
);
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
W1
=
WoPerBlock
/
W2
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
constexpr
auto
out_10d_global_desc
=
fwd
(
out_k_h_w_n_global_desc
)
make_ConstantTensorDescriptor
(
Sequence
<
K
/
(
K1
*
K2
),
.
Fold
(
I3
,
Number
<
N1
>
{},
Number
<
N2
>
{})
K1
,
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
W2
>
{})
K2
,
.
Fold
(
I0
,
Number
<
K1
>
{},
Number
<
K2
>
{});
Ho
,
Wo
/
(
W1
*
W2
),
constexpr
auto
out_10d_thread_desc
=
fwd
(
out_k_h_w_n_thread_desc
)
W1
,
.
Fold
(
I3
,
Number
<
1
>
{},
Number
<
N2
>
{})
W2
,
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{})
N
/
f_dummy
(
N1
*
N2
),
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
N1
,
N2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
1
,
1
,
N2
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"
out_k_h_w_n_thread_desc");
"a:
out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "
a:
out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"
out_k_h_w_n_global_desc");
"a:
out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "
a:
out_10d_global_desc");
}
}
#endif
#endif
...
@@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -419,8 +410,8 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite_N
>
{});
Number
<
OutThreadCopyDataPerWrite_N
>
{});
}).
else_
([
&
](
auto
f
_dummy
)
{
}).
else_
([
&
](
auto
f
wd
)
{
static_assert
(
f
_dummy
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
static_assert
(
f
wd
(
GemmNPerThreadSubC
)
>=
NPerBlock
&&
NPerThread
==
NPerBlock
&&
GemmNPerThreadSubC
%
NPerThread
==
0
,
GemmNPerThreadSubC
%
NPerThread
==
0
,
"wrong!"
);
"wrong!"
);
...
@@ -429,32 +420,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
...
@@ -429,32 +420,33 @@ struct GridwiseConvolutionImplicitGemm_v1r3_lds_double_buffer_chwn_cyxk_khwn
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W3
=
GemmNPerThreadSubC
/
NPerBlock
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W2
=
GemmNLevel0Cluster
*
GemmNLevel1Cluster
;
constexpr
index_t
W1
=
WoPerBlock
/
f
_dummy
(
W2
*
W3
);
constexpr
index_t
W1
=
WoPerBlock
/
f
wd
(
W2
*
W3
);
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K2
=
GemmMPerThreadSubC
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
index_t
K1
=
KPerBlock
/
KPerThread
;
constexpr
auto
out_10d_global_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_global_desc
=
Sequence
<
K
/
(
K1
*
K2
),
K1
,
K2
,
Ho
,
Wo
/
(
W1
*
W2
*
W3
),
W1
,
W2
,
W3
,
N
/
N1
,
N1
>
{});
fwd
(
out_k_h_w_n_global_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
W2
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
K1
>
{},
Number
<
K2
>
{});
constexpr
auto
out_10d_thread_desc
=
make_ConstantTensorDescriptor
(
constexpr
auto
out_10d_thread_desc
=
Sequence
<
KPerThread
/
K2
,
1
,
K2
,
HoPerThread
,
1
,
W1
,
1
,
W3
,
1
,
N1
>
{});
fwd
(
out_k_h_w_n_thread_desc
)
.
Fold
(
I3
,
Number
<
N1
>
{})
.
Fold
(
I2
,
Number
<
W1
>
{},
Number
<
1
>
{},
Number
<
W3
>
{})
.
Fold
(
I0
,
Number
<
1
>
{},
Number
<
K2
>
{});
#if 0
#if 0
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
if(get_thread_local_1d_id() == 0 && get_block_1d_id() == 0)
{
{
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_thread_desc,
"
out_k_h_w_n_thread_desc");
"b:
out_k_h_w_n_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "out_10d_thread_desc");
print_ConstantTensorDescriptor(out_10d_thread_desc, "
b:
out_10d_thread_desc");
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
print_ConstantTensorDescriptor(out_k_h_w_n_global_desc,
"out_k_h_w_n_global_desc");
"b: out_k_h_w_n_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "out_10d_global_desc");
print_ConstantTensorDescriptor(out_10d_global_desc, "b: out_10d_global_desc");
for(index_t i = 0; i < 64; ++i)
{
printf("out %f, ", p_out_thread[i]);
}
}
}
#endif
#endif
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment