Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
17f3d2d4
"...git@developer.sourcefind.cn:modelzoo/gpt2_migraphx.git" did not exist on "741ac4aedba78d7584822ce9afb03f5d5ea440e5"
Commit
17f3d2d4
authored
Apr 16, 2019
by
Chao Liu
Browse files
refactor ConstantTensorDescriptor and functional
parent
a2cf803c
Changes
22
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
383 additions
and
269 deletions
+383
-269
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
+6
-6
driver/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
...er/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
+6
-6
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
...device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
+15
-15
driver/driver.hip.cpp
driver/driver.hip.cpp
+12
-12
src/include/Array.hip.hpp
src/include/Array.hip.hpp
+3
-1
src/include/ConstantTensorDescriptor.hip.hpp
src/include/ConstantTensorDescriptor.hip.hpp
+26
-45
src/include/Sequence.hip.hpp
src/include/Sequence.hip.hpp
+20
-0
src/include/blockwise_2d_tensor_op.hip.hpp
src/include/blockwise_2d_tensor_op.hip.hpp
+3
-6
src/include/blockwise_3d_tensor_op.hip.hpp
src/include/blockwise_3d_tensor_op.hip.hpp
+3
-6
src/include/blockwise_4d_tensor_op.hip.hpp
src/include/blockwise_4d_tensor_op.hip.hpp
+126
-18
src/include/blockwise_batched_gemm.hip.hpp
src/include/blockwise_batched_gemm.hip.hpp
+2
-3
src/include/blockwise_direct_convolution.hip.hpp
src/include/blockwise_direct_convolution.hip.hpp
+12
-15
src/include/functional.hip.hpp
src/include/functional.hip.hpp
+39
-9
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
+23
-26
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...plicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+12
-14
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
...ise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
+25
-26
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
...implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
+7
-8
src/include/gridwise_direct_convolution_1.hip.hpp
src/include/gridwise_direct_convolution_1.hip.hpp
+9
-10
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
...lude/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
+19
-24
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
...se_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
+15
-19
No files found.
driver/device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp
View file @
17f3d2d4
...
@@ -8,12 +8,12 @@
...
@@ -8,12 +8,12 @@
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
InDesc
,
void
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
driver/device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp
View file @
17f3d2d4
...
@@ -7,12 +7,12 @@
...
@@ -7,12 +7,12 @@
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
template
<
class
T
,
class
InDesc
,
class
WeiDesc
,
class
OutDesc
>
void
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
(
InDesc
,
void
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
(
InDesc
,
const
Tensor
<
T
>&
in_nchw
,
const
Tensor
<
T
>&
in_nchw
,
WeiDesc
,
WeiDesc
,
const
Tensor
<
T
>&
wei_kcyx
,
const
Tensor
<
T
>&
wei_kcyx
,
OutDesc
,
OutDesc
,
Tensor
<
T
>&
out_nkhw
,
Tensor
<
T
>&
out_nkhw
,
index_t
nrepeat
)
index_t
nrepeat
)
{
{
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
driver/device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp
View file @
17f3d2d4
...
@@ -52,7 +52,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
...
@@ -52,7 +52,7 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
vector_t
::
Pack
(
in_nchw
(
n
,
2
*
c
,
h
,
w
),
in_nchw
(
n
,
2
*
c
+
1
,
h
,
w
));
vector_t
::
Pack
(
in_nchw
(
n
,
2
*
c
,
h
,
w
),
in_nchw
(
n
,
2
*
c
+
1
,
h
,
w
));
#elif 1
#elif 1
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
vector_t
::
Pack
(
in_nchw
(
n
,
4
*
c
,
h
,
w
),
in_nchw_vec
(
n
,
c
,
h
,
w
)
=
vector_t
::
Pack
(
in_nchw
(
n
,
4
*
c
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
1
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
1
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
2
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
2
,
h
,
w
),
in_nchw
(
n
,
4
*
c
+
3
,
h
,
w
));
in_nchw
(
n
,
4
*
c
+
3
,
h
,
w
));
...
@@ -114,37 +114,37 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
...
@@ -114,37 +114,37 @@ void device_direct_convolution_2_vectorized_nchw_kcyx_nkhw(InDesc,
constexpr index_t BlockSize = 128;
constexpr index_t BlockSize = 128;
#elif
0
#elif
0
// 3x3, 34x34, 128 thread, fp32, vector = 2
// 3x3, 34x34, 128 thread, fp32, vector = 2
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
CPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
HoPerBlock
=
2
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
NPerThread
=
2
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
KPerThread
=
4
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
CPerThread
=
1
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
HoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
#elif 0
#elif 0
// 3x3, 34x34, 128 thread, int8, vector = 4
// 3x3, 34x34, 128 thread, int8, vector = 4
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
NPerBlock
=
2
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
KPerBlock
=
32
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
CPerBlock
=
8
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
HoPerBlock
=
4
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
WoPerBlock
=
32
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
NPerThread
=
1
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
KPerThread
=
8
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
CPerThread
=
2
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
HoPerThread
=
4
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
WoPerThread
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
InBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
WeiBlockCopyDataPerRead
=
2
;
constexpr
index_t
BlockSize
=
128
;
constexpr
index_t
BlockSize
=
128
;
...
...
driver/driver.hip.cpp
View file @
17f3d2d4
...
@@ -11,6 +11,7 @@
...
@@ -11,6 +11,7 @@
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
#include "device_direct_convolution_2_nchw_kcyx_nkhw.hpp"
//#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
//#include "device_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hpp"
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v1_nchw_cyxk_khwn.hpp"
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
//#include "device_implicit_gemm_convolution_1_chwn_cyxk_khwn_padded.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
#include "device_convolution_implicit_gemm_v2_chwn_cyxk_khwn.hpp"
...
@@ -48,13 +49,10 @@ struct GeneratorTensor_3
...
@@ -48,13 +49,10 @@ struct GeneratorTensor_3
#if 0
#if 0
auto f_acc = std::plus<index_t>{};
auto f_acc = std::plus<index_t>{};
#else
#else
auto
f_acc
=
[](
auto
a
,
auto
b
){
return
10
*
a
+
b
;};
auto
f_acc
=
[](
auto
a
,
auto
b
)
{
return
10
*
a
+
b
;
};
#endif
#endif
return
std
::
accumulate
(
dims
.
begin
(),
return
std
::
accumulate
(
dims
.
begin
(),
dims
.
end
(),
index_t
(
0
),
f_acc
);
dims
.
end
(),
index_t
(
0
),
f_acc
);
}
}
};
};
...
@@ -376,7 +374,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
...
@@ -376,7 +374,7 @@ void host_winograd_3x3_convolution(const Tensor<TIn>& in_nchw,
std
::
size_t
ho
=
HoPerTile
*
htile
+
j
;
std
::
size_t
ho
=
HoPerTile
*
htile
+
j
;
for
(
int
i
=
0
;
i
<
WoPerTile
;
++
i
)
for
(
int
i
=
0
;
i
<
WoPerTile
;
++
i
)
{
{
std
::
size_t
wo
=
WoPerTile
*
wtile
+
i
;
std
::
size_t
wo
=
WoPerTile
*
wtile
+
i
;
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_hold
(
n
,
k
,
htile
,
wtile
,
j
,
i
);
out_nkhw
(
n
,
k
,
ho
,
wo
)
=
out_hold
(
n
,
k
,
htile
,
wtile
,
j
,
i
);
}
}
}
}
...
@@ -435,13 +433,13 @@ int main(int argc, char* argv[])
...
@@ -435,13 +433,13 @@ int main(int argc, char* argv[])
constexpr index_t WPad = 0;
constexpr index_t WPad = 0;
#elif
0
#elif
0
// 3x3, 56x56
// 3x3, 56x56
constexpr
index_t
N
=
64
;
constexpr
index_t
N
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
C
=
64
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
HI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
WI
=
56
;
constexpr
index_t
K
=
128
;
constexpr
index_t
K
=
128
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
HPad
=
0
;
constexpr
index_t
WPad
=
0
;
constexpr
index_t
WPad
=
0
;
...
@@ -505,7 +503,7 @@ int main(int argc, char* argv[])
...
@@ -505,7 +503,7 @@ int main(int argc, char* argv[])
constexpr
index_t
C
=
256
;
constexpr
index_t
C
=
256
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
HI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
WI
=
28
;
constexpr
index_t
K
=
5
12
;
constexpr
index_t
K
=
12
8
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
Y
=
3
;
constexpr
index_t
X
=
3
;
constexpr
index_t
X
=
3
;
...
@@ -666,6 +664,8 @@ int main(int argc, char* argv[])
...
@@ -666,6 +664,8 @@ int main(int argc, char* argv[])
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
device_direct_convolution_2_vectorized_nchw_kcyx_nkhw
#elif 1
#elif 1
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
device_convolution_implicit_gemm_v1_chwn_cyxk_khwn
#elif 0
device_convolution_implicit_gemm_v1_nchw_cyxk_khwn
#elif 0
#elif 0
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
device_convolution_implicit_gemm_v2_chwn_cyxk_khwn
#endif
#endif
...
...
src/include/Array.hip.hpp
View file @
17f3d2d4
...
@@ -14,5 +14,7 @@ struct Array
...
@@ -14,5 +14,7 @@ struct Array
{
{
}
}
__host__
__device__
TData
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
const
TData
&
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
__host__
__device__
TData
&
operator
[](
index_t
i
)
{
return
mData
[
i
];
}
};
};
src/include/ConstantTensorDescriptor.hip.hpp
View file @
17f3d2d4
...
@@ -115,46 +115,27 @@ struct ConstantTensorDescriptor
...
@@ -115,46 +115,27 @@ struct ConstantTensorDescriptor
static_assert
(
Lengths
::
nDim
==
Strides
::
nDim
,
"nDim not consistent"
);
static_assert
(
Lengths
::
nDim
==
Strides
::
nDim
,
"nDim not consistent"
);
}
}
__host__
__device__
constexpr
index_t
GetDimension
()
const
{
return
nDim
;
}
__host__
__device__
static
constexpr
index_t
GetDimension
()
{
return
nDim
;
}
__host__
__device__
constexpr
Lengths
GetLengths
()
const
{
return
Lengths
{};
}
__host__
__device__
static
constexpr
Lengths
GetLengths
()
{
return
Lengths
{};
}
__host__
__device__
constexpr
Strides
GetStrides
()
const
{
return
Strides
{};
}
__host__
__device__
static
constexpr
Strides
GetStrides
()
{
return
Strides
{};
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
GetLength
(
Number
<
I
>
)
const
__host__
__device__
static
constexpr
index_t
GetLength
(
Number
<
I
>
)
{
{
return
Lengths
{}.
Get
(
Number
<
I
>
{});
return
Lengths
{}.
Get
(
Number
<
I
>
{});
}
}
template
<
index_t
I
>
template
<
index_t
I
>
__host__
__device__
constexpr
index_t
GetStride
(
Number
<
I
>
)
const
__host__
__device__
static
constexpr
index_t
GetStride
(
Number
<
I
>
)
{
{
return
Strides
{}.
Get
(
Number
<
I
>
{});
return
Strides
{}.
Get
(
Number
<
I
>
{});
}
}
// c++14 doesn't support constexpr lambdas, has to use this trick instead
__host__
__device__
static
constexpr
index_t
GetElementSize
()
struct
GetElementSize_f
{
{
template
<
class
IDim
>
return
accumulate_on_sequence
(
Lengths
{},
mod_conv
::
multiplies
<
index_t
>
{},
Number
<
1
>
{});
__host__
__device__
constexpr
index_t
operator
()(
IDim
idim
)
const
{
return
Type
{}.
GetLength
(
idim
);
}
};
__host__
__device__
constexpr
index_t
GetElementSize
()
const
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
multiply
{
__host__
__device__
constexpr
index_t
operator
()(
index_t
a
,
index_t
b
)
const
{
return
a
*
b
;
}
};
return
static_const_reduce_n
<
nDim
>
{}(
GetElementSize_f
{},
multiply
{});
}
}
// c++14 doesn't support constexpr lambdas, has to use this trick instead
// c++14 doesn't support constexpr lambdas, has to use this trick instead
...
@@ -168,25 +149,16 @@ struct ConstantTensorDescriptor
...
@@ -168,25 +149,16 @@ struct ConstantTensorDescriptor
};
};
template
<
class
Align
=
Number
<
1
>
>
template
<
class
Align
=
Number
<
1
>
>
__host__
__device__
constexpr
index_t
GetElementSpace
(
Align
align
=
Align
{})
const
__host__
__device__
static
constexpr
index_t
GetElementSpace
(
Align
align
=
Align
{})
{
{
// c++14 doesn't support constexpr lambdas, has to use this trick instead
struct
add
{
__host__
__device__
constexpr
index_t
operator
()(
index_t
a
,
index_t
b
)
const
{
return
a
+
b
;
}
};
index_t
element_space_unaligned
=
index_t
element_space_unaligned
=
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
add
{})
+
1
;
static_const_reduce_n
<
nDim
>
{}(
GetElementSpace_f
{},
mod_conv
::
plus
<
index_t
>
{})
+
1
;
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
return
align
.
Get
()
*
((
element_space_unaligned
+
align
.
Get
()
-
1
)
/
align
.
Get
());
}
}
template
<
class
...
Is
>
template
<
class
...
Is
>
__host__
__device__
index_t
Get1dIndex
(
Is
...
is
)
const
__host__
__device__
static
index_t
Get1dIndex
(
Is
...
is
)
{
{
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
static_assert
(
sizeof
...(
Is
)
==
nDim
,
"number of multi-index is wrong"
);
...
@@ -194,7 +166,7 @@ struct ConstantTensorDescriptor
...
@@ -194,7 +166,7 @@ struct ConstantTensorDescriptor
index_t
id
=
0
;
index_t
id
=
0
;
static_
loop_n
<
nDim
>
{}([
&
](
auto
IDim
)
{
static_
for
<
0
,
nDim
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
.
Get
();
constexpr
index_t
idim
=
IDim
.
Get
();
#if DEVICE_BACKEND_HIP
#if DEVICE_BACKEND_HIP
id
+=
__mul24
(
multi_id
[
idim
],
GetStride
(
IDim
));
id
+=
__mul24
(
multi_id
[
idim
],
GetStride
(
IDim
));
...
@@ -206,16 +178,25 @@ struct ConstantTensorDescriptor
...
@@ -206,16 +178,25 @@ struct ConstantTensorDescriptor
return
id
;
return
id
;
}
}
__host__
__device__
constexpr
auto
Condense
()
const
__host__
__device__
static
Array
<
index_t
,
nDim
>
GetMultiIndex
(
index_t
id
)
{
{
constexpr
auto
default_strides
=
calculate_default_strides
(
Lengths
{});
Array
<
index_t
,
nDim
>
multi_id
;
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
static_for
<
0
,
nDim
-
1
,
1
>
{}([
&
](
auto
IDim
)
{
constexpr
index_t
idim
=
IDim
.
Get
();
multi_id
[
idim
]
=
id
/
GetStride
(
IDim
);
id
-=
multi_id
[
idim
]
*
GetStride
(
IDim
);
});
multi_id
[
nDim
-
1
]
=
id
/
GetStride
(
Number
<
nDim
-
1
>
{});
return
multi_id
;
}
}
template
<
index_t
IDim
,
index_t
NVector
>
__host__
__device__
static
constexpr
auto
Condense
()
__host__
__device__
constexpr
auto
Vectorize
(
Number
<
IDim
>
,
Number
<
NVector
>
)
const
{
{
assert
(
false
);
// not implemented
constexpr
auto
default_strides
=
calculate_default_strides
(
Lengths
{});
return
ConstantTensorDescriptor
<
Lengths
,
decltype
(
default_strides
)
>
{};
}
}
};
};
...
...
src/include/Sequence.hip.hpp
View file @
17f3d2d4
...
@@ -17,6 +17,8 @@ struct Sequence
...
@@ -17,6 +17,8 @@ struct Sequence
return
mData
[
I
];
return
mData
[
I
];
}
}
__host__
__device__
index_t
operator
[](
index_t
i
)
const
{
return
mData
[
i
];
}
// this is ugly, only for nDIm = 4
// this is ugly, only for nDIm = 4
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
template
<
index_t
I0
,
index_t
I1
,
index_t
I2
,
index_t
I3
>
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
__host__
__device__
constexpr
auto
ReorderByGetNewFromOld
(
Sequence
<
I0
,
I1
,
I2
,
I3
>
)
const
...
@@ -90,3 +92,21 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
...
@@ -90,3 +92,21 @@ __host__ __device__ constexpr auto Sequence<Is...>::PopBack() const
{
{
return
sequence_pop_back
(
Type
{});
return
sequence_pop_back
(
Type
{});
}
}
template
<
class
Seq
>
struct
accumulate_on_sequence_f
{
template
<
class
IDim
>
__host__
__device__
constexpr
index_t
operator
()(
IDim
)
const
{
return
Seq
{}.
Get
(
IDim
{});
}
};
template
<
class
Seq
,
class
Reduce
,
index_t
I
>
__host__
__device__
constexpr
index_t
accumulate_on_sequence
(
Seq
,
Reduce
,
Number
<
I
>
)
{
constexpr
index_t
a
=
static_const_reduce_n
<
Seq
::
nDim
>
{}(
accumulate_on_sequence_f
<
Seq
>
{},
Reduce
{});
return
Reduce
{}(
a
,
I
);
}
src/include/blockwise_2d_tensor_op.hip.hpp
View file @
17f3d2d4
...
@@ -211,8 +211,7 @@ struct Blockwise2dTensorCopy1
...
@@ -211,8 +211,7 @@ struct Blockwise2dTensorCopy1
constexpr
index_t
read_per_d1
=
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
index_t
read_per_d1
=
integer_divide_ceil
(
L1
,
DataPerRead
);
constexpr
auto
ref_desc
=
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
read_per_d1
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
read_per_d1
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
...
@@ -225,10 +224,8 @@ struct Blockwise2dTensorCopy1
...
@@ -225,10 +224,8 @@ struct Blockwise2dTensorCopy1
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
did
[
1
]
=
is
/
ref_desc
.
GetStride
(
I1
);
const
index_t
src_index
=
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
...
...
src/include/blockwise_3d_tensor_op.hip.hpp
View file @
17f3d2d4
...
@@ -54,8 +54,7 @@ struct Blockwise3dTensorCopy1
...
@@ -54,8 +54,7 @@ struct Blockwise3dTensorCopy1
constexpr
index_t
read_per_d2
=
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
index_t
read_per_d2
=
integer_divide_ceil
(
L2
,
DataPerRead
);
constexpr
auto
ref_desc
=
constexpr
auto
ref_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
read_per_d2
>
{});
make_ConstantTensorDescriptor
(
Sequence
<
L0
,
L1
,
read_per_d2
>
{});
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
...
@@ -72,10 +71,8 @@ struct Blockwise3dTensorCopy1
...
@@ -72,10 +71,8 @@ struct Blockwise3dTensorCopy1
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
did
[
2
]
=
is
/
ref_desc
.
GetStride
(
I2
);
const
index_t
src_index
=
const
index_t
src_index
=
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
src_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
const
index_t
dst_index
=
dst_desc
.
Get1dIndex
(
did
[
0
],
did
[
1
],
did
[
2
]
*
DataPerRead
);
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
vector_t
*>
(
p_dst
+
dst_index
))
=
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
*
(
reinterpret_cast
<
const
vector_t
*>
(
p_src
+
src_index
));
...
...
src/include/blockwise_4d_tensor_op.hip.hpp
View file @
17f3d2d4
...
@@ -340,11 +340,10 @@ struct BlockwiseChwnTensorCopyPadded
...
@@ -340,11 +340,10 @@ struct BlockwiseChwnTensorCopyPadded
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
constexpr
index_t
NLoop
=
ref_desc
.
GetElementSize
()
/
BlockSize
;
const
Float
*
p_src_tmp
=
const
Float
*
p_src_tmp
=
p_src
+
p_src
+
src_desc
.
Get1dIndex
(
c_block_data_begin
,
src_desc
.
Get1dIndex
(
c_block_data_begin
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
ho_block_data_begin
+
h_block_pad_low
)
-
h_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
(
wo_block_data_begin
+
w_block_pad_low
)
-
w_global_pad_low
,
n_block_data_begin
);
n_block_data_begin
);
#if 0
#if 0
if(get_thread_local_1d_id() == 0)
if(get_thread_local_1d_id() == 0)
...
@@ -494,7 +493,7 @@ struct Blockwise4dTensorCopy3
...
@@ -494,7 +493,7 @@ struct Blockwise4dTensorCopy3
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
"wrrong! BlockSize is not big enough for ThreadPerDims!"
);
constexpr
index_t
num_active_thread
=
constexpr
index_t
num_active_thread
=
thread_per_d0
*
thread_per_d1
*
thread_per_d2
*
thread_per_d3
;
accumulate_on_sequence
(
ThreadPerDims
{},
mod_conv
::
multiplies
<
index_t
>
{},
Number
<
1
>
{})
;
if
(
BlockSize
>
num_active_thread
)
if
(
BlockSize
>
num_active_thread
)
{
{
...
@@ -504,19 +503,18 @@ struct Blockwise4dTensorCopy3
...
@@ -504,19 +503,18 @@ struct Blockwise4dTensorCopy3
}
}
}
}
const
index_t
thread_id_d0
=
constexpr
auto
thread_cluster_desc
=
make_ConstantTensorDescriptor
(
ThreadPerDims
{});
get_thread_local_1d_id
()
/
(
thread_per_d1
*
thread_per_d2
*
thread_per_d3
);
const
auto
thread_multi_id
=
thread_cluster_desc
.
GetMultiIndex
(
get_thread_local_1d_id
());
index_t
itmp
=
get_thread_local_1d_id
()
-
thread_id_d0
*
(
thread_per_d1
*
thread_per_d2
*
thread_per_d3
);
const
index_t
thread_id_d1
=
itmp
/
(
thread_per_d2
*
thread_per_d3
);
itmp
-=
thread_id_d1
*
(
thread_per_d2
*
thread_per_d3
);
const
index_t
thread_id_d2
=
itmp
/
thread_per_d3
;
const
index_t
thread_id_d3
=
itmp
-
thread_id_d2
*
thread_per_d3
;
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
mSrcMyThreadOffset
=
SrcDesc
{}.
Get1dIndex
(
thread_multi_id
[
0
],
thread_id_d0
,
thread_id_d1
,
thread_id_d2
,
thread_id_d3
*
DataPerRead
);
thread_multi_id
[
1
],
mDstMyThreadOffset
=
DstDesc
{}.
Get1dIndex
(
thread_multi_id
[
2
],
thread_id_d0
,
thread_id_d1
,
thread_id_d2
,
thread_id_d3
*
DataPerRead
);
thread_multi_id
[
3
]
*
DataPerRead
);
mDstMyThreadOffset
=
DstDesc
{}.
Get1dIndex
(
thread_multi_id
[
0
],
thread_multi_id
[
1
],
thread_multi_id
[
2
],
thread_multi_id
[
3
]
*
DataPerRead
);
}
}
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
...
@@ -745,3 +743,113 @@ struct Blockwise4dTensorCopy3
...
@@ -745,3 +743,113 @@ struct Blockwise4dTensorCopy3
}
}
}
}
};
};
template
<
index_t
BlockSize
,
class
Float
,
class
SrcDesc
,
class
DstDesc
,
class
SrcOpLengths
,
class
DstFromSrcReorder
>
struct
Blockwise4dTensorCopyReorder1
{
__device__
void
Run
(
const
Float
*
__restrict__
p_src
,
Float
*
__restrict__
p_dst
)
const
{
auto
f_copy
=
[](
const
Float
&
src
,
Float
&
dst
)
{
dst
=
src
;
};
blockwise_4d_tensor_pointwise_operation_binary_reorder_by_get_dst_from_src
<
BlockSize
>
(
SrcDesc
{},
p_src
,
DstDesc
{},
p_dst
,
SrcOpLengths
{},
DstFromSrcReorder
{},
f_copy
);
}
};
#if 0
template <index_t BlockSize,
class Float,
class SrcDesc,
class DstDesc,
class SrcLengths,
class SrcSubLengths,
class SrcThreadPerDims,
class DstFromSrcReorder,
index_t DataPerRead,
index_t DataPerWrite>
struct Blockwise4dTensorCopyReorder3
{
index_t mSrcMyThreadOffset;
index_t mDstMyThreadOffset;
__device__ Blockwise4dTensorCopyReorder3()
{
constexpr index_t nDim = SrcDesc{}.GetDimension();
static_assert(DstDesc{}.GetDimension() == nDim && SrcOpLengths::nDim == nDim &&
SrcOpThreadPerDims::nDim == nDim && DstFromSrcReorder::nDim == nDim,
"wrong! nDim is not consistent\n");
// Src
static_assert(DataPerRead == 1 || DataPerRead == 2 || DataPerRead == 4,
"wrong! only support DataPerRead == 1, 2 or 4!\n");
static_assert(DataPerRead == 1 || SrcDesc{}.GetStride(Number<nDim-1>{}) == 1,
"wrong! only support src.stride(nDim-1) == 1 if DataPerRead > 1!\n");
static_assert(
SrcDesc{}.GetStride(Number<nDim-2>{}) % DataPerRead == 0,
"wrong! src.stride(nDim-2) should be multiple of DataPerRead to keep alignment");
static_assert(SrcSubLengths{}.Get(Number<nDim-1>{}) % DataPerRead == 0, "wrong! SrcSubLengths[nDim-1] % DataPerRead != 0\n");
static_loop<nDim-1>([](auto I){
constexpr index_t src_len = SrcLengths{}.Get(I);
constexpr index_t src_sub_len = SrcSubLengths{}.Get(I);
constexpr index_t thread_per_dim = SrcThreadPerDims{}.Get(I);
static_assert(src_len % (src_sub_len * thread_per_dim) == 0,
"wrong! cannot evenly divide tensor lengths");
});
constexpr index_t num_active_thread = accumulate_on_sequence(SrcOpThreadPerDims{}, mod_conv::multiplies<index_t>{}, Number<1>{});
static_assert(BlockSize >= num_active_thread,
"wrong! BlockSize is not big enough for ThreadPerDims!");
if(BlockSize > num_active_thread)
{
if(get_thread_local_1d_id() >= num_active_thread)
{
return;
}
}
const auto thread_multi_id = SrcOpThreadPerDims::GetMultiIndex(get_thread_local_1d_id());
const index_t thread_id_d0 =
get_thread_local_1d_id() / (thread_per_d1 * thread_per_d2 * thread_per_d3);
index_t itmp = get_thread_local_1d_id() -
thread_id_d0 * (thread_per_d1 * thread_per_d2 * thread_per_d3);
const index_t thread_id_d1 = itmp / (thread_per_d2 * thread_per_d3);
itmp -= thread_id_d1 * (thread_per_d2 * thread_per_d3);
const index_t thread_id_d2 = itmp / thread_per_d3;
const index_t thread_id_d3 = itmp - thread_id_d2 * thread_per_d3;
mSrcMyThreadOffset = SrcDesc{}.Get1dIndex(
thread_id_d0, thread_id_d1, thread_id_d2, thread_id_d3 * DataPerRead);
}
__device__ static constexpr index_t GetRegisterClipboardSize()
{
static_assert(is_same<Float, float>::value, "wrong! only support float!\n");
}
__device__ void RunLoadRegisterClipboard(const Float* __restrict__ p_src,
Float* __restrict__ p_clipboard) const
{
}
__device__ void RunStoreRegisterClipboard(const Float* __restrict__ p_clipboard,
Float* __restrict__ p_dst) const
{
}
};
#endif
src/include/blockwise_batched_gemm.hip.hpp
View file @
17f3d2d4
...
@@ -393,9 +393,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
...
@@ -393,9 +393,8 @@ struct BlockwiseBatchGemmBlockABlockBThreadCTransANormalBNormalC_V2
{
{
threadwise_matrix_copy
(
threadwise_matrix_copy
(
c_thread_sub_mtx
,
c_thread_sub_mtx
,
p_c_thread
+
p_c_thread
+
c_thread_sub_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
c_thread_sub_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
n_repeat
*
NPerLevel1Cluster
),
n_repeat
*
NPerLevel1Cluster
),
c_block_mtx
,
c_block_mtx
,
p_c_block
+
p_c_block
+
c_block_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
c_block_mtx
.
Get1dIndex
(
m_repeat
*
MPerLevel1Cluster
,
...
...
src/include/blockwise_direct_convolution.hip.hpp
View file @
17f3d2d4
...
@@ -93,11 +93,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -93,11 +93,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
Float
p_out_thread
[
out_thread_desc
.
GetElementSpace
()];
threadwise_4d_tensor_copy
(
out_block_desc
,
threadwise_4d_tensor_copy
(
out_block_desc
,
p_out_block
+
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
wo_thread_data_begin
),
out_thread_desc
,
out_thread_desc
,
p_out_thread
,
p_out_thread
,
out_thread_desc
.
GetLengths
());
out_thread_desc
.
GetLengths
());
...
@@ -108,11 +107,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -108,11 +107,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
// threadwise convolution
// threadwise convolution
threadwise_direct_convolution_2
(
threadwise_direct_convolution_2
(
in_thread_block_desc
,
in_thread_block_desc
,
p_in_block
+
p_in_block
+
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data_begin
,
c_thread_data_begin
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_thread_block_desc
,
wei_thread_block_desc
,
p_wei_block
+
p_wei_block
+
wei_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data_begin
,
0
,
0
),
wei_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data_begin
,
0
,
0
),
...
@@ -124,11 +122,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
...
@@ -124,11 +122,10 @@ __device__ void blockwise_direct_convolution(InBlockDesc,
threadwise_4d_tensor_copy
(
out_thread_desc
,
threadwise_4d_tensor_copy
(
out_thread_desc
,
p_out_thread
,
p_out_thread
,
out_block_desc
,
out_block_desc
,
p_out_block
+
p_out_block
+
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
out_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
k_thread_data_begin
,
k_thread_data_begin
,
ho_thread_data_begin
,
ho_thread_data_begin
,
wo_thread_data_begin
),
wo_thread_data_begin
),
out_thread_desc
.
GetLengths
());
out_thread_desc
.
GetLengths
());
}
}
}
}
src/include/functional.hip.hpp
View file @
17f3d2d4
#pragma once
#pragma once
#include "constant_integral.hip.hpp"
#include "constant_integral.hip.hpp"
template
<
index_t
NLoop
>
template
<
index_t
Iter
,
index_t
Remaining
,
index_t
Increment
>
struct
static_
loop_n
struct
static_
for_impl
{
{
template
<
class
F
>
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
__host__
__device__
void
operator
()(
F
f
)
const
{
{
static_assert
(
NLoop
>
1
,
"out-of-range"
);
static_assert
(
Remaining
%
Increment
==
0
,
"wrong! Remaining % Increment != 0"
);
static_assert
(
Increment
<=
Remaining
,
"will go out-of-range"
);
f
(
Number
<
NLoop
-
1
>
{});
f
(
Number
<
Iter
>
{});
static_
loop_n
<
NLoop
-
1
>
{}(
f
);
static_
for_impl
<
Iter
+
Increment
,
Remaining
-
Increment
,
Increment
>
{}(
f
);
}
}
};
};
template
<
>
template
<
index_t
Iter
,
index_t
Increment
>
struct
static_loop_n
<
1
>
struct
static_for_impl
<
Iter
,
0
,
Increment
>
{
template
<
class
F
>
__host__
__device__
void
operator
()(
F
)
const
{
// do nothing
return
;
}
};
template
<
index_t
NBegin
,
index_t
NEnd
,
index_t
Increment
>
struct
static_for
{
{
template
<
class
F
>
template
<
class
F
>
__host__
__device__
void
operator
()(
F
f
)
const
__host__
__device__
void
operator
()(
F
f
)
const
{
{
f
(
Number
<
0
>
{});
static_assert
(
NBegin
<
NEnd
,
"Wrong! we should have NBegin < NEnd"
);
static_assert
((
NEnd
-
NBegin
)
%
Increment
==
0
,
"Wrong! should satisfy (NEnd - NBegin) % Increment == 0"
);
static_for_impl
<
NBegin
,
NEnd
-
NBegin
,
Increment
>
{}(
f
);
}
}
};
};
...
@@ -54,4 +69,19 @@ __host__ __device__ constexpr auto unpacker(F f)
...
@@ -54,4 +69,19 @@ __host__ __device__ constexpr auto unpacker(F f)
{
{
return [=](auto xs_array){ f(xs...); };
return [=](auto xs_array){ f(xs...); };
}
}
#endif
#endif
\ No newline at end of file
namespace
mod_conv
{
template
<
class
T
>
struct
multiplies
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
template
<
class
T
>
struct
plus
{
__host__
__device__
constexpr
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
}
// namespace mod_conv
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn.hip.hpp
View file @
17f3d2d4
...
@@ -99,8 +99,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
...
@@ -99,8 +99,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
// tensor view of blockwise input and weight in LDS
// tensor view of blockwise input and weight in LDS
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
mod_conv
::
max
(
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_chwn_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HiPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
...
@@ -135,16 +135,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
...
@@ -135,16 +135,15 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
InBlockCopyDataPerRead
>
{};
InBlockCopyDataPerRead
>
{};
#endif
#endif
// blockwise wei copy
// blockwise wei copy
// format is [CPerBlock*Y*X,KPerBlock]
// format is [CPerBlock*Y*X,KPerBlock]
const
auto
blockwise_wei_copy
=
const
auto
blockwise_wei_copy
=
Blockwise2dTensorCopy3
<
BlockSize
,
Blockwise2dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_global_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
decltype
(
wei_ek_block_desc
.
GetLengths
()),
WeiBlockCopyDataPerRead
>
{};
WeiBlockCopyDataPerRead
>
{};
// a series of blockwise batched GEMM
// a series of blockwise batched GEMM
// C_matrix += transpose(A_matrix) * B_matrix
// C_matrix += transpose(A_matrix) * B_matrix
...
@@ -202,9 +201,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
...
@@ -202,9 +201,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_khwn_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
...
@@ -323,17 +321,16 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
...
@@ -323,17 +321,16 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn
}
}
#endif
#endif
threadwise_10d_tensor_copy
(
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
out_10d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_10d_global_desc
,
out_10d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
p_out_global
+
k_block_data_begin
+
k_thread_data_begin
,
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
#endif
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v1r1_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
17f3d2d4
...
@@ -190,9 +190,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
...
@@ -190,9 +190,8 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
__shared__
Float
p_wei_block_double
[
2
*
wei_block_space
];
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
p_in_global
+
in_chwn_global_desc
.
Get1dIndex
(
in_chwn_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
p_wei_global
+
wei_cyxk_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
...
@@ -393,17 +392,16 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
...
@@ -393,17 +392,16 @@ struct GridwiseConvolutionImplicitGemm_v1r1_chwn_cyxk_khwn_lds_double_buffer
}
}
#endif
#endif
threadwise_10d_tensor_copy
(
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
out_10d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_10d_global_desc
,
out_10d_global_desc
,
p_out_global
+
out_khwn_global_desc
.
Get1dIndex
(
p_out_global
+
k_block_data_begin
+
k_thread_data_begin
,
out_khwn_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
#endif
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v1r2_chwn_cyxk_khwn.hip.hpp
View file @
17f3d2d4
...
@@ -101,8 +101,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -101,8 +101,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// LDS tensor view
// LDS tensor view
// be careful of alignment
// be careful of alignment
constexpr
index_t
max_align
=
constexpr
index_t
max_align
=
mod_conv
::
max
(
mod_conv
::
max
(
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
InBlockCopyDataPerRead
,
WeiBlockCopyDataPerRead
,
GemmDataPerReadA
,
GemmDataPerReadB
);
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
constexpr
auto
in_c_h_w_n_block_desc
=
make_ConstantTensorDescriptor_aligned
(
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
Sequence
<
CPerBlock
,
HoPerBlock
,
WiPerBlock
,
NPerBlock
>
{},
Number
<
max_align
>
{});
...
@@ -116,8 +116,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -116,8 +116,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// blockwise copy
// blockwise copy
// input: format is [C, Hi, Wi, N]
// input: format is [C, Hi, Wi, N]
const
auto
blockwise_in_copy
=
#if 0
#if 0
const auto blockwise_in_copy =
Blockwise4dTensorCopy1<BlockSize,
Blockwise4dTensorCopy1<BlockSize,
Float,
Float,
decltype(in_c_h_w_n_global_desc),
decltype(in_c_h_w_n_global_desc),
...
@@ -125,6 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -125,6 +125,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
decltype(in_c_h_w_n_block_desc.GetLengths()),
decltype(in_c_h_w_n_block_desc.GetLengths()),
InBlockCopyDataPerRead>{};
InBlockCopyDataPerRead>{};
#else
#else
const
auto
blockwise_in_copy
=
Blockwise4dTensorCopy3
<
BlockSize
,
Blockwise4dTensorCopy3
<
BlockSize
,
Float
,
Float
,
decltype
(
in_c_h_w_n_global_desc
),
decltype
(
in_c_h_w_n_global_desc
),
...
@@ -150,10 +151,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -150,10 +151,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// A_matrix[C,K] is a sub-matrix of wei_block[C,K]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// B_matrix[C,Wo*N] is a sub-matrix of in_block[C,Hi,Wi,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
// C_matrix[K,Wo*N] is a sub-matrix of out_block[K,Ho,Wo,N]
constexpr
auto
a_c_k_block_mtx_desc
=
constexpr
auto
a_c_k_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
Number
<
CPerBlock
>
{},
Number
<
KPerBlock
>
{},
Number
<
wei_c_x_k_block_desc
.
GetStride
(
I0
)
>
{});
Number
<
KPerBlock
>
{},
Number
<
wei_c_x_k_block_desc
.
GetStride
(
I0
)
>
{});
constexpr
auto
b_c_wn_block_mtx_desc
=
constexpr
auto
b_c_wn_block_mtx_desc
=
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
make_ConstantMatrixDescriptor
(
Number
<
CPerBlock
>
{},
...
@@ -187,8 +186,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -187,8 +186,10 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
GemmDataPerReadB
>
{};
GemmDataPerReadB
>
{};
// LDS: be careful of alignment
// LDS: be careful of alignment
constexpr
index_t
in_block_space
=
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
in_block_space
=
constexpr
index_t
wei_block_space
=
wei_c_x_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
in_c_h_w_n_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
constexpr
index_t
wei_block_space
=
wei_c_x_k_block_desc
.
GetElementSpace
(
Number
<
max_align
>
{});
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_in_block
[
in_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
__shared__
Float
p_wei_block
[
wei_block_space
];
...
@@ -213,9 +214,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -213,9 +214,8 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
threadwise_4d_tensor_set_zero
(
out_k_h_w_n_thread_desc
,
p_out_thread
);
const
Float
*
p_in_global_block_offset
=
const
Float
*
p_in_global_block_offset
=
p_in_global
+
p_in_global
+
in_c_h_w_n_global_desc
.
Get1dIndex
(
in_c_h_w_n_global_desc
.
Get1dIndex
(
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
0
,
hi_block_data_begin
,
wi_block_data_begin
,
n_block_data_begin
);
const
Float
*
p_wei_global_block_offset
=
const
Float
*
p_wei_global_block_offset
=
p_wei_global
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
p_wei_global
+
wei_c_y_x_k_global_desc
.
Get1dIndex
(
0
,
0
,
0
,
k_block_data_begin
);
...
@@ -227,7 +227,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -227,7 +227,7 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
for
(
index_t
y
=
0
;
y
<
Y
;
++
y
)
{
{
blockwise_in_copy
.
Run
(
p_in_global_block_offset
+
blockwise_in_copy
.
Run
(
p_in_global_block_offset
+
in_c_h_w_n_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
in_c_h_w_n_global_desc
.
Get1dIndex
(
0
,
y
,
0
,
0
),
p_in_block
);
p_in_block
);
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
blockwise_wei_copy
.
Run
(
p_wei_global_block_offset
+
...
@@ -239,9 +239,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -239,9 +239,9 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
for
(
index_t
x
=
0
;
x
<
X
;
++
x
)
{
{
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_c_x_k_block_desc
.
Get1dIndex
(
0
,
x
,
0
),
blockwise_batch_gemm
.
Run
(
p_wei_block
+
wei_c_x_k_block_desc
.
Get1dIndex
(
0
,
x
,
0
),
p_in_block
+
in_c_h_w_n_block_desc
.
Get1dIndex
(
0
,
0
,
x
,
0
),
p_in_block
+
in_c_h_w_n_block_desc
.
Get1dIndex
(
0
,
0
,
x
,
0
),
p_out_thread
);
p_out_thread
);
}
}
__syncthreads
();
__syncthreads
();
...
@@ -321,17 +321,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
...
@@ -321,17 +321,16 @@ struct GridwiseConvolutionImplicitGemm_v1r2_chwn_cyxk_khwn
}
}
#endif
#endif
threadwise_10d_tensor_copy
(
threadwise_10d_tensor_copy
(
out_10d_thread_desc
,
out_10d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_10d_global_desc
,
out_10d_global_desc
,
p_out_global
+
out_k_h_w_n_global_desc
.
Get1dIndex
(
p_out_global
+
k_block_data_begin
+
k_thread_data_begin
,
out_k_h_w_n_global_desc
.
Get1dIndex
(
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
,
n_block_data_begin
+
n_thread_data_begin
),
n_block_data_begin
+
n_thread_data_begin
),
out_10d_thread_desc
.
GetLengths
(),
out_10d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite
>
{});
#endif
#endif
}
}
};
};
src/include/gridwise_convolution_implicit_gemm_v2_chwn_cyxk_khwn_lds_double_buffer.hip.hpp
View file @
17f3d2d4
...
@@ -365,14 +365,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
...
@@ -365,14 +365,13 @@ struct GridwiseConvolutionImplicitGemm_v2_chwn_cyxk_khwn_lds_double_buffer
constexpr
auto
out_kb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
B
>
{});
constexpr
auto
out_kb_global_desc
=
make_ConstantTensorDescriptor
(
Sequence
<
K
,
B
>
{});
threadwise_6d_tensor_copy
(
threadwise_6d_tensor_copy
(
out_6d_thread_desc
,
out_6d_thread_desc
,
p_out_thread
,
p_out_thread
,
out_6d_global_desc
,
out_6d_global_desc
,
p_out_global
+
out_kb_global_desc
.
Get1dIndex
(
p_out_global
+
k_thread_data_begin
,
b_thread_data_begin
),
out_kb_global_desc
.
Get1dIndex
(
k_thread_data_begin
,
b_thread_data_begin
),
out_6d_thread_desc
.
GetLengths
(),
out_6d_thread_desc
.
GetLengths
(),
Number
<
OutThreadCopyDataPerWrite
>
{});
Number
<
OutThreadCopyDataPerWrite
>
{});
}
}
else
else
{
{
...
...
src/include/gridwise_direct_convolution_1.hip.hpp
View file @
17f3d2d4
...
@@ -113,11 +113,10 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
...
@@ -113,11 +113,10 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
c_block_work_begin
+=
CPerBlock
)
c_block_work_begin
+=
CPerBlock
)
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
blockwise_in_copy
.
Run
(
p_in_global
+
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
in_global_desc
.
Get1dIndex
(
n_block_work_begin
,
c_block_work_begin
,
c_block_work_begin
,
hi_block_work_begin
,
hi_block_work_begin
,
wi_block_work_begin
),
wi_block_work_begin
),
p_in_block
);
p_in_block
);
// copy weight tensor to LDS
// copy weight tensor to LDS
...
@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
...
@@ -144,9 +143,9 @@ __global__ void gridwise_direct_convolution_1(const Float* const __restrict__ p_
}
}
// copy output tensor from LDS to device mem
// copy output tensor from LDS to device mem
blockwise_out_copy
.
Run
(
blockwise_out_copy
.
Run
(
p_out_block
,
p_out_block
,
p_out_global
+
out_global_desc
.
Get1dIndex
(
n_block_work_begin
,
p_out_global
+
k_block_work_begin
,
out_global_desc
.
Get1dIndex
(
ho_block_work_begin
,
n_block_work_begin
,
k_block_work_begin
,
ho_block_work_begin
,
wo_block_work_begin
));
wo_block_work_begin
));
}
}
src/include/gridwise_direct_convolution_2_nchw_kcyx_nkhw.hip.hpp
View file @
17f3d2d4
...
@@ -175,18 +175,16 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
...
@@ -175,18 +175,16 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
c_block_data_begin
+=
CPerBlock
,
__syncthreads
())
{
{
// copy input tensor to LDS
// copy input tensor to LDS
blockwise_in_copy
.
Run
(
p_in_global
+
blockwise_in_copy
.
Run
(
p_in_global
+
in_nchw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
in_nchw_global_desc
.
Get1dIndex
(
n_block_data_begin
,
c_block_data_begin
,
c_block_data_begin
,
hi_block_data_begin
,
hi_block_data_begin
,
wi_block_data_begin
),
wi_block_data_begin
),
p_in_block
);
p_in_block
);
// copy weight tensor to LDS
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
blockwise_wei_copy
.
Run
(
p_wei_global
+
wei_kcyx_global_desc
.
Get1dIndex
(
p_wei_global
+
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
wei_kcyx_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_block
);
p_wei_block
);
__syncthreads
();
__syncthreads
();
...
@@ -196,11 +194,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
...
@@ -196,11 +194,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#if 1
#if 1
threadwise_direct_convolution_2
(
threadwise_direct_convolution_2
(
in_nchw_thread_block_desc
,
in_nchw_thread_block_desc
,
p_in_block
+
p_in_block
+
in_nchw_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_nchw_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
wei_kcyx_thread_block_desc
,
p_wei_block
+
p_wei_block
+
wei_kcyx_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
wei_kcyx_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
...
@@ -209,11 +206,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
...
@@ -209,11 +206,10 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
#elif 0
#elif 0
threadwise_direct_convolution_3
(
threadwise_direct_convolution_3
(
in_nchw_thread_block_desc
,
in_nchw_thread_block_desc
,
p_in_block
+
p_in_block
+
in_nchw_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_nchw_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_kcyx_thread_block_desc
,
wei_kcyx_thread_block_desc
,
p_wei_block
+
p_wei_block
+
wei_kcyx_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
wei_kcyx_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
...
@@ -228,10 +224,9 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
...
@@ -228,10 +224,9 @@ gridwise_direct_convolution_2_nchw_kcyx_nkhw(const Float* const __restrict__ p_i
out_nkhw_thread_desc
,
out_nkhw_thread_desc
,
p_out_thread
,
p_out_thread
,
out_nkhw_global_desc
,
out_nkhw_global_desc
,
p_out_global
+
p_out_global
+
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
out_nkhw_thread_desc
.
GetLengths
());
}
}
src/include/gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw.hip.hpp
View file @
17f3d2d4
...
@@ -198,10 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -198,10 +198,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
p_in_vec_block
);
p_in_vec_block
);
// copy weight tensor to LDS
// copy weight tensor to LDS
blockwise_wei_copy
.
Run
(
blockwise_wei_copy
.
Run
(
p_wei_vec_global
+
wei_kcyx_vec_global_desc
.
Get1dIndex
(
p_wei_vec_global
+
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
wei_kcyx_vec_global_desc
.
Get1dIndex
(
k_block_data_begin
,
c_block_data_begin
,
0
,
0
),
p_wei_vec_block
);
p_wei_vec_block
);
__syncthreads
();
__syncthreads
();
...
@@ -211,11 +210,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -211,11 +210,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#if 1
#if 1
threadwise_direct_convolution_2
(
threadwise_direct_convolution_2
(
in_nchw_vec_thread_block_desc
,
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
p_in_vec_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
...
@@ -224,11 +222,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -224,11 +222,10 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
#elif 0
#elif 0
threadwise_direct_convolution_3
(
threadwise_direct_convolution_3
(
in_nchw_vec_thread_block_desc
,
in_nchw_vec_thread_block_desc
,
p_in_vec_block
+
p_in_vec_block
+
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
in_nchw_vec_block_desc
.
Get1dIndex
(
n_thread_data_begin
,
c_thread_data
,
c_thread_data
,
hi_thread_data_begin
,
hi_thread_data_begin
,
wi_thread_data_begin
),
wi_thread_data_begin
),
wei_kcyx_vec_thread_block_desc
,
wei_kcyx_vec_thread_block_desc
,
p_wei_vec_block
+
p_wei_vec_block
+
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
wei_kcyx_vec_block_desc
.
Get1dIndex
(
k_thread_data_begin
,
c_thread_data
,
0
,
0
),
...
@@ -243,10 +240,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
...
@@ -243,10 +240,9 @@ __global__ void gridwise_direct_convolution_2_vectorized_nchw_kcyx_nkhw(
out_nkhw_thread_desc
,
out_nkhw_thread_desc
,
p_out_thread
,
p_out_thread
,
out_nkhw_global_desc
,
out_nkhw_global_desc
,
p_out_global
+
p_out_global
+
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
out_nkhw_global_desc
.
Get1dIndex
(
n_block_data_begin
+
n_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
k_block_data_begin
+
k_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
ho_block_data_begin
+
ho_thread_data_begin
,
wo_block_data_begin
+
wo_thread_data_begin
),
wo_block_data_begin
+
wo_thread_data_begin
),
out_nkhw_thread_desc
.
GetLengths
());
out_nkhw_thread_desc
.
GetLengths
());
}
}
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment