Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
eb6405ee
"include/conv_common.hpp" did not exist on "88b77181aab1198b41b612f6d03b6dfb2d32bd40"
Commit
eb6405ee
authored
Jul 04, 2022
by
rocking
Browse files
Sync the naming
parent
e9a41755
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
100 additions
and
98 deletions
+100
-98
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
+41
-41
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
+59
-57
No files found.
include/ck/tensor_operation/gpu/device/device_layernorm.hpp
View file @
eb6405ee
...
...
@@ -32,13 +32,13 @@ template <typename XDataType,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
In
SrcVectorDim
,
index_t
In
SrcVectorSize
,
index_t
X
SrcVectorDim
,
index_t
X
SrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
Out
DstVectorSize
>
index_t
Y
DstVectorSize
>
struct
DeviceLayernorm
:
public
BaseOperator
{
static_assert
(
...
...
@@ -71,36 +71,36 @@ struct DeviceLayernorm : public BaseOperator
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
In
SrcVectorDim
,
In
SrcVectorSize
,
1
>
;
//
Out
DstVectorSize
X
SrcVectorDim
,
X
SrcVectorSize
,
1
>
;
//
Y
DstVectorSize
using
GridDesc_M_K
=
decltype
(
Reduction
::
MakeSrc2dDescriptor
({
1
},
{
1
},
1
,
1
));
using
GridwiseReduce
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
In
SrcVectorDim
,
In
SrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
Out
DstVectorSize
,
false
>
;
using
GridwiseReduce
LayernormGeneric
=
GridwiseLayernorm_mk_to_mk
<
XDataType
,
GammaDataType
,
BetaDataType
,
YDataType
,
AccDataType
,
GridDesc_M_K
,
BlockSize
,
MThreadClusterSize
,
KThreadClusterSize
,
MThreadSliceSize
,
KThreadSliceSize
,
X
SrcVectorDim
,
X
SrcVectorSize
,
GammaSrcVectorDim
,
GammaSrcVectorSize
,
BetaSrcVectorDim
,
BetaSrcVectorSize
,
Y
DstVectorSize
,
false
>
;
struct
Argument
:
public
Reduction
::
Argument
{
Argument
(
const
std
::
vector
<
index_t
>
inL
engths
,
const
std
::
vector
<
index_t
>
in
Strides
,
Argument
(
const
std
::
vector
<
index_t
>
l
engths
,
const
std
::
vector
<
index_t
>
x
Strides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
...
...
@@ -109,8 +109,8 @@ struct DeviceLayernorm : public BaseOperator
const
GammaDataType
*
p_gamma
,
const
BetaDataType
*
p_beta
,
YDataType
*
p_y
)
:
Reduction
::
Argument
(
inL
engths
,
in
Strides
,
:
Reduction
::
Argument
(
l
engths
,
x
Strides
,
{},
{},
reduceDims
,
...
...
@@ -142,16 +142,16 @@ struct DeviceLayernorm : public BaseOperator
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
const
auto
in
_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
x
_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
gamma_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
gammaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
beta_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
betaStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
out
_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
const
auto
y
_grid_desc_m_k
=
Reduction
::
MakeSrc2dDescriptor
(
arg
.
inLengths_
,
arg
.
inStrides_
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
);
const
auto
kernel_main
=
kernel_layernorm
<
GridwiseReduce
,
const
auto
kernel_main
=
kernel_layernorm
<
GridwiseReduce
LayernormGeneric
,
XDataType
,
GammaDataType
,
BetaDataType
,
...
...
@@ -166,10 +166,10 @@ struct DeviceLayernorm : public BaseOperator
dim3
(
arg
.
gridSize
),
dim3
(
BlockSize
),
0
,
in
_grid_desc_m_k
,
x
_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
out
_grid_desc_m_k
,
y
_grid_desc_m_k
,
arg
.
blkGroupSize
,
arg
.
numBlockTileIteration
,
arg
.
epsilon_
,
...
...
@@ -197,7 +197,7 @@ struct DeviceLayernorm : public BaseOperator
return
false
;
}
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
Out
DstVectorSize
!=
0
)
if
(
p_arg_
->
inLengths_
[
Rank
-
1
]
%
Y
DstVectorSize
!=
0
)
{
return
false
;
}
...
...
@@ -241,19 +241,19 @@ struct DeviceLayernorm : public BaseOperator
return
true
;
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
inL
engths
,
const
std
::
vector
<
index_t
>
in
Strides
,
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
l
engths
,
const
std
::
vector
<
index_t
>
x
Strides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
betaStrides
,
const
std
::
vector
<
int
>
reduceDims
,
const
std
::
vector
<
in
dex_
t
>
reduceDims
,
AccDataType
epsilon
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_beta
,
void
*
p_y
)
{
return
std
::
make_unique
<
Argument
>
(
inL
engths
,
in
Strides
,
return
std
::
make_unique
<
Argument
>
(
l
engths
,
x
Strides
,
gammaStrides
,
betaStrides
,
reduceDims
,
...
...
@@ -274,7 +274,7 @@ struct DeviceLayernorm : public BaseOperator
str
<<
"DeviceLayernorm<"
<<
BlockSize
<<
","
;
str
<<
"M_C"
<<
MThreadClusterSize
<<
"_S"
<<
MThreadSliceSize
<<
","
;
str
<<
"K_C"
<<
KThreadClusterSize
<<
"_S"
<<
KThreadSliceSize
<<
","
;
str
<<
"
In
SrcVectorDim_"
<<
In
SrcVectorDim
<<
"_
In
SrcVectorSize_"
<<
In
SrcVectorSize
<<
"_
Out
DstVectorSize_"
<<
Out
DstVectorSize
<<
">"
;
str
<<
"
X
SrcVectorDim_"
<<
X
SrcVectorDim
<<
"_
X
SrcVectorSize_"
<<
X
SrcVectorSize
<<
"_
Y
DstVectorSize_"
<<
Y
DstVectorSize
<<
">"
;
// clang-format on
return
str
.
str
();
...
...
include/ck/tensor_operation/gpu/grid/gridwise_layernorm.hpp
View file @
eb6405ee
...
...
@@ -21,10 +21,10 @@ template <typename GridwiseReduction,
typename
YDataType
,
typename
AccDataType
,
typename
GridDesc_M_K
>
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
in
_grid_desc_m_k
,
__global__
void
kernel_layernorm
(
const
GridDesc_M_K
x
_grid_desc_m_k
,
const
GridDesc_M_K
gamma_grid_desc_m_k
,
const
GridDesc_M_K
beta_grid_desc_m_k
,
const
GridDesc_M_K
out
_grid_desc_m_k
,
const
GridDesc_M_K
y
_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -33,10 +33,10 @@ __global__ void kernel_layernorm(const GridDesc_M_K in_grid_desc_m_k,
const
BetaDataType
*
const
__restrict__
p_beta_global
,
YDataType
*
const
__restrict__
p_y_global
)
{
GridwiseReduction
::
Run
(
in
_grid_desc_m_k
,
GridwiseReduction
::
Run
(
x
_grid_desc_m_k
,
gamma_grid_desc_m_k
,
beta_grid_desc_m_k
,
out
_grid_desc_m_k
,
y
_grid_desc_m_k
,
block_group_size
,
num_k_block_tile_iteration
,
epsilon
,
...
...
@@ -57,22 +57,22 @@ template <typename XDataType,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
index_t
In
SrcVectorDim
,
index_t
In
SrcVectorSize
,
index_t
X
SrcVectorDim
,
index_t
X
SrcVectorSize
,
index_t
GammaSrcVectorDim
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorDim
,
index_t
BetaSrcVectorSize
,
index_t
Out
DstVectorSize
,
index_t
Y
DstVectorSize
,
bool
SweepOnce
>
struct
GridwiseLayernorm_mk_to_mk
{
static_assert
(((
In
SrcVectorDim
==
0
&&
MThreadSliceSize
%
In
SrcVectorSize
==
0
)
||
(
In
SrcVectorDim
==
1
&&
KThreadSliceSize
%
In
SrcVectorSize
==
0
))
&&
(
KThreadSliceSize
%
Out
DstVectorSize
==
0
),
static_assert
(((
X
SrcVectorDim
==
0
&&
MThreadSliceSize
%
X
SrcVectorSize
==
0
)
||
(
X
SrcVectorDim
==
1
&&
KThreadSliceSize
%
X
SrcVectorSize
==
0
))
&&
(
KThreadSliceSize
%
Y
DstVectorSize
==
0
),
"Invalid thread slice sizes and/or vector sizes configuration, please check!"
);
static
constexpr
bool
reorder_thread_cluster
=
(
In
SrcVectorDim
==
0
);
static
constexpr
bool
reorder_thread_cluster
=
(
X
SrcVectorDim
==
0
);
using
ThreadClusterLengths_M_K
=
Sequence
<
MThreadClusterSize
,
KThreadClusterSize
>
;
...
...
@@ -115,10 +115,10 @@ struct GridwiseLayernorm_mk_to_mk
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
__device__
static
void
Run
(
const
GridDesc_M_K
&
in
_grid_desc_m_k
,
__device__
static
void
Run
(
const
GridDesc_M_K
&
x
_grid_desc_m_k
,
const
GridDesc_M_K
&
gamma_grid_desc_m_k
,
const
GridDesc_M_K
&
beta_grid_desc_m_k
,
const
GridDesc_M_K
&
out
_grid_desc_m_k
,
const
GridDesc_M_K
&
y
_grid_desc_m_k
,
index_t
block_group_size
,
index_t
num_k_block_tile_iteration
,
AccDataType
epsilon
,
...
...
@@ -135,14 +135,14 @@ struct GridwiseLayernorm_mk_to_mk
// LDS
__shared__
AccDataType
p_reduce_work_buffer
[
BlockSize
];
auto
out
_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
out
_grid_desc_m_k
.
GetElementSpaceSize
());
auto
y
_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_y_global
,
y
_grid_desc_m_k
.
GetElementSpaceSize
());
auto
reduce_work_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Lds
>
(
p_reduce_work_buffer
,
BlockSize
);
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
in
_thread_buf
;
x
_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
gamma_thread_buf
;
...
...
@@ -151,12 +151,12 @@ struct GridwiseLayernorm_mk_to_mk
beta_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>
out
_thread_buf
;
y
_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
*
KThreadSliceSize
,
true
>&
in
_square_thread_buf
=
out
_thread_buf
;
true
>&
x
_square_thread_buf
=
y
_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
mean_thread_buf
;
StaticBuffer
<
AddressSpaceEnum
::
Vgpr
,
AccDataType
,
MThreadSliceSize
,
true
>
...
...
@@ -192,11 +192,11 @@ struct GridwiseLayernorm_mk_to_mk
decltype
(
thread_buffer_desc
),
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
In
SrcVectorDim
,
In
SrcVectorSize
,
X
SrcVectorDim
,
X
SrcVectorSize
,
1
,
true
>
(
in
_grid_desc_m_k
,
x
_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
));
...
...
@@ -238,24 +238,26 @@ struct GridwiseLayernorm_mk_to_mk
PassThroughOp
,
ThreadBufferLengths
,
ThreadBufferDimAccessOrder
,
In
SrcVectorDim
,
Out
DstVectorSize
,
X
SrcVectorDim
,
Y
DstVectorSize
,
InMemoryDataOperationEnum
::
Set
,
1
,
true
>
(
out
_grid_desc_m_k
,
y
_grid_desc_m_k
,
make_multi_index
(
blkgroup_id
*
M_BlockTileSize
+
thread_m_cluster_id
*
MThreadSliceSize
,
block_local_id
*
reduceSizePerBlock
+
thread_k_cluster_id
*
KThreadSliceSize
),
PassThroughOp
{});
constexpr
auto
in_thread_copy_fwd_step
=
// Copy x from Cache
// one pass: fwd, second pass: bwd
constexpr
auto
thread_copy_fwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
K_BlockTileSize
);
constexpr
auto
in_
thread_copy_bwd_step
=
constexpr
auto
thread_copy_bwd_step
=
make_multi_index
(
0
,
SweepOnce
?
0
:
-
K_BlockTileSize
);
const
auto
in
_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
in
_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
x
_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_x_global
,
x
_grid_desc_m_k
.
GetElementSpaceSize
());
const
auto
gamma_global_val_buf
=
make_dynamic_buffer
<
AddressSpaceEnum
::
Global
>
(
p_gamma_global
,
gamma_grid_desc_m_k
.
GetElementSpaceSize
());
...
...
@@ -264,28 +266,28 @@ struct GridwiseLayernorm_mk_to_mk
p_beta_global
,
beta_grid_desc_m_k
.
GetElementSpaceSize
());
// E(x), E[x^2], var(x)
int
reduce_length
=
in
_grid_desc_m_k
.
GetLength
(
I1
);
int
reduce_length
=
x
_grid_desc_m_k
.
GetLength
(
I1
);
index_t
reducedTiles
=
0
;
do
{
threadwise_x_load
.
Run
(
in
_grid_desc_m_k
,
in
_global_val_buf
,
threadwise_x_load
.
Run
(
x
_grid_desc_m_k
,
x
_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in
_thread_buf
);
x
_thread_buf
);
static_for
<
0
,
MThreadSliceSize
,
1
>
{}([
&
](
auto
iM
)
{
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
in
_square_thread_buf
(
Number
<
offset
>
{})
=
in
_thread_buf
(
Number
<
offset
>
{})
*
in
_thread_buf
(
Number
<
offset
>
{});
x
_square_thread_buf
(
Number
<
offset
>
{})
=
x
_thread_buf
(
Number
<
offset
>
{})
*
x
_thread_buf
(
Number
<
offset
>
{});
});
});
ThreadwiseSumReduce
::
Reduce
(
in
_thread_buf
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
in
_square_thread_buf
,
mean_square_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x
_thread_buf
,
mean_thread_buf
);
ThreadwiseSumReduce
::
Reduce
(
x
_square_thread_buf
,
mean_square_thread_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
in_
thread_copy_fwd_step
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_fwd_step
);
++
reducedTiles
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
...
...
@@ -303,23 +305,23 @@ struct GridwiseLayernorm_mk_to_mk
});
// y = (x - E[x]) / sqrt(var[x] + epsilon)
auto
thread_copy_tail
=
(
num_k_block_tile_iteration
-
1
)
*
in_
thread_copy_fwd_step
;
auto
thread_copy_tail
=
(
num_k_block_tile_iteration
-
1
)
*
thread_copy_fwd_step
;
threadwise_x_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
in_
thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
thread_copy_tail
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
thread_copy_tail
);
threadwise_y_store
.
MoveDstSliceWindow
(
out
_grid_desc_m_k
,
thread_copy_tail
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_tail
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_tail
);
threadwise_y_store
.
MoveDstSliceWindow
(
y
_grid_desc_m_k
,
thread_copy_tail
);
reducedTiles
=
0
;
do
{
if
constexpr
(
!
SweepOnce
)
{
threadwise_x_load
.
Run
(
in
_grid_desc_m_k
,
in
_global_val_buf
,
threadwise_x_load
.
Run
(
x
_grid_desc_m_k
,
x
_global_val_buf
,
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
in
_thread_buf
);
x
_thread_buf
);
}
threadwise_gamma_load
.
Run
(
gamma_grid_desc_m_k
,
...
...
@@ -338,27 +340,27 @@ struct GridwiseLayernorm_mk_to_mk
static_for
<
0
,
KThreadSliceSize
,
1
>
{}([
&
](
auto
iK
)
{
constexpr
auto
offset
=
thread_buffer_desc
.
CalculateOffset
(
make_tuple
(
iM
,
iK
));
// normalize
out
_thread_buf
(
Number
<
offset
>
{})
=
(
in
_thread_buf
(
Number
<
offset
>
{})
-
mean_thread_buf
(
iM
))
/
y
_thread_buf
(
Number
<
offset
>
{})
=
(
x
_thread_buf
(
Number
<
offset
>
{})
-
mean_thread_buf
(
iM
))
/
sqrt
(
var_value_buf
(
iM
)
+
epsilon
);
// affine
out
_thread_buf
(
Number
<
offset
>
{})
=
out
_thread_buf
(
Number
<
offset
>
{})
*
gamma_thread_buf
(
Number
<
offset
>
{})
+
y
_thread_buf
(
Number
<
offset
>
{})
=
y
_thread_buf
(
Number
<
offset
>
{})
*
gamma_thread_buf
(
Number
<
offset
>
{})
+
beta_thread_buf
(
Number
<
offset
>
{});
});
});
threadwise_y_store
.
Run
(
thread_buffer_desc
,
make_tuple
(
I0
,
I0
),
out
_thread_buf
,
out
_grid_desc_m_k
,
out
_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
in_
thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
in_
thread_copy_bwd_step
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
in
_grid_desc_m_k
,
in_
thread_copy_bwd_step
);
threadwise_y_store
.
MoveDstSliceWindow
(
out
_grid_desc_m_k
,
in_
thread_copy_bwd_step
);
y
_thread_buf
,
y
_grid_desc_m_k
,
y
_global_val_buf
);
threadwise_x_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_gamma_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_beta_load
.
MoveSrcSliceWindow
(
x
_grid_desc_m_k
,
thread_copy_bwd_step
);
threadwise_y_store
.
MoveDstSliceWindow
(
y
_grid_desc_m_k
,
thread_copy_bwd_step
);
++
reducedTiles
;
}
while
(
reducedTiles
<
num_k_block_tile_iteration
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment