Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
bd34d666
Commit
bd34d666
authored
May 25, 2022
by
rocking
Browse files
Add 5ary elementwise for normalization
parent
980ed33a
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
512 additions
and
0 deletions
+512
-0
include/ck/tensor_operation/gpu/device/device_5ary_elementwise_xdl_cshuffle.hpp
...ation/gpu/device/device_5ary_elementwise_xdl_cshuffle.hpp
+261
-0
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
...ensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
+251
-0
No files found.
include/ck/tensor_operation/gpu/device/device_
normaliz
e_xdl_cshuffle.hpp
→
include/ck/tensor_operation/gpu/device/device_
5ary_elementwis
e_xdl_cshuffle.hpp
View file @
bd34d666
...
...
@@ -4,6 +4,7 @@
#include "device.hpp"
#include "device_base.hpp"
#include "common_header.hpp"
#include "gridwise_5ary_Elementwise_1d.hpp"
#include "tensor_layout.hpp"
#include "tensor_descriptor.hpp"
#include "tensor_descriptor_helper.hpp"
...
...
@@ -12,24 +13,23 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
// Input: x, E[x], E[x^2], Gamma, Beta
// Output: {[(x - E[x]) / sqrt(E[x^2] - E[x]^2 + epsilon)] * gamma} + beta
template
<
typename
XDataType
,
typename
MeanDataType
,
typename
MeanSquareDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
OutDataType
,
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
ComputeDataType
,
typename
Out
ElementwiseFunctor
,
typename
ElementwiseFunctor
,
index_t
NDim
,
index_t
MPerThread
,
index_t
XScalarPerVector
,
index_t
MeanScalarPerVector
,
index_t
MeanSquareScalarPerVector
,
index_t
GammaScalarPerVector
,
index_t
BetaScalarPerVector
>
struct
DeviceNormalize_Xdl_CShuffle
:
public
BaseOperator
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
struct
Device5AryElementwise_Xdl_CShuffle
:
public
BaseOperator
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
...
@@ -73,68 +73,96 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
return
PadDescriptor_M_1d
(
desc
,
gridSize
,
blockSize
);
}
using
GridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
AGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
BGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
CGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
DGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
EGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
FGridDesc_M
=
decltype
(
MakeDescriptor_M
({
1
,
1
},
{
1
,
1
},
1
,
1
));
using
Gridwise5AryEltwise
=
Gridwise5AryElementwise_1D
<
ADataType
,
BDataType
,
CDataType
,
DDataType
,
EDataType
,
FDataType
,
ComputeDataType
,
AGridDesc_M
,
BGridDesc_M
,
CGridDesc_M
,
DGridDesc_M
,
EGridDesc_M
,
FGridDesc_M
,
ElementwiseFunctor
,
MPerThread
,
AScalarPerVector
,
BScalarPerVector
,
CScalarPerVector
,
DScalarPerVector
,
EScalarPerVector
,
FScalarPerVector
>
;
struct
Argument
:
public
BaseArgument
{
Argument
(
const
X
DataType
*
p_
x
,
const
Mean
DataType
*
p_
mean
,
const
MeanSquare
DataType
*
p_
mean_square
,
const
Gamma
DataType
*
p_
gamma
,
const
Beta
DataType
*
p_
beta
,
Out
DataType
*
p_
output
,
Argument
(
const
A
DataType
*
p_
a
,
const
B
DataType
*
p_
b
,
const
C
DataType
*
p_
c
,
const
D
DataType
*
p_
d
,
const
E
DataType
*
p_
e
,
F
DataType
*
p_
f
,
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
stride
_x
,
const
std
::
vector
<
index_t
>&
stride
_mean
,
const
std
::
vector
<
index_t
>&
stride
_mean_square
,
const
std
::
vector
<
index_t
>&
stride
_gamma
,
const
std
::
vector
<
index_t
>&
stride
_beta
,
const
std
::
vector
<
index_t
>&
stride
_output
,
Out
ElementwiseFunctor
functor
)
:
p_
x
_
(
p_
x
),
p_
mean_
(
p_mean
),
p_
mean_square_
(
p_mean_square
),
p_
gamma_
(
p_gamma
),
p_
beta_
(
p_beta
),
p_
output_
(
p_output
),
const
std
::
vector
<
index_t
>&
a_
stride
s
,
const
std
::
vector
<
index_t
>&
b_
stride
s
,
const
std
::
vector
<
index_t
>&
c_
stride
s
,
const
std
::
vector
<
index_t
>&
d_
stride
s
,
const
std
::
vector
<
index_t
>&
e_
stride
s
,
const
std
::
vector
<
index_t
>&
f_
stride
s
,
ElementwiseFunctor
functor
)
:
p_
a
_
(
p_
a
),
p_
b_
(
p_b
),
p_
c_
(
p_c
),
p_
d_
(
p_d
),
p_
e_
(
p_e
),
p_
f_
(
p_f
),
lengths_
(
lengths
),
stride_x_
(
stride_x
),
stride_mean_
(
stride_mean
),
stride_mean_square_
(
stride_mean_square
),
stride_gamma_
(
stride_gamma
),
stride_beta_
(
stride_beta
),
a_strides_
(
a_strides
),
b_strides_
(
b_strides
),
c_strides_
(
c_strides
),
d_strides_
(
d_strides
),
e_strides_
(
e_strides
),
f_strides_
(
f_strides
),
functor_
(
functor
),
blockSize_
(
256
),
gridSize_
(
120
)
// FIXME - Calculate the grid size by number of CU in the future
{
x_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_x
,
gridSize_
,
blockSize_
);
mean_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_mean
,
gridSize_
,
blockSize_
);
mean_square_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_mean_square
,
gridSize_
,
blockSize_
);
gamma_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_gamma
,
gridSize_
,
blockSize_
);
beta_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_beta
,
gridSize_
,
blockSize_
);
output_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
stride_output
,
gridSize_
,
blockSize_
);
a_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
a_strides
,
gridSize_
,
blockSize_
);
b_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
b_strides
,
gridSize_
,
blockSize_
);
c_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
c_strides
,
gridSize_
,
blockSize_
);
d_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
d_strides
,
gridSize_
,
blockSize_
);
e_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
e_strides
,
gridSize_
,
blockSize_
);
f_grid_desc_m_
=
MakeDescriptor_M
(
lengths
,
f_strides
,
gridSize_
,
blockSize_
);
}
const
X
DataType
*
p_
x
_
;
const
Mean
DataType
*
p_
mean
_
;
const
MeanSquare
DataType
*
p_
mean_square
_
;
const
Gamma
DataType
*
p_
gamma
_
;
const
Beta
DataType
*
p_
beta
_
;
Out
DataType
*
p_
output
_
;
const
A
DataType
*
p_
a
_
;
const
B
DataType
*
p_
b
_
;
const
C
DataType
*
p_
c
_
;
const
D
DataType
*
p_
d
_
;
const
E
DataType
*
p_
e
_
;
F
DataType
*
p_
f
_
;
std
::
vector
<
index_t
>
lengths_
;
GridDesc_M
x_grid_desc_m_
;
GridDesc_M
mean_grid_desc_m_
;
GridDesc_M
mean_square_grid_desc_m_
;
GridDesc_M
gamma_grid_desc_m_
;
GridDesc_M
beta_grid_desc_m_
;
GridDesc_M
output_grid_desc_m_
;
std
::
vector
<
index_t
>
stride_x_
;
std
::
vector
<
index_t
>
stride_mean_
;
std
::
vector
<
index_t
>
stride_mean_square_
;
std
::
vector
<
index_t
>
stride_gamma_
;
std
::
vector
<
index_t
>
stride_beta_
;
OutElementwiseFunctor
functor_
;
AGridDesc_M
a_grid_desc_m_
;
BGridDesc_M
b_grid_desc_m_
;
CGridDesc_M
c_grid_desc_m_
;
DGridDesc_M
d_grid_desc_m_
;
EGridDesc_M
e_grid_desc_m_
;
FGridDesc_M
f_grid_desc_m_
;
std
::
vector
<
index_t
>
a_strides_
;
std
::
vector
<
index_t
>
b_strides_
;
std
::
vector
<
index_t
>
c_strides_
;
std
::
vector
<
index_t
>
d_strides_
;
std
::
vector
<
index_t
>
e_strides_
;
std
::
vector
<
index_t
>
f_strides_
;
ElementwiseFunctor
functor_
;
index_t
blockSize_
;
index_t
gridSize_
;
};
...
...
@@ -143,10 +171,37 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// TODO
(
void
)
arg
;
(
void
)
stream_config
;
return
0
;
const
auto
kernel
=
kernel_5ary_elementwise_1d
<
Gridwise5AryEltwise
,
ADataType
,
BDataType
,
CDataType
,
DDataType
,
EDataType
,
FDataType
,
AGridDesc_M
,
BGridDesc_M
,
CGridDesc_M
,
ElementwiseFunctor
>
;
float
elapsed_time
=
launch_and_time_kernel
(
stream_config
,
kernel
,
dim3
(
arg
.
gridSize_
),
dim3
(
arg
.
blockSize_
),
0
,
arg
.
p_a_
,
arg
.
p_b_
,
arg
.
p_c_
,
arg
.
p_d_
,
arg
.
p_e_
,
arg
.
p_f_
,
arg
.
a_grid_desc_m_
,
arg
.
b_grid_desc_m_
,
arg
.
c_grid_desc_m_
,
arg
.
d_grid_desc_m_
,
arg
.
e_grid_desc_m_
,
arg
.
f_grid_desc_m_
,
arg
.
functor_
);
return
elapsed_time
;
}
// polymorphic
...
...
@@ -181,20 +236,22 @@ struct DeviceNormalize_Xdl_CShuffle : public BaseOperator
return
ret
;
};
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_x_
.
back
()
==
1
,
XScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
a_strides_
.
back
()
==
1
,
AScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
b_strides_
.
back
()
==
1
,
BScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride
_mean
_
.
back
()
==
1
,
Mean
ScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
c_
stride
s
_
.
back
()
==
1
,
C
ScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride_mean_square_
.
back
()
==
1
,
MeanSquareScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
d_strides_
.
back
()
==
1
,
DScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride
_gamma
_
.
back
()
==
1
,
Gamma
ScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
e_
stride
s
_
.
back
()
==
1
,
E
ScalarPerVector
))
return
false
;
if
(
!
IsScalarPerVectorValid
(
pArg
->
stride
_beta
_
.
back
()
==
1
,
Beta
ScalarPerVector
))
if
(
!
IsScalarPerVectorValid
(
pArg
->
f_
stride
s
_
.
back
()
==
1
,
F
ScalarPerVector
))
return
false
;
};
};
...
...
include/ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp
0 → 100644
View file @
bd34d666
#pragma once
#include "cluster_descriptor.hpp"
#include "data_type.hpp"
#include "element_wise_operation.hpp"
#include "threadwise_tensor_slice_transfer.hpp"
namespace
ck
{
template
<
typename
Gridwise5AryEltwise
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
>
__global__
void
kernel_5ary_elementwise_1d
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
Gridwise5AryEltwise
::
Run
(
p_a_global
,
p_b_global
,
p_c_global
,
p_d_global
,
p_e_global
,
p_f_global
,
a_grid_desc_m
,
b_grid_desc_m
,
c_grid_desc_m
,
d_grid_desc_m
,
e_grid_desc_m
,
f_grid_desc_m
,
functor
);
}
// TODO - implement n-ary Elemenetwise_1D, tuple of inputs and tuple of outputs
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
DDataType
,
typename
EDataType
,
typename
FDataType
,
typename
ComputeDataType
,
typename
AGridDesc_M
,
typename
BGridDesc_M
,
typename
CGridDesc_M
,
typename
DGridDesc_M
,
typename
EGridDesc_M
,
typename
FGridDesc_M
,
typename
ElementwiseFunctor
,
index_t
MPerThread
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
struct
Gridwise5AryElementwise_1D
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
thread_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MPerThread
>
{}));
using
PassThrough
=
tensor_operation
::
element_wise
::
PassThrough
;
static
__device__
auto
CalculateElementwiseIndex
()
{
const
index_t
global_thread_id
=
get_thread_global_1d_id
();
return
make_multi_index
(
global_thread_id
*
MPerThread
);
}
__device__
static
void
Run
(
const
ADataType
*
__restrict__
p_a_global
,
const
BDataType
*
__restrict__
p_b_global
,
const
CDataType
*
__restrict__
p_c_global
,
const
DDataType
*
__restrict__
p_d_global
,
const
EDataType
*
__restrict__
p_e_global
,
FDataType
*
__restrict__
p_f_global
,
const
AGridDesc_M
a_grid_desc_m
,
const
BGridDesc_M
b_grid_desc_m
,
const
CGridDesc_M
c_grid_desc_m
,
const
DGridDesc_M
d_grid_desc_m
,
const
EGridDesc_M
e_grid_desc_m
,
const
FGridDesc_M
f_grid_desc_m
,
const
ElementwiseFunctor
functor
)
{
const
auto
a_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_a_global
,
a_grid_desc_m
.
GetElementSpaceSize
());
const
auto
b_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_b_global
,
b_grid_desc_m
.
GetElementSpaceSize
());
const
auto
c_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_c_global
,
c_grid_desc_m
.
GetElementSpaceSize
());
const
auto
d_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_d_global
,
d_grid_desc_m
.
GetElementSpaceSize
());
const
auto
e_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_e_global
,
e_grid_desc_m
.
GetElementSpaceSize
());
auto
f_global_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_f_global
,
f_grid_desc_m
.
GetElementSpaceSize
());
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
a_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
b_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
c_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
d_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
e_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
ComputeDataType
,
MPerThread
,
true
>
f_thread_buf
;
const
auto
thread_store_global_offset
=
CalculateElementwiseIndex
();
auto
a_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
ADataType
,
ComputeDataType
,
AGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
AScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
a_grid_desc_m
,
thread_store_global_offset
};
auto
b_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
BDataType
,
ComputeDataType
,
BGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
BScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
b_grid_desc_m
,
thread_store_global_offset
};
auto
c_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
CDataType
,
ComputeDataType
,
CGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
CScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
c_grid_desc_m
,
thread_store_global_offset
};
auto
d_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
DDataType
,
ComputeDataType
,
DGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
DScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
d_grid_desc_m
,
thread_store_global_offset
};
auto
e_global_load
=
ThreadwiseTensorSliceTransfer_v2
<
EDataType
,
ComputeDataType
,
EGridDesc_M
,
decltype
(
thread_desc_m
),
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// SrcVectorDim
EScalarPerVector
,
// ScalarPerVector
1
,
// SrcScalarStrideInVector
false
>
{
e_grid_desc_m
,
thread_store_global_offset
};
auto
f_global_write
=
ThreadwiseTensorSliceTransfer_v1r3
<
ComputeDataType
,
FDataType
,
decltype
(
thread_desc_m
),
FGridDesc_M
,
PassThrough
,
Sequence
<
MPerThread
>
,
// SliceLengths
Sequence
<
0
>
,
// DimAccessOrder
0
,
// DstVectorDim
FScalarPerVector
,
// ScalarPerVector
InMemoryDataOperationEnum
::
Set
,
1
,
// DstScalarStrideInVector
false
>
{
f_grid_desc_m
,
thread_store_global_offset
,
PassThrough
{}};
const
index_t
blockSize
=
get_block_size
();
const
index_t
blockPerGrid
=
get_grid_size
();
const
auto
M
=
c_grid_desc_m
.
GetLength
(
I0
);
const
index_t
loop_step
=
blockPerGrid
*
blockSize
*
MPerThread
;
const
auto
loop_step_index
=
make_multi_index
(
loop_step
);
index_t
num_iter
=
M
/
(
loop_step
);
do
{
// read and process MPerThread elements
a_global_load
.
Run
(
a_grid_desc_m
,
a_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
a_thread_buf
);
b_global_load
.
Run
(
b_grid_desc_m
,
b_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
b_thread_buf
);
c_global_load
.
Run
(
c_grid_desc_m
,
c_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
c_thread_buf
);
d_global_load
.
Run
(
d_grid_desc_m
,
d_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
d_thread_buf
);
e_global_load
.
Run
(
e_grid_desc_m
,
e_global_buf
,
thread_desc_m
,
make_tuple
(
I0
),
e_thread_buf
);
static_for
<
0
,
MPerThread
,
1
>
{}([
&
](
auto
m
)
{
constexpr
auto
offset
=
thread_desc_m
.
CalculateOffset
(
make_tuple
(
m
));
functor
(
f_thread_buf
(
Number
<
offset
>
{}),
a_thread_buf
(
Number
<
offset
>
{}),
b_thread_buf
(
Number
<
offset
>
{}),
c_thread_buf
(
Number
<
offset
>
{}),
d_thread_buf
(
Number
<
offset
>
{}),
e_thread_buf
(
Number
<
offset
>
{}));
});
f_global_write
.
Run
(
thread_desc_m
,
make_tuple
(
I0
),
// SrcSliceOriginIdx
f_thread_buf
,
f_grid_desc_m
,
f_global_buf
);
a_global_load
.
MoveSrcSliceWindow
(
a_grid_desc_m
,
loop_step_index
);
b_global_load
.
MoveSrcSliceWindow
(
b_grid_desc_m
,
loop_step_index
);
c_global_load
.
MoveSrcSliceWindow
(
c_grid_desc_m
,
loop_step_index
);
d_global_load
.
MoveSrcSliceWindow
(
d_grid_desc_m
,
loop_step_index
);
e_global_load
.
MoveSrcSliceWindow
(
e_grid_desc_m
,
loop_step_index
);
f_global_write
.
MoveDstSliceWindow
(
f_grid_desc_m
,
loop_step_index
);
}
while
(
--
num_iter
);
}
};
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment