Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
02ff2522
Commit
02ff2522
authored
Nov 25, 2022
by
Po-Yen, Chen
Browse files
Merge branch 'develop' into feature/restruct-ckprofiler
parents
acc47d12
4e6a5575
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2568 additions
and
27 deletions
+2568
-27
client_example/13_batchnorm/CMakeLists.txt
client_example/13_batchnorm/CMakeLists.txt
+2
-0
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
+197
-0
example/34_batchnorm/batchnorm_forward_nhwc.cpp
example/34_batchnorm/batchnorm_forward_nhwc.cpp
+14
-10
example/34_batchnorm/batchnorm_infer_nhwc.cpp
example/34_batchnorm/batchnorm_infer_nhwc.cpp
+17
-8
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
.../tensor_operation/gpu/device/device_batchnorm_forward.hpp
+27
-4
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
...ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
+29
-3
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+9
-2
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp
...ence_tensor_operation/cpu/reference_batchnorm_forward.hpp
+368
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp
...erence_tensor_operation/cpu/reference_batchnorm_infer.hpp
+300
-0
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
...brary/tensor_operation_instance/gpu/batchnorm_forward.hpp
+130
-0
library/include/ck/library/utility/host_common_util.hpp
library/include/ck/library/utility/host_common_util.hpp
+60
-0
library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
...rc/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
+7
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
.../gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
+147
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
+147
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
+145
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
+145
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_i8_instance.cpp
...ce/gpu/batchnorm/device_batchnorm_forward_i8_instance.cpp
+145
-0
profiler/include/profiler/profile_batchnorm_forward_impl.hpp
profiler/include/profiler/profile_batchnorm_forward_impl.hpp
+440
-0
profiler/src/CMakeLists.txt
profiler/src/CMakeLists.txt
+2
-0
profiler/src/profile_batchnorm_fwd.cpp
profiler/src/profile_batchnorm_fwd.cpp
+237
-0
No files found.
client_example/13_batchnorm/CMakeLists.txt
0 → 100644
View file @
02ff2522
add_executable
(
client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp
)
target_link_libraries
(
client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations
)
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp"
using
XDataType
=
float
;
using
YDataType
=
float
;
using
AccDataType
=
float
;
using
ScaleDataType
=
AccDataType
;
using
BiasDataType
=
AccDataType
;
using
MeanVarDataType
=
AccDataType
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
4
;
constexpr
int
NumBatchNormReduceDim
=
3
;
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
const
double
averageFactor
=
0.1
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
std
::
array
<
ck
::
index_t
,
Rank
>
xyLengths
{
16
,
8
,
128
,
256
};
std
::
array
<
ck
::
index_t
,
Rank
>
xyStrides
{
8
*
128
*
256
,
128
*
256
,
256
,
1
};
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
scaleBiasMeanVarLengths
{
256
};
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
scaleBiasMeanVarStrides
{
1
};
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
{
0
,
1
,
2
};
ck
::
index_t
numXYElement
=
std
::
accumulate
(
xyLengths
.
begin
(),
xyLengths
.
end
(),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
ck
::
index_t
numScaleBiasMeanVarElement
=
std
::
accumulate
(
scaleBiasMeanVarLengths
.
begin
(),
scaleBiasMeanVarLengths
.
end
(),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
SimpleDeviceMem
x
(
sizeof
(
XDataType
)
*
numXYElement
);
SimpleDeviceMem
y
(
sizeof
(
YDataType
)
*
numXYElement
);
SimpleDeviceMem
scale
(
sizeof
(
ScaleDataType
)
*
numScaleBiasMeanVarElement
);
SimpleDeviceMem
bias
(
sizeof
(
BiasDataType
)
*
numScaleBiasMeanVarElement
);
SimpleDeviceMem
mean
(
sizeof
(
MeanVarDataType
)
*
numScaleBiasMeanVarElement
);
SimpleDeviceMem
invVariance
(
sizeof
(
MeanVarDataType
)
*
numScaleBiasMeanVarElement
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
PassThrough
,
Rank
,
NumBatchNormReduceDim
>
;
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
xyLengths
,
xyStrides
,
xyStrides
,
reduceDims
,
scaleBiasMeanVarLengths
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
x
.
GetDeviceBuffer
(),
scale
.
GetDeviceBuffer
(),
bias
.
GetDeviceBuffer
(),
epsilon
,
PassThrough
{},
y
.
GetDeviceBuffer
(),
mean
.
GetDeviceBuffer
(),
invVariance
.
GetDeviceBuffer
(),
averageFactor
,
nullptr
,
nullptr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
numXYElement
*
(
sizeof
(
XDataType
)
+
sizeof
(
YDataType
))
+
numScaleBiasMeanVarElement
*
(
sizeof
(
ScaleDataType
)
+
sizeof
(
BiasDataType
)
+
sizeof
(
MeanVarDataType
)
+
sizeof
(
MeanVarDataType
));
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
found
)
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
xyLengths
,
xyStrides
,
xyStrides
,
reduceDims
,
scaleBiasMeanVarLengths
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
x
.
GetDeviceBuffer
(),
scale
.
GetDeviceBuffer
(),
bias
.
GetDeviceBuffer
(),
epsilon
,
PassThrough
{},
y
.
GetDeviceBuffer
(),
mean
.
GetDeviceBuffer
(),
invVariance
.
GetDeviceBuffer
(),
averageFactor
,
nullptr
,
nullptr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
example/34_batchnorm/batchnorm_forward_nhwc.cpp
View file @
02ff2522
...
...
@@ -15,7 +15,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward
_nhwc_c
.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
...
@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr
int
Rank
=
4
;
constexpr
int
NumReduceDim
=
3
;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
=
{
inOutLengths
[
3
]};
// input data of the batchnorm forward algorithm
...
...
@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
{
0
,
1
,
2
},
// indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
...
...
@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
{
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThroughOp
>
;
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThroughOp
,
Rank
,
NumReduceDim
>
;
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
{};
...
...
@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
{
0
,
1
,
2
},
// indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
...
...
example/34_batchnorm/batchnorm_infer_nhwc.cpp
View file @
02ff2522
...
...
@@ -15,7 +15,8 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
#include "batchnorm_infer_impl.hpp"
...
...
@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
constexpr
int
Rank
=
4
;
constexpr
int
NumReduceDim
=
3
;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
=
{
inOutLengths
[
3
]};
// input data of the batchnorm forward algorithm
...
...
@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if
(
do_verification
)
{
using
PassThroughOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReferenceBatchNormInferInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
;
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer
<
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThroughOp
,
Rank
,
NumReduceDim
>
;
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
{};
auto
argument_ptr_ref
=
batchNormInfer_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
...
...
@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
epsilon
,
PassThroughOp
{},
estimatedMean
.
mData
.
data
(),
estimatedVariance
.
mData
.
data
(),
y_ref
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
View file @
02ff2522
...
...
@@ -13,7 +13,15 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormFwd
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
...
@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>>
;
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
View file @
02ff2522
...
...
@@ -13,13 +13,22 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormInfer
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
...
...
@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
const
void
*
bnScale
,
const
void
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
void
*
estimatedMean
,
const
void
*
estimatedInvVariance
,
void
*
p_y
)
=
0
;
...
...
@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormInferPtr
=
std
::
unique_ptr
<
DeviceBatchNormInfer
<
Rank
,
NumBatchNormReduceDim
>>
;
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormInferPtr
=
std
::
unique_ptr
<
DeviceBatchNormInfer
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
02ff2522
...
...
@@ -42,8 +42,15 @@ template <typename XDataType,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward
_nhwc_c
.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp
View file @
02ff2522
...
...
@@ -4,13 +4,13 @@
#pragma once
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace
ck
{
...
...
@@ -23,20 +23,33 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
>
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
,
YElementwiseOp
>
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormFwd
:
public
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
...
...
@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
:
p_x_
(
p_x
),
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
y_elementwise_op_
(
y_elementwise_op
),
...
...
@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
{
ignore
=
xStrides
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
ignore
=
reduceDims
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
n
=
xyLengths
[
0
];
h
=
xyLengths
[
1
];
w
=
xyLengths
[
2
];
c
=
xyLengths
[
3
];
using
ck
::
host_common
::
get_index_set
;
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
// get invariant_dims[] and invariant_lengths[]
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
invariantDims_
[
i
]
=
dim
;
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
i
++
;
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
y_invariant_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
y_reduce_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
...
...
@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
y_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
y_reduce_strides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
...
...
@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool
resultSave
,
resultRunning
;
index_t
n
,
h
,
w
,
c
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
averageFactor_
;
AccDataType
epsilon_
;
...
...
@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{
float
Run
(
const
Argument
&
arg
)
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
index_t
offset_C
=
iC
;
using
ck
::
host_common
::
get_offset_from_index
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
y_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
y_invariant_strides_
,
invariant_index
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
// compute mean, variance using welford method
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
auto
offset
=
offset_N
+
offset
_H
+
offset_W
+
offset
_C
;
auto
x_
offset
=
x_invariant_
offset
+
x_reduce_
offset
;
curr_count
++
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
delta
=
x
-
mean
;
AccDataType
delta
=
x
-
mean
;
mean
+=
delta
/
curr_count
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
};
}
variance
+=
delta
*
delta2
;
};
// actual variance
variance
=
variance
/
curr_count
;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
// save the mean/inv
V
ariance if required
// save the mean/inv
-v
ariance if required
if
(
arg
.
resultSave
)
{
arg
.
resultSaveMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
mean
);
arg
.
resultSaveInvVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
);
size_t
offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
arg
.
resultSaveMean_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
mean
);
arg
.
resultSaveInvVariance_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
);
};
// update the moving average if required
if
(
arg
.
resultRunning
)
{
size_t
offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
arg
.
resultRunningMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
iC
])
*
arg
.
resultRunningMean_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
offset
])
*
oneMinusAverageFactor
+
mean
*
arg
.
averageFactor_
);
arg
.
resultRunningVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
iC
]
*
oneMinusAverageFactor
+
arg
.
resultRunningVariance_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
offset
]
*
oneMinusAverageFactor
+
variance
*
arg
.
averageFactor_
);
};
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
size_t
bias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnBiasStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
bnScale_
[
scale_offset
]);
AccDataType
bias
=
type_convert
<
AccDataType
>
(
arg
.
bnBias_
[
bias_offset
]);
// Normalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
y_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
y_reduce_strides_
,
reduce_index
);
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
y_offset
=
y_invariant_offset
+
y_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
norm_x
=
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
AccDataType
norm_x
=
(
x
-
mean
)
*
invVariance
;
arg
.
p_y_
[
offset
]
=
type_convert
<
YDataType
>
(
norm_x
);
};
}
AccDataType
y
=
scale
*
norm_x
+
bias
;
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
p_y_
[
y_offset
]
=
type_convert
<
YDataType
>
(
y
);
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
);
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
c
=
i
c
_begin
;
i
c
<
i
c
_end
;
++
i
c
)
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
thread_reduce_func
(
ic
);
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]
);
}
};
...
...
@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_BatchNorm_Forward
_NHWC_C<
"
<<
std
::
endl
;
str
<<
"Reference_BatchNorm_Forward"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer
_nhwc_c
.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp
View file @
02ff2522
...
...
@@ -8,6 +8,7 @@
#include <array>
#include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace
ck
{
...
...
@@ -19,114 +20,205 @@ template <typename XDataType,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
>
struct
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormInfer
<
4
,
3
>
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormInfer
:
public
device
::
DeviceBatchNormInfer
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
MeanVarDataType
*
estimatedMean
,
const
MeanVarDataType
*
estimatedVariance
,
YDataType
*
p_y
)
:
p_x_
(
p_x
),
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
epsilon_
(
epsilon
),
y_elementwise_op_
(
y_elementwise_op
),
estimatedMean_
(
estimatedMean
),
estimatedVariance_
(
estimatedVariance
),
p_y_
(
p_y
)
{
ignore
=
xStrides
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
ignore
=
bnBiasStrides
;
ignore
=
bnMeanVarStrides
;
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
n_
=
xyLengths
[
0
];
h_
=
xyLengths
[
1
];
w_
=
xyLengths
[
2
];
c_
=
xyLengths
[
3
];
using
ck
::
host_common
::
get_index_set
;
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
// get invariant_dims[] and invariant_lengths[]
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
invariantDims_
[
i
]
=
dim
;
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
i
++
;
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
y_invariant_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
y_reduce_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
y_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
y_reduce_strides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
double
epsilon_
;
const
YElementwiseOp
y_elementwise_op_
;
const
MeanVarDataType
*
estimatedMean_
;
const
MeanVarDataType
*
estimatedVariance_
;
YDataType
*
p_y_
;
index_t
n_
,
h_
,
w_
,
c_
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
epsilon_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
index_t
offset_C
=
iC
;
AccDataType
mean
=
arg
.
estimatedMean_
[
offset_C
];
AccDataType
variance
=
arg
.
estimatedVariance_
[
offset_C
];
using
ck
::
host_common
::
get_offset_from_index
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
y_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
y_invariant_strides_
,
invariant_index
);
size_t
mean_variance_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
AccDataType
mean
=
arg
.
estimatedMean_
[
mean_variance_offset
];
AccDataType
variance
=
arg
.
estimatedVariance_
[
mean_variance_offset
];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
type_convert
<
AccDataType
>
(
1.0
f
)
/
std
::
sqrt
(
arg
.
epsilon_
+
variance
);
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
size_t
bias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnBiasStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
bnScale_
[
scale_offset
]);
AccDataType
bias
=
type_convert
<
AccDataType
>
(
arg
.
bnBias_
[
bias_offset
]);
//
N
ormalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n_
;
iN
++
)
//
n
ormalization
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
index_t
offset_N
=
iN
*
arg
.
h_
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iH
=
0
;
iH
<
arg
.
h_
;
iH
++
)
{
index_t
offset_H
=
iH
*
arg
.
w_
*
arg
.
c_
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w_
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c_
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
y_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
y_reduce_strides_
,
reduce_index
);
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
y_offset
=
y_invariant_offset
+
y_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
norm_x
=
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
AccDataType
norm_x
=
(
x
-
mean
)
*
invVariance
;
arg
.
p_y_
[
offset
]
=
type_convert
<
YDataType
>
(
norm_x
);
};
}
AccDataType
y
=
scale
*
norm_x
+
bias
;
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
p_y_
[
y_offset
]
=
type_convert
<
YDataType
>
(
y
);
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c_
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c_
);
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
c
=
i
c
_begin
;
i
c
<
i
c
_end
;
++
i
c
)
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
thread_reduce_func
(
ic
);
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]
);
}
};
...
...
@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
};
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
void
*
estimatedMean
,
const
void
*
estimatedVariance
,
void
*
p_y
)
override
...
...
@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
...
...
@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
epsilon
,
y_elementwise_op
,
static_cast
<
const
MeanVarDataType
*>
(
estimatedMean
),
static_cast
<
const
MeanVarDataType
*>
(
estimatedVariance
),
static_cast
<
YDataType
*>
(
p_y
));
...
...
@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_BatchNorm_
Forward_NHWC_C
<"
<<
std
::
endl
;
str
<<
"Reference_BatchNorm_
Infer
<"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP32
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// BF16
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// Int8
void
add_device_batchnorm_forward_rank_4_3_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP64
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
BiasDataType
,
F16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
I8
>
&&
is_same_v
<
YDataType
,
I8
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
I8
>
&&
is_same_v
<
BiasDataType
,
I8
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_i8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/host_common_util.hpp
View file @
02ff2522
...
...
@@ -4,9 +4,11 @@
#pragma once
#include <vector>
#include <array>
#include <iostream>
#include <fstream>
#include <string>
#include <algorithm>
#include "ck/ck.hpp"
...
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return
(
values
);
}
template
<
int
NDim
>
static
inline
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
get_index_set
(
const
std
::
array
<
index_t
,
NDim
>&
dim_lengths
)
{
static_assert
(
NDim
>=
1
,
"NDim >= 1 is required to use this function!"
);
if
constexpr
(
NDim
==
1
)
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
for
(
int
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
{
std
::
array
<
index_t
,
1
>
index
{
i
};
index_set
.
push_back
(
index
);
};
return
index_set
;
}
else
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
std
::
array
<
index_t
,
NDim
-
1
>
partial_dim_lengths
;
std
::
copy
(
dim_lengths
.
begin
()
+
1
,
dim_lengths
.
end
(),
partial_dim_lengths
.
begin
());
std
::
vector
<
std
::
array
<
index_t
,
NDim
-
1
>>
partial_index_set
;
partial_index_set
=
get_index_set
<
NDim
-
1
>
(
partial_dim_lengths
);
for
(
index_t
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
for
(
const
auto
&
partial_index
:
partial_index_set
)
{
std
::
array
<
index_t
,
NDim
>
index
;
index
[
0
]
=
i
;
std
::
copy
(
partial_index
.
begin
(),
partial_index
.
end
(),
index
.
begin
()
+
1
);
index_set
.
push_back
(
index
);
};
return
index_set
;
};
};
template
<
int
NDim
>
static
inline
size_t
get_offset_from_index
(
const
std
::
array
<
index_t
,
NDim
>&
strides
,
const
std
::
array
<
index_t
,
NDim
>&
index
)
{
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
NDim
;
i
++
)
offset
+=
index
[
i
]
*
strides
[
i
];
return
(
offset
);
};
}
// namespace host_common
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
0 → 100644
View file @
02ff2522
add_instance_library
(
device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp
)
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_bf16_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_bf16_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_bf16_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_bf16_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f16_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f16_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f16_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f16_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f32_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f32_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f32_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f32_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F64
=
double
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f64_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f64_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f64_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f64_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_i8_instance.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_i8_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_i8_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_i8_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_i8_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/include/profiler/profile_batchnorm_forward_impl.hpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iomanip>
#include <stdexcept>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
namespace
ck
{
namespace
profiler
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
bool
profile_batchnorm_forward_impl
(
int
do_verification
,
int
init_method
,
bool
do_dumpout
,
bool
time_kernel
,
const
std
::
vector
<
size_t
>
inOutLengths
,
const
std
::
vector
<
int
>
reduceDims
,
bool
updateMovingAverage
,
bool
saveMeanAndInvVariance
,
double
averageFactor
,
double
epsilon
)
{
if
(
inOutLengths
.
size
()
!=
Rank
||
reduceDims
.
size
()
!=
NumBatchNormReduceDim
)
{
throw
std
::
runtime_error
(
"Invalid tensor lengths or number of reduce dimensions!"
);
};
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
;
// used for calculating the effective transferred bytes by each operation
size_t
total_length
;
size_t
invariant_length
=
1
;
total_length
=
std
::
accumulate
(
inOutLengths
.
begin
(),
inOutLengths
.
end
(),
1
,
std
::
multiplies
<
size_t
>
{});
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
for
(
int
dim
=
0
;
dim
<
Rank
;
dim
++
)
{
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
dim
==
d
;
}))
{
scaleBiasMeanVarLengths
.
push_back
(
inOutLengths
[
dim
]);
invariant_length
*=
inOutLengths
[
dim
];
};
}
// input data of the batchnorm forward algorithm
Tensor
<
XDataType
>
x
(
inOutLengths
);
Tensor
<
ScaleDataType
>
bnScale
(
scaleBiasMeanVarLengths
);
Tensor
<
BiasDataType
>
bnBias
(
scaleBiasMeanVarLengths
);
// output data of the batchnorm forward algorithm
Tensor
<
YDataType
>
y_ref
(
inOutLengths
);
Tensor
<
YDataType
>
y
(
inOutLengths
);
Tensor
<
MeanVarDataType
>
resultSaveMean_ref
(
scaleBiasMeanVarLengths
);
Tensor
<
MeanVarDataType
>
resultSaveInvVariance_ref
(
scaleBiasMeanVarLengths
);
Tensor
<
MeanVarDataType
>
resultRunningMean_ref
(
scaleBiasMeanVarLengths
);
Tensor
<
MeanVarDataType
>
resultRunningVariance_ref
(
scaleBiasMeanVarLengths
);
auto
inOutStrides
=
x
.
mDesc
.
GetStrides
();
auto
scaleBiasMeanVarStrides
=
bnScale
.
mDesc
.
GetStrides
();
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
if
(
updateMovingAverage
)
{
if
constexpr
(
ck
::
is_same_v
<
XDataType
,
int8_t
>
)
{
x
.
GenerateTensorValue
(
GeneratorTensor_2
<
XDataType
>
{
-
5
,
5
},
num_thread
);
const
float
x_mean
=
0.0
f
;
const
float
x_stddev
=
2.5
f
;
const
float
noise_stddev
=
0.04
f
;
resultRunningMean_ref
.
GenerateTensorValue
(
GeneratorTensor_4
<
MeanVarDataType
>
{
x_mean
,
noise_stddev
},
num_thread
);
resultRunningVariance_ref
.
GenerateTensorValue
(
GeneratorTensor_4
<
MeanVarDataType
>
{
x_stddev
*
x_stddev
,
noise_stddev
},
num_thread
);
}
else
{
const
float
x_mean
=
0.0
f
;
const
float
x_stddev
=
1.0
f
;
const
float
noise_stddev
=
0.04
f
;
// input data in normal distribution
x
.
GenerateTensorValue
(
GeneratorTensor_4
<
XDataType
>
{
x_mean
,
x_stddev
},
num_thread
);
// initialize the runningMean to be values with tiny variation to the mean of the x
// values
resultRunningMean_ref
.
GenerateTensorValue
(
GeneratorTensor_4
<
MeanVarDataType
>
{
x_mean
,
noise_stddev
},
num_thread
);
// initialize the runningVariance to be values with tiny variation to the variance of
// the x values
resultRunningVariance_ref
.
GenerateTensorValue
(
GeneratorTensor_4
<
MeanVarDataType
>
{
x_stddev
*
x_stddev
,
noise_stddev
},
num_thread
);
};
}
else
{
if
constexpr
(
ck
::
is_same_v
<
XDataType
,
int8_t
>
)
x
.
GenerateTensorValue
(
GeneratorTensor_2
<
XDataType
>
{
-
5
,
5
},
num_thread
);
else
x
.
GenerateTensorValue
(
GeneratorTensor_3
<
XDataType
>
{
-
1.0
f
,
1.0
f
},
num_thread
);
};
if
(
do_verification
)
{
if
constexpr
(
ck
::
is_same_v
<
ScaleDataType
,
int8_t
>
&&
ck
::
is_same_v
<
BiasDataType
,
int8_t
>
)
{
bnScale
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleDataType
>
{
-
5
,
5
},
num_thread
);
bnBias
.
GenerateTensorValue
(
GeneratorTensor_2
<
BiasDataType
>
{
-
5
,
5
},
num_thread
);
}
else
{
switch
(
init_method
)
{
case
0
:
bnScale
.
GenerateTensorValue
(
GeneratorTensor_0
<
ScaleDataType
>
{},
num_thread
);
bnBias
.
GenerateTensorValue
(
GeneratorTensor_0
<
BiasDataType
>
{},
num_thread
);
break
;
case
1
:
bnScale
.
GenerateTensorValue
(
GeneratorTensor_1
<
ScaleDataType
>
{
1
},
num_thread
);
bnBias
.
GenerateTensorValue
(
GeneratorTensor_1
<
BiasDataType
>
{
0
},
num_thread
);
break
;
case
2
:
bnScale
.
GenerateTensorValue
(
GeneratorTensor_2
<
ScaleDataType
>
{
-
5
,
5
},
num_thread
);
bnBias
.
GenerateTensorValue
(
GeneratorTensor_2
<
BiasDataType
>
{
-
5
,
5
},
num_thread
);
break
;
default:
bnScale
.
GenerateTensorValue
(
GeneratorTensor_3
<
ScaleDataType
>
{
-
1.0
f
,
1.0
f
},
num_thread
);
bnBias
.
GenerateTensorValue
(
GeneratorTensor_3
<
BiasDataType
>
{
-
1.0
f
,
1.0
f
},
num_thread
);
}
};
};
// these buffers are usually provided by the user application
DeviceMem
x_dev
(
sizeof
(
XDataType
)
*
x
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
y_dev
(
sizeof
(
XDataType
)
*
y
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
bnScale_dev
(
sizeof
(
ScaleDataType
)
*
bnScale
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
bnBias_dev
(
sizeof
(
BiasDataType
)
*
bnBias
.
mDesc
.
GetElementSpaceSize
());
// mean_dev or resultSaveMean_dev
DeviceMem
resultSaveMean_dev
(
sizeof
(
MeanVarDataType
)
*
resultSaveMean_ref
.
mDesc
.
GetElementSpaceSize
());
// meansquare_dev or resultSaveInvVariance_dev
DeviceMem
resultSaveInvVariance_dev
(
sizeof
(
MeanVarDataType
)
*
resultSaveInvVariance_ref
.
mDesc
.
GetElementSpaceSize
());
// resultRunningMean_dev
DeviceMem
resultRunningMean_dev
(
sizeof
(
MeanVarDataType
)
*
resultRunningMean_ref
.
mDesc
.
GetElementSpaceSize
());
// resultRunningVariance_dev
DeviceMem
resultRunningVariance_dev
(
sizeof
(
MeanVarDataType
)
*
resultRunningVariance_ref
.
mDesc
.
GetElementSpaceSize
());
x_dev
.
ToDevice
(
x
.
mData
.
data
());
bnScale_dev
.
ToDevice
(
bnScale
.
mData
.
data
());
bnBias_dev
.
ToDevice
(
bnBias
.
mData
.
data
());
if
(
updateMovingAverage
)
{
resultRunningMean_dev
.
ToDevice
(
resultRunningMean_ref
.
mData
.
data
());
resultRunningVariance_dev
.
ToDevice
(
resultRunningVariance_ref
.
mData
.
data
());
};
// used for storing the device result for verification when updateMovingAverage is enabled
Tensor
<
MeanVarDataType
>
resultRunningMean
(
scaleBiasMeanVarLengths
);
Tensor
<
MeanVarDataType
>
resultRunningVariance
(
scaleBiasMeanVarLengths
);
// used for storing the device result for verification when saveMeanAndInvVariance is enabled
Tensor
<
MeanVarDataType
>
resultSaveMean
(
scaleBiasMeanVarLengths
);
Tensor
<
MeanVarDataType
>
resultSaveInvVariance
(
scaleBiasMeanVarLengths
);
std
::
array
<
index_t
,
Rank
>
arrInOutLengths
;
std
::
array
<
index_t
,
Rank
>
arrInOutStrides
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
arrScaleBiasMeanVarLengths
;
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
arrScaleBiasMeanVarStrides
;
std
::
array
<
int
,
NumBatchNormReduceDim
>
arrReduceDims
;
std
::
copy
(
inOutLengths
.
begin
(),
inOutLengths
.
end
(),
arrInOutLengths
.
begin
());
std
::
copy
(
inOutStrides
.
begin
(),
inOutStrides
.
end
(),
arrInOutStrides
.
begin
());
std
::
copy
(
scaleBiasMeanVarLengths
.
begin
(),
scaleBiasMeanVarLengths
.
end
(),
arrScaleBiasMeanVarLengths
.
begin
());
std
::
copy
(
scaleBiasMeanVarStrides
.
begin
(),
scaleBiasMeanVarStrides
.
end
(),
arrScaleBiasMeanVarStrides
.
begin
());
std
::
copy
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
arrReduceDims
.
begin
());
using
PassThroughOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// add device batchnorm-forward instances
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
PassThroughOp
,
Rank
,
NumBatchNormReduceDim
>
;
// get device op instances
const
auto
instance_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
instance_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_instance_name
;
float
best_avg_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
if
(
do_verification
)
{
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
PassThroughOp
,
Rank
,
NumBatchNormReduceDim
>
;
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
{};
auto
argument_ptr_ref
=
batchNormFwd_ref
.
MakeArgumentPointer
(
arrInOutLengths
,
arrInOutStrides
,
arrInOutStrides
,
arrReduceDims
,
arrScaleBiasMeanVarLengths
,
arrScaleBiasMeanVarStrides
,
arrScaleBiasMeanVarStrides
,
arrScaleBiasMeanVarStrides
,
x
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
epsilon
,
PassThroughOp
{},
y_ref
.
mData
.
data
(),
saveMeanAndInvVariance
?
resultSaveMean_ref
.
mData
.
data
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_ref
.
mData
.
data
()
:
nullptr
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_ref
.
mData
.
data
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_ref
.
mData
.
data
()
:
nullptr
);
if
(
!
batchNormFwd_ref
.
IsSupportedArgument
(
argument_ptr_ref
.
get
()))
{
std
::
cout
<<
"The runtime parameters not supported by the reference instance, exiting!"
<<
std
::
endl
;
return
(
false
);
};
auto
invoker_ptr_ref
=
batchNormFwd_ref
.
MakeInvokerPointer
();
(
void
)
invoker_ptr_ref
->
Run
(
argument_ptr_ref
.
get
());
}
int
num_kernel
=
0
;
bool
pass
=
true
;
for
(
auto
&
inst_ptr
:
instance_ptrs
)
{
auto
argument_ptr
=
inst_ptr
->
MakeArgumentPointer
(
arrInOutLengths
,
arrInOutStrides
,
arrInOutStrides
,
arrReduceDims
,
arrScaleBiasMeanVarLengths
,
arrScaleBiasMeanVarStrides
,
arrScaleBiasMeanVarStrides
,
arrScaleBiasMeanVarStrides
,
x_dev
.
GetDeviceBuffer
(),
bnScale_dev
.
GetDeviceBuffer
(),
bnBias_dev
.
GetDeviceBuffer
(),
epsilon
,
PassThroughOp
{},
y_dev
.
GetDeviceBuffer
(),
saveMeanAndInvVariance
?
resultSaveMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
saveMeanAndInvVariance
?
resultSaveInvVariance_dev
.
GetDeviceBuffer
()
:
nullptr
,
averageFactor
,
updateMovingAverage
?
resultRunningMean_dev
.
GetDeviceBuffer
()
:
nullptr
,
updateMovingAverage
?
resultRunningVariance_dev
.
GetDeviceBuffer
()
:
nullptr
);
if
(
inst_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
num_kernel
++
;
}
else
{
if
(
time_kernel
)
{
std
::
cout
<<
inst_ptr
->
GetTypeString
()
<<
" skipped due to unsupported argument: "
<<
std
::
endl
;
}
continue
;
};
size_t
workspace_sz
=
inst_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
DeviceMem
workspace_dev
(
workspace_sz
);
inst_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace_dev
.
GetDeviceBuffer
());
auto
invoker_ptr
=
inst_ptr
->
MakeInvokerPointer
();
float
avg_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
time_kernel
});
size_t
num_bytes
=
0
;
// inputing of x, scale, bias, outputing of y
num_bytes
+=
total_length
*
(
sizeof
(
XDataType
)
+
sizeof
(
YDataType
))
+
invariant_length
*
(
sizeof
(
ScaleDataType
)
+
sizeof
(
BiasDataType
));
// outputing of mean, inv-variance
num_bytes
+=
saveMeanAndInvVariance
?
invariant_length
*
sizeof
(
MeanVarDataType
)
*
2
:
0
;
// updating of moving mean, variance
num_bytes
+=
updateMovingAverage
?
invariant_length
*
sizeof
(
MeanVarDataType
)
*
4
:
0
;
float
gb_per_sec
=
num_bytes
/
1.E6
/
avg_time
;
if
(
time_kernel
)
std
::
cout
<<
"Perf: "
<<
avg_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
inst_ptr
->
GetTypeString
()
<<
std
::
endl
;
if
(
avg_time
<
best_avg_time
)
{
best_instance_name
=
inst_ptr
->
GetTypeString
();
best_avg_time
=
avg_time
;
best_gb_per_sec
=
gb_per_sec
;
}
if
(
do_verification
)
{
using
ck
::
utils
::
check_err
;
bool
single_pass
;
y_dev
.
FromDevice
(
y
.
mData
.
data
());
if
constexpr
(
ck
::
is_same_v
<
YDataType
,
ck
::
bhalf_t
>
)
single_pass
=
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
1e-2
,
1e-2
);
else
single_pass
=
check_err
(
y
.
mData
,
y_ref
.
mData
,
"y results"
,
4e-3
,
4e-3
);
if
(
updateMovingAverage
)
{
resultRunningMean_dev
.
FromDevice
(
resultRunningMean
.
mData
.
data
());
resultRunningVariance_dev
.
FromDevice
(
resultRunningVariance
.
mData
.
data
());
// clang-format off
single_pass
=
single_pass
&&
check_err
(
resultRunningMean
.
mData
,
resultRunningMean_ref
.
mData
,
"average mean results"
,
1.5e-5
,
1.5e-5
);
single_pass
=
single_pass
&&
check_err
(
resultRunningVariance
.
mData
,
resultRunningVariance_ref
.
mData
,
"average variance results"
,
1e-5
,
1e-5
);
// clang-format on
};
if
(
saveMeanAndInvVariance
)
{
resultSaveMean_dev
.
FromDevice
(
resultSaveMean
.
mData
.
data
());
resultSaveInvVariance_dev
.
FromDevice
(
resultSaveInvVariance
.
mData
.
data
());
// clang-format off
single_pass
=
single_pass
&&
check_err
(
resultSaveMean
.
mData
,
resultSaveMean_ref
.
mData
,
"mean results"
,
3e-5
,
3e-5
);
single_pass
=
single_pass
&&
check_err
(
resultSaveInvVariance
.
mData
,
resultSaveInvVariance_ref
.
mData
,
"inv-variance results"
,
7e-5
,
7e-5
);
// clang-format on
};
pass
=
pass
&&
single_pass
;
};
if
(
do_dumpout
)
{
using
ck
::
host_common
::
dumpBufferToFile
;
// clang-format off
dumpBufferToFile
(
"dump_x.bin"
,
x
.
mData
.
data
(),
x
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_y.bin"
,
y
.
mData
.
data
(),
y
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_y_ref.bin"
,
y_ref
.
mData
.
data
(),
y_ref
.
mDesc
.
GetElementSize
());
// clang-format off
if
(
saveMeanAndInvVariance
)
{
// clang-format off
dumpBufferToFile
(
"dump_mean.bin"
,
resultSaveMean
.
mData
.
data
(),
resultSaveMean
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_mean_ref.bin"
,
resultSaveMean_ref
.
mData
.
data
(),
resultSaveMean_ref
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_invvar.bin"
,
resultSaveInvVariance
.
mData
.
data
(),
resultSaveInvVariance
.
mDesc
.
GetElementSize
());
dumpBufferToFile
(
"dump_invvar_ref.bin"
,
resultSaveInvVariance_ref
.
mData
.
data
(),
resultSaveInvVariance_ref
.
mDesc
.
GetElementSize
());
// clang-format on
};
};
}
if
(
time_kernel
)
{
std
::
cout
<<
"best perf = "
<<
best_avg_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_instance_name
<<
std
::
endl
;
}
if
(
num_kernel
==
0
)
{
std
::
cout
<<
"Error: No kernel is applicable"
<<
std
::
endl
;
return
false
;
}
return
pass
;
}
}
// namespace profiler
}
// namespace ck
profiler/src/CMakeLists.txt
View file @
02ff2522
...
...
@@ -22,6 +22,7 @@ set(PROFILER_SOURCES
profile_groupnorm.cpp
profile_layernorm.cpp
profile_softmax.cpp
profile_batchnorm_fwd.cpp
)
set
(
PROFILER_EXECUTABLE ckProfiler
)
...
...
@@ -56,5 +57,6 @@ target_link_libraries(${PROFILER_EXECUTABLE} PRIVATE device_conv2d_fwd_bias_relu
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_normalization_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_softmax_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_reduce_instance
)
target_link_libraries
(
${
PROFILER_EXECUTABLE
}
PRIVATE device_batchnorm_instance
)
rocm_install
(
TARGETS
${
PROFILER_EXECUTABLE
}
COMPONENT profiler
)
profiler/src/profile_batchnorm_fwd.cpp
0 → 100644
View file @
02ff2522
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <vector>
#include <getopt.h>
#include "ck/library/utility/host_common_util.hpp"
#include "profiler/profile_batchnorm_forward_impl.hpp"
#include "profiler_operation_registry.hpp"
using
ck
::
index_t
;
using
namespace
std
;
static
const
struct
option
long_options
[]
=
{{
"inOutLengths"
,
required_argument
,
nullptr
,
'D'
},
{
"reduceDims"
,
required_argument
,
nullptr
,
'R'
},
{
"dumpout"
,
required_argument
,
nullptr
,
'o'
},
{
"verify"
,
required_argument
,
nullptr
,
'v'
},
{
"help"
,
no_argument
,
nullptr
,
'?'
},
{
nullptr
,
0
,
nullptr
,
0
}};
class
BatchnormFwdArgParser
{
private:
int
option_index
=
0
;
public:
std
::
vector
<
size_t
>
inLengths
;
std
::
vector
<
int
>
reduceDims
;
bool
do_verification
=
false
;
bool
do_dumpout
=
false
;
bool
updateMovingAverage
;
bool
saveMeanAndInvVariance
;
int
data_type
=
0
;
int
init_method
=
2
;
bool
time_kernel
=
false
;
BatchnormFwdArgParser
()
=
default
;
~
BatchnormFwdArgParser
()
=
default
;
void
show_usage
(
const
char
*
cmd
)
{
// clang-format off
std
::
cout
<<
"Usage of "
<<
cmd
<<
std
::
endl
;
std
::
cout
<<
"--inOutLengths or -D, comma separated list of input tensor dimension lengths, must have 4 integers for nhwc"
<<
std
::
endl
;
std
::
cout
<<
"--reduceDims or -R, comma separated list of dimensions to reduce on"
<<
std
::
endl
;
std
::
cout
<<
"--verify or -v, 1/0 to indicate whether to verify the result by comparing with the host-based batch-normalization"
<<
std
::
endl
;
std
::
cout
<<
"Arg1: data type (0: fp16, 1: fp32, 3: int8, 5: bp16, 6: fp64)"
<<
std
::
endl
;
std
::
cout
<<
"Arg2: 1/0 to indicate whether to update the moving average and variance (0=no, 1=yes)"
<<
std
::
endl
;
std
::
cout
<<
"Arg3: 1/0 to indicate whether to save the calculated mean and invVariance (0=no, 1=yes)"
<<
std
::
endl
;
std
::
cout
<<
"Arg4: init method used for bnScale and bnBias (0=no init, 1=single integer value, 2=scope integer value, 3=decimal value)"
<<
std
::
endl
;
std
::
cout
<<
"Arg5: time kernel (0=no, 1=yes)"
<<
std
::
endl
;
// clang-format on
};
int
operator
()(
int
argc
,
char
*
argv
[])
{
using
ck
::
host_common
::
getTypeValuesFromString
;
int
ch
;
optind
++
;
// to skip the module name
while
(
1
)
{
ch
=
getopt_long
(
argc
,
argv
,
"D:R:v:o:"
,
long_options
,
&
option_index
);
if
(
ch
==
-
1
)
break
;
switch
(
ch
)
{
case
'D'
:
if
(
!
optarg
)
throw
std
::
runtime_error
(
"Invalid option format!"
);
inLengths
=
getTypeValuesFromString
<
size_t
>
(
optarg
);
break
;
case
'R'
:
if
(
!
optarg
)
throw
std
::
runtime_error
(
"Invalid option format!"
);
reduceDims
=
getTypeValuesFromString
<
int
>
(
optarg
);
break
;
case
'v'
:
if
(
!
optarg
)
throw
std
::
runtime_error
(
"Invalid option format!"
);
do_verification
=
static_cast
<
bool
>
(
std
::
atoi
(
optarg
));
break
;
case
'o'
:
if
(
!
optarg
)
throw
std
::
runtime_error
(
"Invalid option format!"
);
do_dumpout
=
static_cast
<
bool
>
(
std
::
atoi
(
optarg
));
break
;
case
'?'
:
if
(
std
::
string
(
long_options
[
option_index
].
name
)
==
"help"
)
{
show_usage
(
argv
[
0
]);
return
-
1
;
};
break
;
default:
show_usage
(
argv
[
0
]);
std
::
cerr
<<
"Invalid cmd-line options!"
<<
std
::
endl
;
return
-
1
;
};
};
if
(
optind
+
5
>
argc
)
throw
std
::
runtime_error
(
"Invalid cmd-line arguments, more argumetns are needed!"
);
data_type
=
std
::
atoi
(
argv
[
optind
++
]);
updateMovingAverage
=
std
::
atoi
(
argv
[
optind
++
]);
saveMeanAndInvVariance
=
std
::
atoi
(
argv
[
optind
++
]);
init_method
=
std
::
atoi
(
argv
[
optind
++
]);
time_kernel
=
static_cast
<
bool
>
(
std
::
atoi
(
argv
[
optind
++
]));
if
(
data_type
!=
0
&&
data_type
!=
1
&&
data_type
!=
3
&&
data_type
!=
5
&&
data_type
!=
6
)
return
-
1
;
return
0
;
};
};
// end of class AppArgs
static
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
static
const
double
averageFactor
=
0.1
;
int
profile_batchnorm_forward
(
int
argc
,
char
*
argv
[])
{
using
ck
::
profiler
::
profile_batchnorm_forward_impl
;
BatchnormFwdArgParser
arg_parser
;
if
(
arg_parser
(
argc
,
argv
)
!=
0
)
return
-
1
;
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
BF16
=
ck
::
bhalf_t
;
using
I8
=
int8_t
;
using
F64
=
double
;
if
(
arg_parser
.
data_type
==
0
)
{
if
(
arg_parser
.
inLengths
.
size
()
==
4
&&
arg_parser
.
reduceDims
.
size
()
==
3
)
{
profile_batchnorm_forward_impl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F16
,
4
,
3
>
(
arg_parser
.
do_verification
,
arg_parser
.
init_method
,
arg_parser
.
do_dumpout
,
arg_parser
.
time_kernel
,
arg_parser
.
inLengths
,
arg_parser
.
reduceDims
,
arg_parser
.
updateMovingAverage
,
arg_parser
.
saveMeanAndInvVariance
,
epsilon
,
averageFactor
);
};
}
else
if
(
arg_parser
.
data_type
==
1
)
{
if
(
arg_parser
.
inLengths
.
size
()
==
4
&&
arg_parser
.
reduceDims
.
size
()
==
3
)
{
profile_batchnorm_forward_impl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
4
,
3
>
(
arg_parser
.
do_verification
,
arg_parser
.
init_method
,
arg_parser
.
do_dumpout
,
arg_parser
.
time_kernel
,
arg_parser
.
inLengths
,
arg_parser
.
reduceDims
,
arg_parser
.
updateMovingAverage
,
arg_parser
.
saveMeanAndInvVariance
,
epsilon
,
averageFactor
);
};
}
else
if
(
arg_parser
.
data_type
==
3
)
{
if
(
arg_parser
.
inLengths
.
size
()
==
4
&&
arg_parser
.
reduceDims
.
size
()
==
3
)
{
profile_batchnorm_forward_impl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
4
,
3
>
(
arg_parser
.
do_verification
,
arg_parser
.
init_method
,
arg_parser
.
do_dumpout
,
arg_parser
.
time_kernel
,
arg_parser
.
inLengths
,
arg_parser
.
reduceDims
,
arg_parser
.
updateMovingAverage
,
arg_parser
.
saveMeanAndInvVariance
,
epsilon
,
averageFactor
);
};
}
else
if
(
arg_parser
.
data_type
==
5
)
{
if
(
arg_parser
.
inLengths
.
size
()
==
4
&&
arg_parser
.
reduceDims
.
size
()
==
3
)
{
profile_batchnorm_forward_impl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
4
,
3
>
(
arg_parser
.
do_verification
,
arg_parser
.
init_method
,
arg_parser
.
do_dumpout
,
arg_parser
.
time_kernel
,
arg_parser
.
inLengths
,
arg_parser
.
reduceDims
,
arg_parser
.
updateMovingAverage
,
arg_parser
.
saveMeanAndInvVariance
,
epsilon
,
averageFactor
);
};
}
else
if
(
arg_parser
.
data_type
==
6
)
{
if
(
arg_parser
.
inLengths
.
size
()
==
4
&&
arg_parser
.
reduceDims
.
size
()
==
3
)
{
profile_batchnorm_forward_impl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
4
,
3
>
(
arg_parser
.
do_verification
,
arg_parser
.
init_method
,
arg_parser
.
do_dumpout
,
arg_parser
.
time_kernel
,
arg_parser
.
inLengths
,
arg_parser
.
reduceDims
,
arg_parser
.
updateMovingAverage
,
arg_parser
.
saveMeanAndInvVariance
,
epsilon
,
averageFactor
);
};
}
return
0
;
}
REGISTER_PROFILER_OPERATION
(
"bnorm_fwd"
,
"Batchnorm forward"
,
profile_batchnorm_forward
);
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment