Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
b7f500f0
Commit
b7f500f0
authored
Nov 28, 2022
by
rocking5566
Committed by
rocking
Nov 28, 2022
Browse files
Merge branch 'develop' into gemm_layernorm_welford
parents
694057a7
4e6a5575
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2111 additions
and
149 deletions
+2111
-149
client_example/13_batchnorm/CMakeLists.txt
client_example/13_batchnorm/CMakeLists.txt
+2
-0
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
+197
-0
example/34_batchnorm/batchnorm_forward_nhwc.cpp
example/34_batchnorm/batchnorm_forward_nhwc.cpp
+14
-10
example/34_batchnorm/batchnorm_infer_nhwc.cpp
example/34_batchnorm/batchnorm_infer_nhwc.cpp
+17
-8
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
.../tensor_operation/gpu/device/device_batchnorm_forward.hpp
+27
-4
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
...ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
+29
-3
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
.../device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
+118
-117
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
...eration/gpu/device/impl/device_batchnorm_forward_impl.hpp
+9
-2
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
...mm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
+102
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp
...ence_tensor_operation/cpu/reference_batchnorm_forward.hpp
+368
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp
...erence_tensor_operation/cpu/reference_batchnorm_infer.hpp
+300
-0
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
...brary/tensor_operation_instance/gpu/batchnorm_forward.hpp
+130
-0
library/include/ck/library/utility/host_common_util.hpp
library/include/ck/library/utility/host_common_util.hpp
+60
-0
library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
...rc/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
+7
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
.../gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
+147
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
+147
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
+145
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
+145
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_i8_instance.cpp
...ce/gpu/batchnorm/device_batchnorm_forward_i8_instance.cpp
+145
-0
profiler/CMakeLists.txt
profiler/CMakeLists.txt
+2
-0
No files found.
client_example/13_batchnorm/CMakeLists.txt
0 → 100644
View file @
b7f500f0
add_executable
(
client_batchnorm_fwd_nhwc batchnorm_fwd_nhwc.cpp
)
target_link_libraries
(
client_batchnorm_fwd_nhwc PRIVATE composable_kernel::device_operations
)
client_example/13_batchnorm/batchnorm_fwd_nhwc.cpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <functional>
#include <numeric>
#include <iomanip>
#include <iostream>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_reduce.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp"
using
XDataType
=
float
;
using
YDataType
=
float
;
using
AccDataType
=
float
;
using
ScaleDataType
=
AccDataType
;
using
BiasDataType
=
AccDataType
;
using
MeanVarDataType
=
AccDataType
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
constexpr
int
Rank
=
4
;
constexpr
int
NumBatchNormReduceDim
=
3
;
const
double
epsilon
=
std
::
numeric_limits
<
float
>::
epsilon
();
const
double
averageFactor
=
0.1
;
struct
SimpleDeviceMem
{
SimpleDeviceMem
()
=
delete
;
SimpleDeviceMem
(
std
::
size_t
mem_size
)
:
p_mem_
{}
{
(
void
)
hipMalloc
(
static_cast
<
void
**>
(
&
p_mem_
),
mem_size
);
}
void
*
GetDeviceBuffer
()
{
return
p_mem_
;
}
~
SimpleDeviceMem
()
{
(
void
)
hipFree
(
p_mem_
);
}
void
*
p_mem_
;
};
int
main
(
int
argc
,
char
*
argv
[])
{
std
::
array
<
ck
::
index_t
,
Rank
>
xyLengths
{
16
,
8
,
128
,
256
};
std
::
array
<
ck
::
index_t
,
Rank
>
xyStrides
{
8
*
128
*
256
,
128
*
256
,
256
,
1
};
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
scaleBiasMeanVarLengths
{
256
};
std
::
array
<
ck
::
index_t
,
Rank
-
NumBatchNormReduceDim
>
scaleBiasMeanVarStrides
{
1
};
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
{
0
,
1
,
2
};
ck
::
index_t
numXYElement
=
std
::
accumulate
(
xyLengths
.
begin
(),
xyLengths
.
end
(),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
ck
::
index_t
numScaleBiasMeanVarElement
=
std
::
accumulate
(
scaleBiasMeanVarLengths
.
begin
(),
scaleBiasMeanVarLengths
.
end
(),
1
,
std
::
multiplies
<
ck
::
index_t
>
());
SimpleDeviceMem
x
(
sizeof
(
XDataType
)
*
numXYElement
);
SimpleDeviceMem
y
(
sizeof
(
YDataType
)
*
numXYElement
);
SimpleDeviceMem
scale
(
sizeof
(
ScaleDataType
)
*
numScaleBiasMeanVarElement
);
SimpleDeviceMem
bias
(
sizeof
(
BiasDataType
)
*
numScaleBiasMeanVarElement
);
SimpleDeviceMem
mean
(
sizeof
(
MeanVarDataType
)
*
numScaleBiasMeanVarElement
);
SimpleDeviceMem
invVariance
(
sizeof
(
MeanVarDataType
)
*
numScaleBiasMeanVarElement
);
using
DeviceOp
=
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
PassThrough
,
Rank
,
NumBatchNormReduceDim
>
;
const
auto
op_ptrs
=
ck
::
tensor_operation
::
device
::
instance
::
DeviceOperationInstanceFactory
<
DeviceOp
>::
GetInstances
();
std
::
cout
<<
"found "
<<
op_ptrs
.
size
()
<<
" instances"
<<
std
::
endl
;
std
::
string
best_op_name
;
bool
found
=
false
;
int
best_op_id
=
-
1
;
float
best_ave_time
=
std
::
numeric_limits
<
float
>::
max
();
float
best_gb_per_sec
=
0
;
// profile device operation instances
std
::
cout
<<
"Run all instances and do timing"
<<
std
::
endl
;
for
(
int
i
=
0
;
i
<
op_ptrs
.
size
();
++
i
)
{
auto
&
op_ptr
=
op_ptrs
[
i
];
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
xyLengths
,
xyStrides
,
xyStrides
,
reduceDims
,
scaleBiasMeanVarLengths
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
x
.
GetDeviceBuffer
(),
scale
.
GetDeviceBuffer
(),
bias
.
GetDeviceBuffer
(),
epsilon
,
PassThrough
{},
y
.
GetDeviceBuffer
(),
mean
.
GetDeviceBuffer
(),
invVariance
.
GetDeviceBuffer
(),
averageFactor
,
nullptr
,
nullptr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
std
::
string
op_name
=
op_ptr
->
GetTypeString
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
size_t
workspace_sz
=
op_ptr
->
GetWorkSpaceSize
(
argument_ptr
.
get
());
SimpleDeviceMem
workspace
(
workspace_sz
);
op_ptr
->
SetWorkSpacePointer
(
argument_ptr
.
get
(),
workspace
.
GetDeviceBuffer
());
float
ave_time
=
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
true
});
std
::
size_t
num_bytes
=
numXYElement
*
(
sizeof
(
XDataType
)
+
sizeof
(
YDataType
))
+
numScaleBiasMeanVarElement
*
(
sizeof
(
ScaleDataType
)
+
sizeof
(
BiasDataType
)
+
sizeof
(
MeanVarDataType
)
+
sizeof
(
MeanVarDataType
));
float
gb_per_sec
=
num_bytes
/
1.E6
/
ave_time
;
std
::
cout
<<
"Perf: "
<<
std
::
setw
(
10
)
<<
ave_time
<<
" ms, "
<<
gb_per_sec
<<
" GB/s, "
<<
op_name
<<
std
::
endl
;
if
(
ave_time
<
best_ave_time
)
{
found
=
true
;
best_op_id
=
i
;
best_op_name
=
op_name
;
best_ave_time
=
ave_time
;
best_gb_per_sec
=
gb_per_sec
;
}
}
else
{
std
::
cout
<<
op_name
<<
" does not support this problem"
<<
std
::
endl
;
}
}
if
(
found
)
{
std
::
cout
<<
"Best Perf: "
<<
best_ave_time
<<
" ms, "
<<
best_gb_per_sec
<<
" GB/s, "
<<
best_op_name
<<
std
::
endl
;
// run the best intance
auto
&
op_ptr
=
op_ptrs
[
best_op_id
];
std
::
cout
<<
"Run the best instance without timing: "
<<
op_ptr
->
GetTypeString
()
<<
std
::
endl
;
auto
argument_ptr
=
op_ptr
->
MakeArgumentPointer
(
xyLengths
,
xyStrides
,
xyStrides
,
reduceDims
,
scaleBiasMeanVarLengths
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
scaleBiasMeanVarStrides
,
x
.
GetDeviceBuffer
(),
scale
.
GetDeviceBuffer
(),
bias
.
GetDeviceBuffer
(),
epsilon
,
PassThrough
{},
y
.
GetDeviceBuffer
(),
mean
.
GetDeviceBuffer
(),
invVariance
.
GetDeviceBuffer
(),
averageFactor
,
nullptr
,
nullptr
);
auto
invoker_ptr
=
op_ptr
->
MakeInvokerPointer
();
if
(
op_ptr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
invoker_ptr
->
Run
(
argument_ptr
.
get
(),
StreamConfig
{
nullptr
,
false
});
}
std
::
cout
<<
"Done"
<<
std
::
endl
;
}
return
0
;
}
example/34_batchnorm/batchnorm_forward_nhwc.cpp
View file @
b7f500f0
...
@@ -15,7 +15,7 @@
...
@@ -15,7 +15,7 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward
_nhwc_c
.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
...
@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -142,6 +142,8 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
constexpr
int
Rank
=
4
;
constexpr
int
Rank
=
4
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
=
{
inOutLengths
[
3
]};
const
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
=
{
inOutLengths
[
3
]};
// input data of the batchnorm forward algorithm
// input data of the batchnorm forward algorithm
...
@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -300,7 +302,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths
,
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
{
0
,
1
,
2
},
// indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
...
@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -366,13 +368,15 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
{
{
using
ReferenceBatchNormFwdInstance
=
using
ReferenceBatchNormFwdInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
<
InOutDataType
,
ck
::
tensor_operation
::
host
::
ReferenceBatchNormFwd
<
InOutDataType
,
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
PassThroughOp
>
;
PassThroughOp
,
Rank
,
NumReduceDim
>
;
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
{};
auto
batchNormFwd_ref
=
ReferenceBatchNormFwdInstance
{};
...
@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
...
@@ -380,7 +384,7 @@ bool bnorm_fwd_nhwc_test(bool do_verification,
i_inOutLengths
,
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
{
0
,
1
,
2
},
// indicates physical indices of reduce dimensions in lengths[] and strides[]
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
...
...
example/34_batchnorm/batchnorm_infer_nhwc.cpp
View file @
b7f500f0
...
@@ -15,7 +15,8 @@
...
@@ -15,7 +15,8 @@
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer_nhwc_c.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp"
#include "batchnorm_infer_impl.hpp"
#include "batchnorm_infer_impl.hpp"
...
@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -124,6 +125,8 @@ bool bnorm_infer_nhwc_test(bool do_verification,
constexpr
int
Rank
=
4
;
constexpr
int
Rank
=
4
;
constexpr
int
NumReduceDim
=
3
;
constexpr
int
NumReduceDim
=
3
;
// when using lengths[] to create a tensor, lengths[0] is the length of highest dimension
// eg. N of NHWC, so lengths[3] is the dimension C length of NHWC
const
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
=
{
inOutLengths
[
3
]};
const
std
::
vector
<
size_t
>
scaleBiasMeanVarLengths
=
{
inOutLengths
[
3
]};
// input data of the batchnorm forward algorithm
// input data of the batchnorm forward algorithm
...
@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -260,20 +263,25 @@ bool bnorm_infer_nhwc_test(bool do_verification,
if
(
do_verification
)
if
(
do_verification
)
{
{
using
PassThroughOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
ReferenceBatchNormInferInstance
=
using
ReferenceBatchNormInferInstance
=
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
<
ck
::
tensor_operation
::
host
::
ReferenceBatchNormInfer
<
InOutDataType
,
InOutDataType
,
InOutDataType
,
InOutDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
,
AccDataType
>
;
PassThroughOp
,
Rank
,
NumReduceDim
>
;
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
{};
auto
batchNormInfer_ref
=
ReferenceBatchNormInferInstance
{};
auto
argument_ptr_ref
=
auto
argument_ptr_ref
=
batchNormInfer_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
batchNormInfer_ref
.
MakeArgumentPointer
(
i_inOutLengths
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
i_inOutStrides
,
{
0
,
1
,
2
},
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarLengths
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
i_scaleBiasMeanVarStrides
,
...
@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
...
@@ -282,6 +290,7 @@ bool bnorm_infer_nhwc_test(bool do_verification,
bnScale
.
mData
.
data
(),
bnScale
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
bnBias
.
mData
.
data
(),
epsilon
,
epsilon
,
PassThroughOp
{},
estimatedMean
.
mData
.
data
(),
estimatedMean
.
mData
.
data
(),
estimatedVariance
.
mData
.
data
(),
estimatedVariance
.
mData
.
data
(),
y_ref
.
mData
.
data
());
y_ref
.
mData
.
data
());
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp
View file @
b7f500f0
...
@@ -13,7 +13,15 @@ namespace ck {
...
@@ -13,7 +13,15 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormFwd
:
public
BaseOperator
struct
DeviceBatchNormFwd
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
...
@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
...
@@ -40,9 +48,24 @@ struct DeviceBatchNormFwd : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
,
typename
YElementwiseOp
>
template
<
typename
XDataType
,
using
DeviceBatchNormFwdPtr
=
typename
YDataType
,
std
::
unique_ptr
<
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>>
;
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormFwdPtr
=
std
::
unique_ptr
<
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp
View file @
b7f500f0
...
@@ -13,13 +13,22 @@ namespace ck {
...
@@ -13,13 +13,22 @@ namespace ck {
namespace
tensor_operation
{
namespace
tensor_operation
{
namespace
device
{
namespace
device
{
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
DeviceBatchNormInfer
:
public
BaseOperator
struct
DeviceBatchNormInfer
:
public
BaseOperator
{
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
Rank
-
NumBatchNormReduceDim
>
bnBiasStrides
,
...
@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
...
@@ -28,6 +37,7 @@ struct DeviceBatchNormInfer : public BaseOperator
const
void
*
bnScale
,
const
void
*
bnScale
,
const
void
*
bnBias
,
const
void
*
bnBias
,
double
epsilon
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
void
*
estimatedMean
,
const
void
*
estimatedMean
,
const
void
*
estimatedInvVariance
,
const
void
*
estimatedInvVariance
,
void
*
p_y
)
=
0
;
void
*
p_y
)
=
0
;
...
@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
...
@@ -35,8 +45,24 @@ struct DeviceBatchNormInfer : public BaseOperator
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
};
template
<
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
template
<
typename
XDataType
,
using
DeviceBatchNormInferPtr
=
std
::
unique_ptr
<
DeviceBatchNormInfer
<
Rank
,
NumBatchNormReduceDim
>>
;
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
using
DeviceBatchNormInferPtr
=
std
::
unique_ptr
<
DeviceBatchNormInfer
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>>
;
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_gemm_multiple_d_layernorm_xdl_cshuffle.hpp
View file @
b7f500f0
...
@@ -13,15 +13,14 @@
...
@@ -13,15 +13,14 @@
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_gemm_multiple_d_welford_first_half_xdl_cshuffle.hpp"
// #include
#include "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
// "ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "device_base.hpp"
#include "device_base.hpp"
namespace
ck
{
namespace
ck
{
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
Welford
,
typename
ABDataType
,
typename
ABDataType
,
typename
DsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
EDataType
,
...
@@ -63,25 +62,26 @@ __global__ void
...
@@ -63,25 +62,26 @@ __global__ void
const
Block2ETileMap
block_2_etile_map
)
const
Block2ETileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemmWelford
::
GetSharedMemoryNumberOfByte
()];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
GridwiseGemmWelford
::
template
Run
<
HasMainKBlockLoop
>(
p_b_grid
,
p_a_grid
,
p_ds_grid
,
p_b_grid
,
p_e_grid
,
p_ds_grid
,
p_mean_grid
,
p_e_grid
,
p_var_grid
,
p_mean_grid
,
p_shared
,
p_var_grid
,
a_element_op
,
p_shared
,
b_element_op
,
a_element_op
,
cde_element_op
,
b_element_op
,
a_grid_desc_ak0_m_ak1
,
cde_element_op
,
b_grid_desc_bk0_n_bk1
,
a_grid_desc_ak0_m_ak1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
b_grid_desc_bk0_n_bk1
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
mean_grid_desc_mblock_mperblock_nblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock
,
var_grid_desc_mblock_mperblock_nblock
,
mean_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
);
var_grid_desc_mblock_mperblock_nblock
,
block_2_etile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
...
@@ -102,23 +102,23 @@ __global__ void
...
@@ -102,23 +102,23 @@ __global__ void
#endif
#endif
}
}
//
template <typename GridwiseWelfordLayernorm,
template
<
typename
GridwiseWelfordLayernorm
,
//
typename EDataType,
typename
EDataType
,
//
typename HDataType,
typename
HDataType
,
//
typename MeanDataType,
typename
MeanDataType
,
//
typename VarDataType>
typename
VarDataType
>
//
__global__ void
__global__
void
//
#if CK_USE_LAUNCH_BOUNDS
#if CK_USE_LAUNCH_BOUNDS
//
__launch_bounds__(CK_MAX_THREAD_PER_BLOCK, CK_MIN_BLOCK_PER_CU)
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
//
#endif
#endif
//
kernel_welford_layernorm2d_second_half(const EDataType* __restrict__ p_x_grid,
kernel_welford_layernorm2d_second_half
(
const
EDataType
*
__restrict__
p_x_grid
,
//
const MeanDataType* __restrict__ p_mean_grid,
const
MeanDataType
*
__restrict__
p_mean_grid
,
//
const VarDataType* __restrict__ p_var_grid,
const
VarDataType
*
__restrict__
p_var_grid
,
//
HDataType* __restrict__ p_y_grid,
HDataType
*
__restrict__
p_y_grid
,
//
index_t blkgroup_size)
index_t
blkgroup_size
)
//
{
{
//
GridwiseWelfordLayernorm::Run(p_x_grid, p_mean_grid, p_var_grid, p_y_grid, blkgroup_size);
GridwiseWelfordLayernorm
::
Run
(
p_x_grid
,
p_mean_grid
,
p_var_grid
,
p_y_grid
,
blkgroup_size
);
//
}
}
}
// namespace ck
}
// namespace ck
...
@@ -335,8 +335,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -335,8 +335,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
using
MeanVarGridDesc_M
=
decltype
(
MakeDescriptor_M
(
1
));
using
MeanVarGridDesc_M
=
decltype
(
MakeDescriptor_M
(
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
using
HGridDesc_M_N
=
decltype
(
MakeGridDescriptor_M_N
<
HLayout
>
(
1
,
1
,
1
));
// GridwiseGemm
using
GridwiseGemmWelford
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
using
GridwiseGemm
=
GridwiseGemmMultipleDWelfordFirstHalf_xdl_cshuffle
<
ADataType
,
// TODO: distinguish A/B datatype
ADataType
,
// TODO: distinguish A/B datatype
AccDataType
,
AccDataType
,
CShuffleDataType
,
CShuffleDataType
,
...
@@ -388,29 +387,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -388,29 +387,29 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
1
,
1
,
LoopSched
>
;
LoopSched
>
;
using
Block2ETileMap
=
typename
GridwiseGemm
::
DefaultBlock2ETileMap
;
using
Block2ETileMap
=
typename
GridwiseGemm
Welford
::
DefaultBlock2ETileMap
;
//
using GridwiseWelfordLayernorm =
using
GridwiseWelfordLayernorm
=
//
GridwiseWelfordSecondHalfLayernorm2d<EDataType,
GridwiseWelfordSecondHalfLayernorm2d
<
EDataType
,
//
HDataType,
HDataType
,
//
MeanDataType,
MeanDataType
,
//
VarDataType,
VarDataType
,
//
AccDataType,
AccDataType
,
//
HGridDesc_M_N,
HGridDesc_M_N
,
//
MeanGridDesc_M_N,
Mean
Var
GridDesc_M_N
,
//
GammaBetaGridDesc_N,
GammaBetaGridDesc_N
,
//
MeanVarGridDesc_M,
MeanVarGridDesc_M
,
//
BlockSize,
BlockSize
,
//
Layernorm
M
ThreadClusterSize,
LayernormThreadClusterSize
_M_N
::
At
(
I0
)
,
//
Layernorm
N
ThreadClusterSize,
LayernormThreadClusterSize
_M_N
::
At
(
I1
)
,
//
Layernorm
M
ThreadSliceSize,
LayernormThreadSliceSize
_M_N
::
At
(
I0
)
,
//
Layernorm
N
ThreadSliceSize,
LayernormThreadSliceSize
_M_N
::
At
(
I1
)
,
//
LayernormESrcHDstVectorDim,
LayernormESrcHDstVectorDim
,
//
LayernormESrcVectorSize,
LayernormESrcVectorSize
,
//
LayernormHDstVectorSize,
LayernormHDstVectorSize
,
//
LayernormGammaSrcVectorSize,
LayernormGammaSrcVectorSize
,
//
LayernormBetaSrcVectorSize,
LayernormBetaSrcVectorSize
,
//
LayernormMeanVarSrcDstVectorSize>;
LayernormMeanVarSrcDstVectorSize
>
;
// Argument
// Argument
struct
Argument
:
public
BaseArgument
struct
Argument
:
public
BaseArgument
...
@@ -451,7 +450,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -451,7 +450,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
gamma_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
beta_grid_desc_n_
{
DeviceOp
::
MakeDescriptor_N
(
NRaw
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
h_grid_desc_m_n_
{
DeviceOp
::
MakeGridDescriptor_M_N
<
HLayout
>
(
MRaw
,
NRaw
,
StrideH
)},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
block_2_etile_map_
{
GridwiseGemm
Welford
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
cde_element_op_
{
cde_element_op
},
...
@@ -484,28 +483,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -484,28 +483,28 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
});
});
// populate desc for Ds/E/F/G
// populate desc for Ds/E/F/G
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
if
(
GridwiseGemm
Welford
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
e_grid_desc_m_n_
,
mean_grid_desc_m_n_
,
mean_grid_desc_m_n_
,
var_grid_desc_m_n_
,
var_grid_desc_m_n_
,
block_2_etile_map_
))
block_2_etile_map_
))
{
{
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
Welford
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
ds_grid_desc_m_n_
);
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
Welford
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
e_grid_desc_m_n_
);
mean_grid_desc_mblock_mperblock_nblock_
=
mean_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemm
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
GridwiseGemm
Welford
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
mean_grid_desc_m_n_
);
mean_grid_desc_m_n_
);
var_grid_desc_mblock_mperblock_nblock_
=
var_grid_desc_mblock_mperblock_nblock_
=
GridwiseGemm
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
GridwiseGemm
Welford
::
MakeMeanVarGridDescriptor_MBlock_MPerBlock_NBlock
(
var_grid_desc_m_n_
);
var_grid_desc_m_n_
);
}
}
...
@@ -526,7 +525,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -526,7 +525,7 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
// pointers
// pointers
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
typename
GridwiseGemm
Welford
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
EDataType
*
p_e_grid_
;
MeanDataType
*
p_mean_grid_
;
// mean
MeanDataType
*
p_mean_grid_
;
// mean
VarDataType
*
p_var_grid_
;
// variance * count
VarDataType
*
p_var_grid_
;
// variance * count
...
@@ -546,15 +545,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -546,15 +545,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
HGridDesc_M_N
h_grid_desc_m_n_
;
HGridDesc_M_N
h_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
// tensor descriptors for block/thread-wise copy
typename
GridwiseGemm
::
DefaultAGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
typename
GridwiseGemm
Welford
::
DefaultAGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
Welford
::
DefaultBGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
Welford
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
Welford
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
MeanGridDescriptor_MBlock_MPerBlock_NBlock
typename
GridwiseGemm
Welford
::
MeanGridDescriptor_MBlock_MPerBlock_NBlock
mean_grid_desc_mblock_mperblock_nblock_
;
mean_grid_desc_mblock_mperblock_nblock_
;
typename
GridwiseGemm
::
VarGridDescriptor_MBlock_MPerBlock_NBlock
typename
GridwiseGemm
Welford
::
VarGridDescriptor_MBlock_MPerBlock_NBlock
var_grid_desc_mblock_mperblock_nblock_
;
var_grid_desc_mblock_mperblock_nblock_
;
// block-to-e-tile map
// block-to-e-tile map
...
@@ -579,15 +578,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -579,15 +578,15 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
{
{
float
avg_time
=
0
;
float
avg_time
=
0
;
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
if
(
!
GridwiseGemm
Welford
::
CheckValidity
(
arg
.
a_grid_desc_m_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
b_grid_desc_n_k_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
ds_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
mean_grid_desc_m_n_
,
arg
.
mean_grid_desc_m_n_
,
arg
.
var_grid_desc_m_n_
,
arg
.
var_grid_desc_m_n_
,
arg
.
block_2_etile_map_
))
arg
.
block_2_etile_map_
))
{
{
throw
std
::
runtime_error
(
"wrong! GridwiseGemm has invalid setting"
);
throw
std
::
runtime_error
(
"wrong! GridwiseGemm
Welford
has invalid setting"
);
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
...
@@ -601,30 +600,32 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -601,30 +600,32 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
const
auto
kernel_gemm_welford
=
const
auto
kernel_gemm_welford
=
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle
<
kernel_gemm_multiple_d_welford_first_half_xdl_cshuffle
<
GridwiseGemm
,
GridwiseGemm
Welford
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
typename
GridwiseGemm
::
DsGridPointer
,
typename
GridwiseGemm
Welford
::
DsGridPointer
,
EDataType
,
EDataType
,
MeanDataType
,
MeanDataType
,
VarDataType
,
VarDataType
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
CDEElementwiseOperation
,
typename
GridwiseGemm
::
DefaultAGridDesc_AK0_M_AK1
,
typename
GridwiseGemmWelford
::
DefaultAGridDesc_AK0_M_AK1
,
typename
GridwiseGemm
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemmWelford
::
DefaultBGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemmWelford
::
typename
GridwiseGemm
::
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
MeanGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
typename
GridwiseGemm
::
VarGridDescriptor_MBlock_MPerBlock_NBlock
,
EGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
DefaultBlock2ETileMap
,
typename
GridwiseGemmWelford
::
MeanGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
VarGridDescriptor_MBlock_MPerBlock_NBlock
,
typename
GridwiseGemmWelford
::
DefaultBlock2ETileMap
,
has_main_loop
>
;
has_main_loop
>
;
//
const auto kernel_welford_layernorm =
const
auto
kernel_welford_layernorm
=
//
kernel_welford_layernorm2d_second_half<GridwiseWelfordLayernorm,
kernel_welford_layernorm2d_second_half
<
GridwiseWelfordLayernorm
,
//
EDataType,
EDataType
,
//
HDataType,
HDataType
,
//
MeanDataType,
MeanDataType
,
//
VarDataType>;
VarDataType
>
;
avg_time
+=
avg_time
+=
launch_and_time_kernel
(
stream_config
,
launch_and_time_kernel
(
stream_config
,
...
@@ -649,21 +650,21 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
...
@@ -649,21 +650,21 @@ struct DeviceGemmMultipleDLayernorm_Xdl_CShuffle : public BaseOperator
arg
.
var_grid_desc_mblock_mperblock_nblock_
,
arg
.
var_grid_desc_mblock_mperblock_nblock_
,
arg
.
block_2_etile_map_
);
arg
.
block_2_etile_map_
);
//
avg_time += launch_and_time_kernel(stream_config,
avg_time
+=
launch_and_time_kernel
(
stream_config
,
//
kernel_welford_layernorm,
kernel_welford_layernorm
,
//
dim3(grid_size),
dim3
(
grid_size
),
//
dim3(BlockSize),
dim3
(
BlockSize
),
//
0,
0
,
//
arg.p_e_grid_,
arg
.
p_e_grid_
,
//
arg.p_mean_grid_,
arg
.
p_mean_grid_
,
//
arg.p_var_grid_,
arg
.
p_var_grid_
,
//
arg.p_h_grid_,
arg
.
p_h_grid_
,
//
arg.blkGroupSize_);
arg
.
blkGroupSize_
);
return
avg_time
;
return
avg_time
;
};
};
if
(
GridwiseGemm
::
CalculateHasMainKBlockLoop
(
K
))
if
(
GridwiseGemm
Welford
::
CalculateHasMainKBlockLoop
(
K
))
{
{
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
return
launch_kernel
(
integral_constant
<
bool
,
true
>
{});
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp
View file @
b7f500f0
...
@@ -42,8 +42,15 @@ template <typename XDataType,
...
@@ -42,8 +42,15 @@ template <typename XDataType,
index_t
ScaleSrcVectorSize
,
index_t
ScaleSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
BiasSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
index_t
MeanVarSrcDstVectorSize
>
struct
DeviceBatchNormFwdImpl
struct
DeviceBatchNormFwdImpl
:
public
DeviceBatchNormFwd
<
XDataType
,
:
public
DeviceBatchNormFwd
<
Rank
,
NumBatchNormReduceDim
,
YElementwiseOp
>
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
static_assert
(
BlockSize
==
MThreadClusterSize
*
KThreadClusterSize
,
...
...
include/ck/tensor_operation/gpu/grid/gemm_layernorm/gridwise_welford_second_half_layernorm2d.hpp
View file @
b7f500f0
...
@@ -19,18 +19,115 @@
...
@@ -19,18 +19,115 @@
namespace
ck
{
namespace
ck
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
VarDataType
>
template
<
typename
EDataType
,
typename
HDataType
,
typename
MeanDataType
,
typename
VarDataType
,
typename
ComputeDataType
,
typename
XYGridDesc_M_N
,
typename
MeanVarGridDesc_M_N
,
typename
GammaBetaGridDesc_N
,
typename
MeanVarGridDesc_M
,
index_t
BlockSize
,
index_t
MThreadClusterSize
,
index_t
NThreadClusterSize
,
index_t
MThreadSliceSize
,
index_t
NThreadSliceSize
,
index_t
XSrcYDstVectorDim
,
index_t
XSrcVectorSize
,
index_t
YDstVectorSize
,
index_t
GammaSrcVectorSize
,
index_t
BetaSrcVectorSize
,
index_t
MeanVarSrcDstVectorSize
>
struct
GridwiseWelfordSecondHalfLayernorm2d
struct
GridwiseWelfordSecondHalfLayernorm2d
{
{
__device__
static
void
Run
(
const
XDataType
*
__restrict__
p_x_grid
,
static
constexpr
bool
reorder_thread_cluster
=
(
XSrcYDstVectorDim
==
0
);
using
ThreadClusterLengths_M_N
=
Sequence
<
MThreadClusterSize
,
NThreadClusterSize
>
;
using
ThreadBufferDimAccessOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
using
ThreadClusterArrangeOrder
=
typename
conditional
<
reorder_thread_cluster
,
Sequence
<
1
,
0
>
,
Sequence
<
0
,
1
>>::
type
;
static
constexpr
auto
thread_cluster_desc
=
make_cluster_descriptor
(
ThreadClusterLengths_M_N
{},
ThreadClusterArrangeOrder
{});
using
ThreadReduceSrcDesc_M_1
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{})));
using
ThreadReduceDstDesc_M
=
decltype
(
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{})));
using
ThreadwiseWelford
=
ThreadwiseWelfordMerge
<
ComputeDataType
,
ThreadReduceSrcDesc_M_1
,
ThreadReduceDstDesc_M
>
;
using
BlockwiseWelford
=
BlockwiseWelford
<
ComputeDataType
,
BlockSize
,
ThreadClusterLengths_M_N
,
ThreadClusterArrangeOrder
>
;
using
PassThroughOp
=
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
index_t
M_BlockTileSize
=
MThreadClusterSize
*
MThreadSliceSize
;
static
constexpr
index_t
N_BlockTileSize
=
NThreadClusterSize
*
NThreadSliceSize
;
__device__
static
void
Run
(
const
EDataType
*
__restrict__
p_e_grid
,
const
MeanDataType
*
__restrict__
p_mean_grid
,
const
MeanDataType
*
__restrict__
p_mean_grid
,
const
VarDataType
*
__restrict__
p_var_grid
,
const
VarDataType
*
__restrict__
p_var_grid
,
YDataType
*
__restrict__
p_y_grid
)
HDataType
*
__restrict__
p_h_grid
,
/*const MeanVarGridDesc_M_N& mean_grid_desc_m_k,
const MeanVarGridDesc_M_N& var_grid_desc_m_k,
const GammaBetaGridDesc_N& gamma_grid_desc_m,
const GammaBetaGridDesc_N& beta_grid_desc_m,
const MeanVarGridDesc_M& mean_var_grid_desc_m,*/
index_t
blkgroup_size
)
{
{
ignore
=
p_
x
_grid
;
ignore
=
p_
e
_grid
;
ignore
=
p_mean_grid
;
ignore
=
p_mean_grid
;
ignore
=
p_var_grid
;
ignore
=
p_var_grid
;
ignore
=
p_y_grid
;
ignore
=
p_h_grid
;
const
index_t
thread_local_id
=
get_thread_local_1d_id
();
const
index_t
block_global_id
=
get_block_1d_id
();
const
index_t
blkgroup_id
=
block_global_id
/
blkgroup_size
;
const
index_t
block_local_id
=
block_global_id
%
blkgroup_size
;
const
auto
thread_cluster_idx
=
thread_cluster_desc
.
CalculateBottomIndex
(
make_multi_index
(
thread_local_id
));
const
auto
thread_m_cluster_id
=
thread_cluster_idx
[
I0
];
const
auto
thread_n_cluster_id
=
thread_cluster_idx
[
I1
];
using
ThreadBufferLengths_M_N
=
Sequence
<
MThreadSliceSize
,
NThreadSliceSize
>
;
using
ThreadBufferLengths_M
=
Sequence
<
MThreadSliceSize
>
;
using
ThreadBufferLengths_M_1
=
Sequence
<
MThreadSliceSize
,
1
>
;
constexpr
auto
thread_buffer_desc_m_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
NThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{}));
constexpr
auto
thread_buffer_desc_m_1
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MThreadSliceSize
>
{},
Number
<
1
>
{}));
/*
auto threadwise_mean_load_m_n =
ThreadwiseTensorSliceTransfer_v2<MeanDataType,
ComputeDataType,
MeanVarGridDesc_M_N,
decltype(thread_buffer_desc_m_1),
ThreadBufferLengths_M_1,
Sequence<0, 1>,
1,
1,
1,
true>(
mean_grid_desc_m_n,
make_multi_index(blkgroup_id * M_BlockTileSize +
thread_m_cluster_id * MThreadSliceSize,
thread_n_cluster_id * 1));*/
}
// run
}
// run
};
};
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward
_nhwc_c
.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp
View file @
b7f500f0
...
@@ -4,13 +4,13 @@
...
@@ -4,13 +4,13 @@
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <vector>
#include <array>
#include <array>
#include <algorithm>
#include <algorithm>
#include <thread>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -23,20 +23,33 @@ template <typename XDataType,
...
@@ -23,20 +23,33 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
>
typename
YElementwiseOp
,
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
index_t
Rank
,
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
,
YElementwiseOp
>
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormFwd
:
public
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
const
BiasDataType
*
bnBias
,
...
@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double
averageFactor
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
MeanVarDataType
*
resultRunningVariance
)
:
p_x_
(
p_x
),
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
bnBias_
(
bnBias
),
y_elementwise_op_
(
y_elementwise_op
),
y_elementwise_op_
(
y_elementwise_op
),
...
@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_
(
resultRunningMean
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
resultRunningVariance_
(
resultRunningVariance
)
{
{
ignore
=
xStrides
;
using
ck
::
host_common
::
get_index_set
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
if
(
std
::
any_of
(
ignore
=
bnBiasStrides
;
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
ignore
=
bnMeanVarStrides
;
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
ignore
=
reduceDims
;
// get invariant_dims[] and invariant_lengths[]
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
if
(
std
::
none_of
(
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
n
=
xyLengths
[
0
];
invariantDims_
[
i
]
=
dim
;
h
=
xyLengths
[
1
];
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
w
=
xyLengths
[
2
];
i
++
;
c
=
xyLengths
[
3
];
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
y_invariant_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
y_reduce_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
...
@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
}
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
y_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
y_reduce_strides_
;
const
XDataType
*
p_x_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
const
BiasDataType
*
bnBias_
;
...
@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool
resultSave
,
resultRunning
;
bool
resultSave
,
resultRunning
;
index_t
n
,
h
,
w
,
c
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
averageFactor_
;
AccDataType
averageFactor_
;
AccDataType
epsilon_
;
AccDataType
epsilon_
;
...
@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{
{
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
using
ck
::
host_common
::
get_offset_from_index
;
index_t
offset_C
=
iC
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
y_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
y_invariant_strides_
,
invariant_index
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
int32_t
curr_count
=
0
;
// compute mean, variance using welford method
// compute mean, variance using welford method
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
arg
.
x_reduce_strides_
,
reduce_index
);
{
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
auto
offset
=
offset_N
+
offset
_H
+
offset_W
+
offset
_C
;
auto
x_
offset
=
x_invariant_
offset
+
x_reduce_
offset
;
curr_count
++
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
delta
=
x
-
mean
;
AccDataType
delta
=
x
-
mean
;
mean
+=
delta
/
curr_count
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
variance
+=
delta
*
delta2
;
};
}
};
};
// actual variance
// actual variance
variance
=
variance
/
curr_count
;
variance
=
variance
/
curr_count
;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType
invVariance
=
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
// save the mean/inv
V
ariance if required
// save the mean/inv
-v
ariance if required
if
(
arg
.
resultSave
)
if
(
arg
.
resultSave
)
{
{
arg
.
resultSaveMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
mean
);
size_t
offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
arg
.
resultSaveInvVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
);
invariant_index
);
arg
.
resultSaveMean_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
mean
);
arg
.
resultSaveInvVariance_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
);
};
};
// update the moving average if required
// update the moving average if required
if
(
arg
.
resultRunning
)
if
(
arg
.
resultRunning
)
{
{
size_t
offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
AccDataType
oneMinusAverageFactor
=
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
arg
.
resultRunningMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningMean_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
iC
])
*
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
offset
])
*
oneMinusAverageFactor
+
oneMinusAverageFactor
+
mean
*
arg
.
averageFactor_
);
mean
*
arg
.
averageFactor_
);
arg
.
resultRunningVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
iC
]
*
oneMinusAverageFactor
+
arg
.
resultRunningVariance_
[
offset
]
*
oneMinusAverageFactor
+
variance
*
arg
.
averageFactor_
);
variance
*
arg
.
averageFactor_
);
};
};
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
size_t
bias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnBiasStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
bnScale_
[
scale_offset
]);
AccDataType
bias
=
type_convert
<
AccDataType
>
(
arg
.
bnBias_
[
bias_offset
]);
// Normalization
// Normalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
arg
.
x_reduce_strides_
,
reduce_index
);
{
size_t
y_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
arg
.
y_reduce_strides_
,
reduce_index
);
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
y_offset
=
y_invariant_offset
+
y_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
norm_x
=
AccDataType
norm_x
=
(
x
-
mean
)
*
invVariance
;
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
YDataType
>
(
norm_x
);
AccDataType
y
=
scale
*
norm_x
+
bias
;
};
}
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
p_y_
[
y_offset
]
=
type_convert
<
YDataType
>
(
y
);
};
};
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
);
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
c
=
i
c
_begin
;
i
c
<
i
c
_end
;
++
i
c
)
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
{
thread_reduce_func
(
ic
);
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]
);
}
}
};
};
...
@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"Reference_BatchNorm_Forward
_NHWC_C<
"
<<
std
::
endl
;
str
<<
"Reference_BatchNorm_Forward"
<<
std
::
endl
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer
_nhwc_c
.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp
View file @
b7f500f0
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include <array>
#include <array>
#include <algorithm>
#include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -19,114 +20,205 @@ template <typename XDataType,
...
@@ -19,114 +20,205 @@ template <typename XDataType,
typename
AccDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
>
typename
MeanVarDataType
,
struct
ReferenceBatchNormInfer_Input_N_H_W_C_Output_C
:
public
device
::
DeviceBatchNormInfer
<
4
,
3
>
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormInfer
:
public
device
::
DeviceBatchNormInfer
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
const
BiasDataType
*
bnBias
,
double
epsilon
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
MeanVarDataType
*
estimatedMean
,
const
MeanVarDataType
*
estimatedMean
,
const
MeanVarDataType
*
estimatedVariance
,
const
MeanVarDataType
*
estimatedVariance
,
YDataType
*
p_y
)
YDataType
*
p_y
)
:
p_x_
(
p_x
),
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
bnBias_
(
bnBias
),
epsilon_
(
epsilon
),
y_elementwise_op_
(
y_elementwise_op
),
estimatedMean_
(
estimatedMean
),
estimatedMean_
(
estimatedMean
),
estimatedVariance_
(
estimatedVariance
),
estimatedVariance_
(
estimatedVariance
),
p_y_
(
p_y
)
p_y_
(
p_y
)
{
{
ignore
=
xStrides
;
using
ck
::
host_common
::
get_index_set
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
if
(
std
::
any_of
(
ignore
=
bnBiasStrides
;
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
ignore
=
bnMeanVarStrides
;
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
// get invariant_dims[] and invariant_lengths[]
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
n_
=
xyLengths
[
0
];
{
h_
=
xyLengths
[
1
];
invariantDims_
[
i
]
=
dim
;
w_
=
xyLengths
[
2
];
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
c_
=
xyLengths
[
3
];
i
++
;
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
y_invariant_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
y_reduce_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
}
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
y_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
y_reduce_strides_
;
const
XDataType
*
p_x_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
const
BiasDataType
*
bnBias_
;
const
YElementwiseOp
y_elementwise_op_
;
double
epsilon_
;
const
MeanVarDataType
*
estimatedMean_
;
const
MeanVarDataType
*
estimatedMean_
;
const
MeanVarDataType
*
estimatedVariance_
;
const
MeanVarDataType
*
estimatedVariance_
;
YDataType
*
p_y_
;
YDataType
*
p_y_
;
index_t
n_
,
h_
,
w_
,
c_
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
epsilon_
;
};
};
struct
Invoker
:
public
device
::
BaseInvoker
struct
Invoker
:
public
device
::
BaseInvoker
{
{
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
using
ck
::
host_common
::
get_offset_from_index
;
index_t
offset_C
=
iC
;
AccDataType
mean
=
arg
.
estimatedMean_
[
offset_C
];
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
AccDataType
variance
=
arg
.
estimatedVariance_
[
offset_C
];
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
y_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
y_invariant_strides_
,
invariant_index
);
size_t
mean_variance_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
AccDataType
mean
=
arg
.
estimatedMean_
[
mean_variance_offset
];
AccDataType
variance
=
arg
.
estimatedVariance_
[
mean_variance_offset
];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType
invVariance
=
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
type_convert
<
AccDataType
>
(
1.0
f
)
/
std
::
sqrt
(
arg
.
epsilon_
+
variance
);
std
::
sqrt
(
type_convert
<
AccDataType
>
(
arg
.
epsilon_
)
+
variance
);
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
size_t
bias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnBiasStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
bnScale_
[
scale_offset
]);
AccDataType
bias
=
type_convert
<
AccDataType
>
(
arg
.
bnBias_
[
bias_offset
]);
//
N
ormalization
//
n
ormalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n_
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
{
index_t
offset_N
=
iN
*
arg
.
h_
*
arg
.
w_
*
arg
.
c_
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
for
(
index_t
iH
=
0
;
iH
<
arg
.
h_
;
iH
++
)
arg
.
x_reduce_strides_
,
reduce_index
);
{
size_t
y_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
index_t
offset_H
=
iH
*
arg
.
w_
*
arg
.
c_
;
arg
.
y_reduce_strides_
,
reduce_index
);
for
(
index_t
iW
=
0
;
iW
<
arg
.
w_
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c_
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
y_offset
=
y_invariant_offset
+
y_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
norm_x
=
AccDataType
norm_x
=
(
x
-
mean
)
*
invVariance
;
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
YDataType
>
(
norm_x
);
AccDataType
y
=
scale
*
norm_x
+
bias
;
};
}
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
p_y_
[
y_offset
]
=
type_convert
<
YDataType
>
(
y
);
};
};
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c_
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c_
);
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
c
=
i
c
_begin
;
i
c
<
i
c
_end
;
++
i
c
)
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
{
thread_reduce_func
(
ic
);
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]
);
}
}
};
};
...
@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -151,17 +243,19 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
};
};
std
::
unique_ptr
<
device
::
BaseArgument
>
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnScale
,
const
void
*
bnBias
,
const
void
*
bnBias
,
double
epsilon
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
void
*
estimatedMean
,
const
void
*
estimatedMean
,
const
void
*
estimatedVariance
,
const
void
*
estimatedVariance
,
void
*
p_y
)
override
void
*
p_y
)
override
...
@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -169,6 +263,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
xStrides
,
yStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnScaleStrides
,
bnBiasStrides
,
bnBiasStrides
,
...
@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -177,6 +272,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
epsilon
,
epsilon
,
y_elementwise_op
,
static_cast
<
const
MeanVarDataType
*>
(
estimatedMean
),
static_cast
<
const
MeanVarDataType
*>
(
estimatedMean
),
static_cast
<
const
MeanVarDataType
*>
(
estimatedVariance
),
static_cast
<
const
MeanVarDataType
*>
(
estimatedVariance
),
static_cast
<
YDataType
*>
(
p_y
));
static_cast
<
YDataType
*>
(
p_y
));
...
@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
...
@@ -192,7 +288,7 @@ struct ReferenceBatchNormInfer_Input_N_H_W_C_Output_C : public device::DeviceBat
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"Reference_BatchNorm_
Forward_NHWC_C
<"
<<
std
::
endl
;
str
<<
"Reference_BatchNorm_
Infer
<"
<<
std
::
endl
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP32
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// BF16
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// Int8
void
add_device_batchnorm_forward_rank_4_3_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP64
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
BiasDataType
,
F16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
I8
>
&&
is_same_v
<
YDataType
,
I8
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
I8
>
&&
is_same_v
<
BiasDataType
,
I8
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_i8_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/host_common_util.hpp
View file @
b7f500f0
...
@@ -4,9 +4,11 @@
...
@@ -4,9 +4,11 @@
#pragma once
#pragma once
#include <vector>
#include <vector>
#include <array>
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <string>
#include <string>
#include <algorithm>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return
(
values
);
return
(
values
);
}
}
template
<
int
NDim
>
static
inline
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
get_index_set
(
const
std
::
array
<
index_t
,
NDim
>&
dim_lengths
)
{
static_assert
(
NDim
>=
1
,
"NDim >= 1 is required to use this function!"
);
if
constexpr
(
NDim
==
1
)
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
for
(
int
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
{
std
::
array
<
index_t
,
1
>
index
{
i
};
index_set
.
push_back
(
index
);
};
return
index_set
;
}
else
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
std
::
array
<
index_t
,
NDim
-
1
>
partial_dim_lengths
;
std
::
copy
(
dim_lengths
.
begin
()
+
1
,
dim_lengths
.
end
(),
partial_dim_lengths
.
begin
());
std
::
vector
<
std
::
array
<
index_t
,
NDim
-
1
>>
partial_index_set
;
partial_index_set
=
get_index_set
<
NDim
-
1
>
(
partial_dim_lengths
);
for
(
index_t
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
for
(
const
auto
&
partial_index
:
partial_index_set
)
{
std
::
array
<
index_t
,
NDim
>
index
;
index
[
0
]
=
i
;
std
::
copy
(
partial_index
.
begin
(),
partial_index
.
end
(),
index
.
begin
()
+
1
);
index_set
.
push_back
(
index
);
};
return
index_set
;
};
};
template
<
int
NDim
>
static
inline
size_t
get_offset_from_index
(
const
std
::
array
<
index_t
,
NDim
>&
strides
,
const
std
::
array
<
index_t
,
NDim
>&
index
)
{
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
NDim
;
i
++
)
offset
+=
index
[
i
]
*
strides
[
i
];
return
(
offset
);
};
}
// namespace host_common
}
// namespace host_common
}
// namespace ck
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
0 → 100644
View file @
b7f500f0
add_instance_library
(
device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_i8_instance.cpp
device_batchnorm_forward_f64_instance.cpp
)
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_bf16_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_bf16_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_bf16_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_bf16_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f16_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f16_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f16_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f16_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f32_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f32_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f32_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f32_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F64
=
double
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f64_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f64_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f64_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f64_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_i8_instance.cpp
0 → 100644
View file @
b7f500f0
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
I8
=
int8_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_i8_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_i8_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_i8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
I8
,
I8
,
F32
,
I8
,
I8
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_i8_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_i8_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
profiler/CMakeLists.txt
View file @
b7f500f0
...
@@ -26,6 +26,7 @@ set(PROFILER_SOURCE
...
@@ -26,6 +26,7 @@ set(PROFILER_SOURCE
src/profile_groupnorm.cpp
src/profile_groupnorm.cpp
src/profile_layernorm.cpp
src/profile_layernorm.cpp
src/profile_softmax.cpp
src/profile_softmax.cpp
src/profile_batchnorm_fwd.cpp
)
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
add_executable
(
ckProfiler
${
PROFILER_SOURCE
}
)
...
@@ -57,5 +58,6 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc
...
@@ -57,5 +58,6 @@ target_link_libraries(ckProfiler PRIVATE device_conv2d_fwd_bias_relu_add_instanc
target_link_libraries
(
ckProfiler PRIVATE device_normalization_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_normalization_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_softmax_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_softmax_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_reduce_instance
)
target_link_libraries
(
ckProfiler PRIVATE device_batchnorm_instance
)
rocm_install
(
TARGETS ckProfiler COMPONENT profiler
)
rocm_install
(
TARGETS ckProfiler COMPONENT profiler
)
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment