Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e508995a
Commit
e508995a
authored
Nov 14, 2023
by
rocking
Browse files
Add deviceOp to backward x
parent
b9cb4a21
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
463 additions
and
8 deletions
+463
-8
example/53_layernorm2d_bwd/layernorm2d_bwd_fp16.cpp
example/53_layernorm2d_bwd/layernorm2d_bwd_fp16.cpp
+65
-8
include/ck/tensor_operation/gpu/device/device_normalization_bwd_x.hpp
...ensor_operation/gpu/device/device_normalization_bwd_x.hpp
+59
-0
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
...ation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
+339
-0
No files found.
example/53_layernorm2d_bwd/layernorm2d_bwd_fp16.cpp
View file @
e508995a
...
@@ -15,6 +15,7 @@
...
@@ -15,6 +15,7 @@
#include "ck/library/utility/literals.hpp"
#include "ck/library/utility/literals.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_normalization_bwd_gamma_beta_impl.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_layernorm_bwd.hpp"
...
@@ -46,8 +47,34 @@ constexpr int NumReduceDim = 1;
...
@@ -46,8 +47,34 @@ constexpr int NumReduceDim = 1;
// dbeta = reduce_sum(dy, axis=0)
// dbeta = reduce_sum(dy, axis=0)
// [CAUSION]
// [CAUSION]
// In DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K is reduced dimension
// In DeviceNormalizationBwdXImpl & DeviceNormalizationBwdGammaBetaImpl, M is invarient dimension, K
// Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is different
// is reduced dimension Hence, M in this example and DeviceNormalizationBwdGammaBetaImpl is
// different
using
XDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdXImpl
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
ComputeDataType
,
DXDataType
,
Rank
,
NumReduceDim
,
256
,
// BlockSize
8
,
// MThreadClusterSize
32
,
// KThreadClusterSize
1
,
// MThreadSliceSize
8
,
// KThreadSliceSize
true
,
// IsDYFastestDimReduced
8
,
// DYSrcVectorSize
true
,
// IsXFastestDimReduced
8
,
// XSrcVectorSize
true
,
// IsGammaFastestDimReduced
8
,
// GammaSrcVectorSize
false
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
true
,
// IsDXFastestDimReduced
8
>
;
// DXDstVectorSize
using
GammaBetaDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBetaImpl
<
using
GammaBetaDeviceInstance
=
ck
::
tensor_operation
::
device
::
DeviceNormalizationBwdGammaBetaImpl
<
DYDataType
,
DYDataType
,
XDataType
,
XDataType
,
...
@@ -58,18 +85,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
...
@@ -58,18 +85,18 @@ using GammaBetaDeviceInstance = ck::tensor_operation::device::DeviceNormalizatio
Rank
,
Rank
,
NumReduceDim
,
NumReduceDim
,
256
,
// BlockSize
256
,
// BlockSize
8
,
//
ClusterInvarient
8
,
//
MThreadClusterSize
32
,
// Cluster
Reduc
e
32
,
//
KThread
Cluster
Siz
e
8
,
//
SliceInvarient
8
,
//
MThreadSliceSize
1
,
//
SliceReduc
e
1
,
//
KThreadSliceSiz
e
false
,
// IsDYFastestDimReduced
false
,
// IsDYFastestDimReduced
8
,
// DYSrcVectorSize
8
,
// DYSrcVectorSize
false
,
// IsXFastestDimReduced
false
,
// IsXFastestDimReduced
8
,
// XSrcVectorSize
8
,
// XSrcVectorSize
true
,
// IsMeanInvStdFastestDimReduced
true
,
// IsMeanInvStdFastestDimReduced
1
,
// MeanInvStdSrcVectorSize
1
,
// MeanInvStdSrcVectorSize
1
,
// DGammaDstVectorSize
8
,
// DGammaDstVectorSize
1
>
;
// DBetaDstVectorSize
8
>
;
// DBetaDstVectorSize
int
main
()
int
main
()
{
{
...
@@ -96,8 +123,10 @@ int main()
...
@@ -96,8 +123,10 @@ int main()
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dy_dev
(
sizeof
(
DYDataType
)
*
dy
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
gamma_dev
(
sizeof
(
GammaDataType
)
*
gamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
mean_dev
(
sizeof
(
MeanInvStdDataType
)
*
mean
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
inv_std_dev
(
sizeof
(
MeanInvStdDataType
)
*
inv_std
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dx_dev
(
sizeof
(
DXDataType
)
*
dx
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dgamma_dev
(
sizeof
(
DGammaDataType
)
*
dgamma
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
dbeta_dev
(
sizeof
(
DBetaDataType
)
*
dbeta
.
mDesc
.
GetElementSpaceSize
());
...
@@ -106,6 +135,34 @@ int main()
...
@@ -106,6 +135,34 @@ int main()
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
mean_dev
.
ToDevice
(
mean
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
inv_std_dev
.
ToDevice
(
inv_std
.
mData
.
data
());
// backward x
auto
x_device_instance
=
XDeviceInstance
{};
auto
x_argument_ptr
=
x_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// lengths
{
N
,
1
},
// dyStrides
{
N
,
1
},
// xStrides
{
0
,
1
},
// gammaStrides
{
1
,
0
},
// meanStrides
{
1
,
0
},
// invStdStrides
{
N
,
1
},
// dxStrides
{
1
},
// reduceDims
dy_dev
.
GetDeviceBuffer
(),
x_dev
.
GetDeviceBuffer
(),
gamma_dev
.
GetDeviceBuffer
(),
mean_dev
.
GetDeviceBuffer
(),
inv_std_dev
.
GetDeviceBuffer
(),
dx_dev
.
GetDeviceBuffer
());
if
(
!
x_device_instance
.
IsSupportedArgument
(
x_argument_ptr
.
get
()))
{
std
::
cout
<<
"The runtime parameters are not supported"
<<
std
::
endl
;
return
1
;
};
auto
x_invoker_ptr
=
x_device_instance
.
MakeInvokerPointer
();
x_invoker_ptr
->
Run
(
x_argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
// backward gamma & beta
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_device_instance
=
GammaBetaDeviceInstance
{};
auto
gamma_beta_argument_ptr
=
auto
gamma_beta_argument_ptr
=
gamma_beta_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// inLengths
gamma_beta_device_instance
.
MakeArgumentPointer
({
M
,
N
},
// inLengths
...
...
include/ck/tensor_operation/gpu/device/device_normalization_bwd_x.hpp
0 → 100644
View file @
e508995a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceNormalizationBwdX
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dx
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
>
using
DeviceNormalizationBwdXPtr
=
std
::
unique_ptr
<
DeviceNormalizationBwdX
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
Rank
,
NumReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_normalization_bwd_x_impl.hpp
0 → 100644
View file @
e508995a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "ck/tensor_operation/gpu/device/device_normalization_bwd_x.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_reduce_common.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
// M is invarient dimension, K is reduced dimension
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
DYDataType
,
typename
XDataType
,
typename
GammaDataType
,
typename
MeanInvStdDataType
,
typename
ComputeDataType
,
typename
DXDataType
,
index_t
Rank
,
index_t
NumReduceDim
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
KThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
KThreadSliceSize
,
bool
IsDYFastestDimReduced
,
index_t
DYSrcVectorSize
,
bool
IsXFastestDimReduced
,
index_t
XSrcVectorSize
,
bool
IsGammaFastestDimReduced
,
index_t
GammaSrcVectorSize
,
bool
IsMeanInvStdFastestDimReduced
,
index_t
MeanInvStdSrcVectorSize
,
bool
IsDxFastestDimReduced
,
index_t
DXDstVectorSize
>
struct
DeviceNormalizationBwdXImpl
:
public
DeviceNormalizationBwdX
<
DYDataType
,
XDataType
,
GammaDataType
,
MeanInvStdDataType
,
DXDataType
,
Rank
,
NumReduceDim
>
{
static
constexpr
index_t
DYSrcVectorDim
=
IsDYFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
XSrcVectorDim
=
IsXFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
GammaSrcVectorDim
=
IsGammaFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
MeanInvStdSrcVectorDim
=
IsMeanInvStdFastestDimReduced
?
1
:
0
;
static
constexpr
index_t
DXDstVectorDim
=
IsDxFastestDimReduced
?
1
:
0
;
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
);
static_assert
(((
DYSrcVectorDim
==
0
&&
MThreadSliceSize
%
DYSrcVectorSize
==
0
)
||
(
DYSrcVectorDim
==
1
&&
KThreadSliceSize
%
DYSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or dy vector sizes configuration, please check!"
);
static_assert
(((
XSrcVectorDim
==
0
&&
MThreadSliceSize
%
XSrcVectorSize
==
0
)
||
(
XSrcVectorDim
==
1
&&
KThreadSliceSize
%
XSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or x vector sizes configuration, please check!"
);
static_assert
(
((
GammaSrcVectorDim
==
0
&&
MThreadSliceSize
%
GammaSrcVectorSize
==
0
)
||
(
GammaSrcVectorDim
==
1
&&
KThreadSliceSize
%
GammaSrcVectorSize
==
0
)),
"Invalid thread slice sizes and/or gamma vector sizes configuration, please check!"
);
static_assert
(
(
MeanInvStdSrcVectorDim
==
0
&&
MThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
)
||
(
MeanInvStdSrcVectorDim
==
1
&&
KThreadSliceSize
%
MeanInvStdSrcVectorSize
==
0
),
"Invalid thread slice sizes and/or mean and inverse std vector sizes configuration, please "
"check!"
);
static_assert
(((
DXDstVectorDim
==
0
&&
MThreadSliceSize
%
DXDstVectorSize
==
0
)
||
(
DXDstVectorDim
==
1
&&
KThreadSliceSize
%
DXDstVectorSize
==
0
)),
"Invalid thread slice sizes and/or dx vector sizes configuration, please check!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumReduceDim
;
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
K_BlockTileSize
=
KThreadClusterSize
*
KThreadSliceSize
;
static
constexpr
bool
reduceAllDim
=
(
NumInvariantDim
==
0
);
static_assert
(
!
reduceAllDim
);
static
auto
Make2dDescriptor
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
,
int
numBlockTileIteration
)
{
const
auto
tupleLengths
=
make_tuple_from_array
(
lengths
,
Number
<
Rank
>
{});
const
auto
tupleStrides
=
make_tuple_from_array
(
strides
,
Number
<
Rank
>
{});
const
auto
desc
=
make_naive_tensor_descriptor
(
tupleLengths
,
tupleStrides
);
const
auto
grid_desc_m_k
=
[
&
]()
{
using
InvariantDims
=
typename
arithmetic_sequence_gen
<
0
,
NumInvariantDim
,
1
>::
type
;
using
ReduceDims
=
typename
arithmetic_sequence_gen
<
NumInvariantDim
,
Rank
,
1
>::
type
;
const
auto
reduceDimLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
ReduceDims
{});
const
auto
invariantDimLengths
=
make_tuple_from_array_and_index_seq
(
lengths
,
InvariantDims
{});
return
transform_tensor_descriptor
(
desc
,
make_tuple
(
make_merge_transform
(
invariantDimLengths
),
make_merge_transform
(
reduceDimLengths
)),
make_tuple
(
InvariantDims
{},
ReduceDims
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
}();
const
auto
invariantLength
=
grid_desc_m_k
.
GetLength
(
Number
<
0
>
{});
const
auto
reduceLength
=
grid_desc_m_k
.
GetLength
(
Number
<
1
>
{});
const
auto
pad_M
=
math
::
integer_least_multiple
(
invariantLength
,
M_BlockTileSize
)
-
invariantLength
;
const
auto
pad_K
=
K_BlockTileSize
*
numBlockTileIteration
-
reduceLength
;
auto
grid_desc_m_k_padded
=
transform_tensor_descriptor
(
grid_desc_m_k
,
make_tuple
(
make_right_pad_transform
(
invariantLength
,
pad_M
),
make_right_pad_transform
(
reduceLength
,
pad_K
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
grid_desc_m_k_padded
;
}
using
GridDesc_M_K
=
decltype
(
Make2dDescriptor
({
1
},
{
1
},
1
));
struct
Argument
:
public
BaseArgument
{
Argument
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
DYDataType
*
p_dy
,
const
XDataType
*
p_x
,
const
GammaDataType
*
p_gamma
,
const
MeanInvStdDataType
*
p_mean
,
const
MeanInvStdDataType
*
p_invStd
,
DXDataType
*
p_dx
)
:
p_dy_
(
p_dy
),
p_x_
(
p_x
),
p_gamma_
(
p_gamma
),
p_mean_
(
p_mean
),
p_invStd_
(
p_invStd
),
p_dx_
(
p_dx
),
dxStrides_
{
dxStrides
}
{
lengths_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
lengths
,
reduceDims
);
dyStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
dyStrides
,
reduceDims
);
xStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
xStrides
,
reduceDims
);
gammaStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
gammaStrides
,
reduceDims
);
meanStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
meanStrides
,
reduceDims
);
invStdStrides_
=
shuffle_tensor_dimensions
<
Rank
,
NumReduceDim
>
(
invStdStrides
,
reduceDims
);
std
::
tie
(
MRaw_
,
KRaw_
)
=
get_2d_lengths
<
Rank
,
NumReduceDim
>
(
lengths
);
numBlockTileIteration_
=
math
::
integer_divide_ceil
(
KRaw_
,
K_BlockTileSize
);
gridSize_
=
math
::
integer_divide_ceil
(
MRaw_
,
M_BlockTileSize
);
dy_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dyStrides_
,
numBlockTileIteration_
);
x_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
xStrides_
,
numBlockTileIteration_
);
gamma_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
gammaStrides_
,
numBlockTileIteration_
);
mean_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
meanStrides_
,
numBlockTileIteration_
);
inv_std_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
invStdStrides_
,
numBlockTileIteration_
);
dx_grid_desc_m_k_
=
Make2dDescriptor
(
lengths_
,
dxStrides_
,
numBlockTileIteration_
);
isSweeponce_
=
dy_grid_desc_m_k_
.
GetLength
(
Number
<
1
>
{})
<=
K_BlockTileSize
;
}
const
DYDataType
*
p_dy_
;
const
XDataType
*
p_x_
;
const
GammaDataType
*
p_gamma_
;
const
MeanInvStdDataType
*
p_mean_
;
const
MeanInvStdDataType
*
p_invStd_
;
DXDataType
*
p_dx_
;
std
::
vector
<
index_t
>
lengths_
;
std
::
vector
<
index_t
>
dyStrides_
;
std
::
vector
<
index_t
>
xStrides_
;
std
::
vector
<
index_t
>
gammaStrides_
;
std
::
vector
<
index_t
>
meanStrides_
;
std
::
vector
<
index_t
>
invStdStrides_
;
std
::
vector
<
index_t
>
dxStrides_
;
int
numBlockTileIteration_
;
size_t
gridSize_
;
// tensor descriptor
GridDesc_M_K
dy_grid_desc_m_k_
;
GridDesc_M_K
x_grid_desc_m_k_
;
GridDesc_M_K
gamma_grid_desc_m_k_
;
GridDesc_M_K
mean_grid_desc_m_k_
;
GridDesc_M_K
inv_std_grid_desc_m_k_
;
GridDesc_M_K
dx_grid_desc_m_k_
;
bool
isSweeponce_
;
index_t
MRaw_
;
// invarient length
index_t
KRaw_
;
// reduce length
};
struct
Invoker
:
public
BaseInvoker
{
float
Run
(
const
Argument
&
arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
{
// TODO
ignore
=
arg
;
ignore
=
stream_config
;
return
0
;
}
float
Run
(
const
BaseArgument
*
p_arg
,
const
StreamConfig
&
stream_config
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
),
stream_config
);
}
};
template
<
index_t
SrcVectorDim
,
index_t
SrcVectorSize
>
bool
IsVectorDimSizeValid
(
const
std
::
vector
<
index_t
>&
lengths
,
const
std
::
vector
<
index_t
>&
strides
)
{
if
constexpr
(
SrcVectorSize
==
1
)
return
true
;
// Fastest dimension is not reduced
if
constexpr
(
SrcVectorDim
==
0
)
{
if
constexpr
(
NumInvariantDim
==
0
)
return
false
;
if
(
strides
[
NumInvariantDim
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
NumInvariantDim
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
}
else
// Fastest dimension is reduced
{
if
(
strides
[
Rank
-
1
]
!=
1
)
return
false
;
if
(
lengths
[
Rank
-
1
]
%
SrcVectorSize
!=
0
)
return
false
;
};
return
true
;
}
bool
IsSupportedArgument
(
const
BaseArgument
*
p_arg
)
override
{
const
Argument
*
p_arg_
=
dynamic_cast
<
const
Argument
*>
(
p_arg
);
bool
pass
=
true
;
pass
&=
IsVectorDimSizeValid
<
DYSrcVectorDim
,
DYSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
dyStrides_
);
pass
&=
IsVectorDimSizeValid
<
XSrcVectorDim
,
XSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
xStrides_
);
pass
&=
IsVectorDimSizeValid
<
GammaSrcVectorDim
,
GammaSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
gammaStrides_
);
pass
&=
IsVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
meanStrides_
);
pass
&=
IsVectorDimSizeValid
<
MeanInvStdSrcVectorDim
,
MeanInvStdSrcVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
invStdStrides_
);
pass
&=
IsVectorDimSizeValid
<
DXDstVectorDim
,
DXDstVectorSize
>
(
p_arg_
->
lengths_
,
p_arg_
->
dxStrides_
);
return
pass
;
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
vector
<
index_t
>
lengths
,
const
std
::
vector
<
index_t
>
dyStrides
,
const
std
::
vector
<
index_t
>
xStrides
,
const
std
::
vector
<
index_t
>
gammaStrides
,
const
std
::
vector
<
index_t
>
meanStrides
,
const
std
::
vector
<
index_t
>
invStdStrides
,
const
std
::
vector
<
index_t
>
dxStrides
,
const
std
::
vector
<
index_t
>
reduceDims
,
const
void
*
p_dy
,
const
void
*
p_x
,
const
void
*
p_gamma
,
const
void
*
p_mean
,
const
void
*
p_invStd
,
void
*
p_dx
)
override
{
if
(
lengths
.
size
()
!=
Rank
||
dyStrides
.
size
()
!=
Rank
||
xStrides
.
size
()
!=
Rank
||
gammaStrides
.
size
()
!=
Rank
||
meanStrides
.
size
()
!=
Rank
||
invStdStrides
.
size
()
!=
Rank
||
dxStrides
.
size
()
!=
Rank
)
throw
std
::
runtime_error
(
"dimension is incorrect"
);
return
std
::
make_unique
<
Argument
>
(
lengths
,
dyStrides
,
xStrides
,
gammaStrides
,
meanStrides
,
invStdStrides
,
dxStrides
,
reduceDims
,
static_cast
<
const
DYDataType
*>
(
p_dy
),
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
XDataType
*>
(
p_gamma
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_mean
),
static_cast
<
const
MeanInvStdDataType
*>
(
p_invStd
),
static_cast
<
DXDataType
*>
(
p_dx
));
}
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"DeviceNormalizationBwdXImpl<"
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment