Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
463e2aa1
Commit
463e2aa1
authored
Nov 30, 2022
by
aska-0096
Browse files
Merge branch 'develop' of
https://github.com/ROCmSoftwarePlatform/composable_kernel
into wmma_op
parents
6e106c19
236bd148
Changes
83
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2126 additions
and
5 deletions
+2126
-5
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp
...ence_tensor_operation/cpu/reference_batchnorm_forward.hpp
+368
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp
...erence_tensor_operation/cpu/reference_batchnorm_infer.hpp
+300
-0
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
...rary/reference_tensor_operation/cpu/reference_softmax.hpp
+7
-5
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
..._operation_instance/device_operation_instance_factory.hpp
+2
-0
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
...ration_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
+56
-0
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
...brary/tensor_operation_instance/gpu/batchnorm_forward.hpp
+117
-0
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
...nsor_operation_instance/gpu/convolution_backward_data.hpp
+39
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
...brary/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
+145
-0
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
...k/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
+138
-0
library/include/ck/library/utility/host_common_util.hpp
library/include/ck/library/utility/host_common_util.hpp
+60
-0
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
...ance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
+1
-0
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
...cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
+133
-0
library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
...rc/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
+6
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
.../gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
+147
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
+147
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
+145
-0
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
...e/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
+145
-0
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
...sor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
+4
-0
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
...device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
+83
-0
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
...device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
+83
-0
No files found.
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward
_nhwc_c
.hpp
→
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_forward.hpp
View file @
463e2aa1
...
@@ -4,13 +4,13 @@
...
@@ -4,13 +4,13 @@
#pragma once
#pragma once
#include <iostream>
#include <iostream>
#include <vector>
#include <array>
#include <array>
#include <algorithm>
#include <algorithm>
#include <thread>
#include <thread>
#include "ck/utility/math_v2.hpp"
#include "ck/utility/math_v2.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/utility/ignore.hpp"
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -23,20 +23,33 @@ template <typename XDataType,
...
@@ -23,20 +23,33 @@ template <typename XDataType,
typename
ScaleDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
>
typename
YElementwiseOp
,
struct
ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
index_t
Rank
,
:
public
device
::
DeviceBatchNormFwd
<
4
,
3
,
YElementwiseOp
>
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormFwd
:
public
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
struct
Argument
:
public
device
::
BaseArgument
{
{
Argument
(
const
std
::
array
<
index_t
,
4
>
xyLengths
,
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
4
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
4
>
yStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
3
>
reduceDims
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
1
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
1
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
1
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
1
>
bnMeanVarStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
const
BiasDataType
*
bnBias
,
...
@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -48,7 +61,12 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
double
averageFactor
,
double
averageFactor
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningMean
,
MeanVarDataType
*
resultRunningVariance
)
MeanVarDataType
*
resultRunningVariance
)
:
p_x_
(
p_x
),
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
bnBias_
(
bnBias
),
y_elementwise_op_
(
y_elementwise_op
),
y_elementwise_op_
(
y_elementwise_op
),
...
@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -58,21 +76,51 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunningMean_
(
resultRunningMean
),
resultRunningMean_
(
resultRunningMean
),
resultRunningVariance_
(
resultRunningVariance
)
resultRunningVariance_
(
resultRunningVariance
)
{
{
ignore
=
xStrides
;
using
ck
::
host_common
::
get_index_set
;
ignore
=
yStrides
;
ignore
=
bnScaleStrides
;
if
(
std
::
any_of
(
ignore
=
bnBiasStrides
;
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
ignore
=
bnMeanVarStrides
;
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
ignore
=
reduceDims
;
// get invariant_dims[] and invariant_lengths[]
if
(
xyLengths
.
size
()
!=
4
||
bnScaleBiasMeanVarLengths
.
size
()
!=
1
||
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
bnScaleBiasMeanVarLengths
[
0
]
!=
xyLengths
[
3
])
if
(
std
::
none_of
(
throw
std
::
runtime_error
(
"Invalid tensor dimensions!"
);
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
n
=
xyLengths
[
0
];
invariantDims_
[
i
]
=
dim
;
h
=
xyLengths
[
1
];
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
w
=
xyLengths
[
2
];
i
++
;
c
=
xyLengths
[
3
];
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
y_invariant_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
y_reduce_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
averageFactor_
=
type_convert
<
AccDataType
>
(
averageFactor
);
...
@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -81,6 +129,21 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
resultRunning
=
(
resultRunningMean
!=
nullptr
&&
resultRunningVariance
!=
nullptr
);
}
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
y_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
y_reduce_strides_
;
const
XDataType
*
p_x_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
const
BiasDataType
*
bnBias_
;
...
@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -94,7 +157,8 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
bool
resultSave
,
resultRunning
;
bool
resultSave
,
resultRunning
;
index_t
n
,
h
,
w
,
c
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
averageFactor_
;
AccDataType
averageFactor_
;
AccDataType
epsilon_
;
AccDataType
epsilon_
;
...
@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -104,105 +168,119 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
{
{
float
Run
(
const
Argument
&
arg
)
float
Run
(
const
Argument
&
arg
)
{
{
auto
thread_reduce_func
=
[
&
](
auto
iC
)
{
using
ck
::
host_common
::
get_offset_from_index
;
index_t
offset_C
=
iC
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
y_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
y_invariant_strides_
,
invariant_index
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
mean
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
AccDataType
variance
=
type_convert
<
AccDataType
>
(
0.0
f
);
int32_t
curr_count
=
0
;
int32_t
curr_count
=
0
;
// compute mean, variance using welford method
// compute mean, variance using welford method
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
arg
.
x_reduce_strides_
,
reduce_index
);
{
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
auto
offset
=
offset_N
+
offset
_H
+
offset_W
+
offset
_C
;
auto
x_
offset
=
x_invariant_
offset
+
x_reduce_
offset
;
curr_count
++
;
curr_count
++
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
delta
=
x
-
mean
;
AccDataType
delta
=
x
-
mean
;
mean
+=
delta
/
curr_count
;
mean
+=
delta
/
curr_count
;
AccDataType
delta2
=
x
-
mean
;
AccDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
variance
+=
delta
*
delta2
;
};
}
};
};
// actual variance
// actual variance
variance
=
variance
/
curr_count
;
variance
=
variance
/
curr_count
;
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType
invVariance
=
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
type_convert
<
AccDataType
>
(
1.0
f
)
/
ck
::
math
::
sqrt
(
arg
.
epsilon_
+
variance
);
// save the mean/inv
V
ariance if required
// save the mean/inv
-v
ariance if required
if
(
arg
.
resultSave
)
if
(
arg
.
resultSave
)
{
{
arg
.
resultSaveMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
mean
);
size_t
offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
arg
.
resultSaveInvVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
);
invariant_index
);
arg
.
resultSaveMean_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
mean
);
arg
.
resultSaveInvVariance_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
invVariance
);
};
};
// update the moving average if required
// update the moving average if required
if
(
arg
.
resultRunning
)
if
(
arg
.
resultRunning
)
{
{
size_t
offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
AccDataType
oneMinusAverageFactor
=
AccDataType
oneMinusAverageFactor
=
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
type_convert
<
AccDataType
>
(
1.0
)
-
arg
.
averageFactor_
;
arg
.
resultRunningMean_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningMean_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
iC
])
*
type_convert
<
AccDataType
>
(
arg
.
resultRunningMean_
[
offset
])
*
oneMinusAverageFactor
+
oneMinusAverageFactor
+
mean
*
arg
.
averageFactor_
);
mean
*
arg
.
averageFactor_
);
arg
.
resultRunningVariance_
[
iC
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
offset
]
=
type_convert
<
MeanVarDataType
>
(
arg
.
resultRunningVariance_
[
iC
]
*
oneMinusAverageFactor
+
arg
.
resultRunningVariance_
[
offset
]
*
oneMinusAverageFactor
+
variance
*
arg
.
averageFactor_
);
variance
*
arg
.
averageFactor_
);
};
};
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
size_t
bias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnBiasStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
bnScale_
[
scale_offset
]);
AccDataType
bias
=
type_convert
<
AccDataType
>
(
arg
.
bnBias_
[
bias_offset
]);
// Normalization
// Normalization
for
(
index_t
iN
=
0
;
iN
<
arg
.
n
;
iN
++
)
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
{
index_t
offset_N
=
iN
*
arg
.
h
*
arg
.
w
*
arg
.
c
;
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
for
(
index_t
iH
=
0
;
iH
<
arg
.
h
;
iH
++
)
arg
.
x_reduce_strides_
,
reduce_index
);
{
size_t
y_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
index_t
offset_H
=
iH
*
arg
.
w
*
arg
.
c
;
arg
.
y_reduce_strides_
,
reduce_index
);
for
(
index_t
iW
=
0
;
iW
<
arg
.
w
;
iW
++
)
{
index_t
offset_W
=
iW
*
arg
.
c
;
auto
offset
=
offset_N
+
offset_H
+
offset_W
+
offset_C
;
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
y_offset
=
y_invariant_offset
+
y_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
offset
]);
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_
offset
]);
AccDataType
norm_x
=
AccDataType
norm_x
=
(
x
-
mean
)
*
invVariance
;
arg
.
bnScale_
[
iC
]
*
(
x
-
mean
)
*
invVariance
+
arg
.
bnBias_
[
iC
];
arg
.
p_y_
[
offset
]
=
type_convert
<
YDataType
>
(
norm_x
);
AccDataType
y
=
scale
*
norm_x
+
bias
;
};
}
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
p_y_
[
y_offset
]
=
type_convert
<
YDataType
>
(
y
);
};
};
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
c
+
num_thread
-
1
)
/
num_thread
;
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
{
std
::
size_t
ic_begin
=
it
*
work_per_thread
;
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
ic_end
=
std
::
min
(
static_cast
<
int
>
((
it
+
1
)
*
work_per_thread
),
arg
.
c
);
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
c
=
i
c
_begin
;
i
c
<
i
c
_end
;
++
i
c
)
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
{
thread_reduce_func
(
ic
);
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]
);
}
}
};
};
...
@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
...
@@ -278,7 +356,7 @@ struct ReferenceBatchNormFwd_Input_N_H_W_C_Output_C
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
// clang-format off
// clang-format off
str
<<
"Reference_BatchNorm_Forward
_NHWC_C<
"
<<
std
::
endl
;
str
<<
"Reference_BatchNorm_Forward"
<<
std
::
endl
;
// clang-format on
// clang-format on
return
str
.
str
();
return
str
.
str
();
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_batchnorm_infer.hpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include <array>
#include <algorithm>
#include "ck/library/utility/host_common_util.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_infer.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
host
{
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumBatchNormReduceDim
>
struct
ReferenceBatchNormInfer
:
public
device
::
DeviceBatchNormInfer
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumBatchNormReduceDim
>
{
static_assert
(
Rank
<=
6
,
"Bigger Rank size is not supported!"
);
static
constexpr
index_t
NumInvariantDim
=
Rank
-
NumBatchNormReduceDim
;
struct
Argument
:
public
device
::
BaseArgument
{
Argument
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
XDataType
*
p_x
,
const
ScaleDataType
*
bnScale
,
const
BiasDataType
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
MeanVarDataType
*
estimatedMean
,
const
MeanVarDataType
*
estimatedVariance
,
YDataType
*
p_y
)
:
reduceDims_
(
reduceDims
),
bnScaleBiasMeanVarLengths_
(
bnScaleBiasMeanVarLengths
),
bnScaleStrides_
(
bnScaleStrides
),
bnBiasStrides_
(
bnBiasStrides
),
bnMeanVarStrides_
(
bnMeanVarStrides
),
p_x_
(
p_x
),
bnScale_
(
bnScale
),
bnBias_
(
bnBias
),
y_elementwise_op_
(
y_elementwise_op
),
estimatedMean_
(
estimatedMean
),
estimatedVariance_
(
estimatedVariance
),
p_y_
(
p_y
)
{
using
ck
::
host_common
::
get_index_set
;
if
(
std
::
any_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[](
int
d
)
{
return
d
<
0
||
d
>=
Rank
;
}))
throw
std
::
runtime_error
(
"Invalid reduce dimensions!"
);
// get invariant_dims[] and invariant_lengths[]
for
(
int
dim
=
0
,
i
=
0
;
dim
<
Rank
;
dim
++
)
if
(
std
::
none_of
(
reduceDims
.
begin
(),
reduceDims
.
end
(),
[
&
](
int
d
)
{
return
d
==
dim
;
}))
{
invariantDims_
[
i
]
=
dim
;
invariant_lengths_
[
i
]
=
xyLengths
[
dim
];
i
++
;
};
// get reduce_lengths_[]
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims
[
j
];
reduce_lengths_
[
i
++
]
=
xyLengths
[
dim
];
};
// check invariant_lengths_ and bnScaleBiasMeanVarLengths
for
(
int
i
=
0
;
i
<
NumInvariantDim
;
i
++
)
if
(
invariant_lengths_
[
i
]
!=
bnScaleBiasMeanVarLengths_
[
i
])
throw
std
::
runtime_error
(
"Invalid lengths parameters!"
);
for
(
int
j
=
0
,
i
=
0
;
j
<
NumInvariantDim
;
j
++
)
{
int
dim
=
invariantDims_
[
j
];
x_invariant_strides_
[
i
]
=
xStrides
[
dim
];
y_invariant_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
for
(
int
j
=
0
,
i
=
0
;
j
<
NumBatchNormReduceDim
;
j
++
)
{
int
dim
=
reduceDims_
[
j
];
x_reduce_strides_
[
i
]
=
xStrides
[
dim
];
y_reduce_strides_
[
i
]
=
yStrides
[
dim
];
i
++
;
};
invariant_index_set_
=
get_index_set
<
NumInvariantDim
>
(
invariant_lengths_
);
reduce_index_set_
=
get_index_set
<
NumBatchNormReduceDim
>
(
reduce_lengths_
);
epsilon_
=
type_convert
<
AccDataType
>
(
epsilon
);
}
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims_
;
std
::
array
<
int
,
NumInvariantDim
>
invariantDims_
;
std
::
array
<
index_t
,
NumInvariantDim
>
invariant_lengths_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
reduce_lengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides_
;
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
x_invariant_strides_
;
std
::
array
<
index_t
,
NumInvariantDim
>
y_invariant_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
x_reduce_strides_
;
std
::
array
<
index_t
,
NumBatchNormReduceDim
>
y_reduce_strides_
;
const
XDataType
*
p_x_
;
const
ScaleDataType
*
bnScale_
;
const
BiasDataType
*
bnBias_
;
const
YElementwiseOp
y_elementwise_op_
;
const
MeanVarDataType
*
estimatedMean_
;
const
MeanVarDataType
*
estimatedVariance_
;
YDataType
*
p_y_
;
std
::
vector
<
std
::
array
<
index_t
,
NumInvariantDim
>>
invariant_index_set_
;
std
::
vector
<
std
::
array
<
index_t
,
NumBatchNormReduceDim
>>
reduce_index_set_
;
AccDataType
epsilon_
;
};
struct
Invoker
:
public
device
::
BaseInvoker
{
float
Run
(
const
Argument
&
arg
)
{
using
ck
::
host_common
::
get_offset_from_index
;
auto
thread_reduce_func
=
[
&
](
auto
invariant_index
)
{
size_t
x_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
x_invariant_strides_
,
invariant_index
);
size_t
y_invariant_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
y_invariant_strides_
,
invariant_index
);
size_t
mean_variance_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnMeanVarStrides_
,
invariant_index
);
AccDataType
mean
=
arg
.
estimatedMean_
[
mean_variance_offset
];
AccDataType
variance
=
arg
.
estimatedVariance_
[
mean_variance_offset
];
// inv-variance defined as 1/sqrt(epsilon+variance)
AccDataType
invVariance
=
type_convert
<
AccDataType
>
(
1.0
f
)
/
std
::
sqrt
(
arg
.
epsilon_
+
variance
);
size_t
scale_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnScaleStrides_
,
invariant_index
);
size_t
bias_offset
=
get_offset_from_index
<
NumInvariantDim
>
(
arg
.
bnBiasStrides_
,
invariant_index
);
AccDataType
scale
=
type_convert
<
AccDataType
>
(
arg
.
bnScale_
[
scale_offset
]);
AccDataType
bias
=
type_convert
<
AccDataType
>
(
arg
.
bnBias_
[
bias_offset
]);
// normalization
for
(
const
auto
&
reduce_index
:
arg
.
reduce_index_set_
)
{
size_t
x_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
x_reduce_strides_
,
reduce_index
);
size_t
y_reduce_offset
=
get_offset_from_index
<
NumBatchNormReduceDim
>
(
arg
.
y_reduce_strides_
,
reduce_index
);
auto
x_offset
=
x_invariant_offset
+
x_reduce_offset
;
auto
y_offset
=
y_invariant_offset
+
y_reduce_offset
;
AccDataType
x
=
type_convert
<
AccDataType
>
(
arg
.
p_x_
[
x_offset
]);
AccDataType
norm_x
=
(
x
-
mean
)
*
invVariance
;
AccDataType
y
=
scale
*
norm_x
+
bias
;
arg
.
y_elementwise_op_
(
y
,
y
);
arg
.
p_y_
[
y_offset
]
=
type_convert
<
YDataType
>
(
y
);
};
};
std
::
size_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
std
::
size_t
work_per_thread
=
(
arg
.
invariant_index_set_
.
size
()
+
num_thread
-
1
)
/
num_thread
;
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
i_begin
=
it
*
work_per_thread
;
std
::
size_t
i_end
=
std
::
min
(
static_cast
<
size_t
>
((
it
+
1
)
*
work_per_thread
),
arg
.
invariant_index_set_
.
size
());
auto
f
=
[
=
]
{
for
(
std
::
size_t
i
=
i_begin
;
i
<
i_end
;
++
i
)
{
thread_reduce_func
(
arg
.
invariant_index_set_
[
i
]);
}
};
threads
[
it
]
=
joinable_thread
(
f
);
}
return
(
0.0
f
);
};
float
Run
(
const
device
::
BaseArgument
*
p_arg
,
const
StreamConfig
&
/*stream_config*/
=
StreamConfig
{})
override
{
return
Run
(
*
dynamic_cast
<
const
Argument
*>
(
p_arg
));
};
};
bool
IsSupportedArgument
(
const
device
::
BaseArgument
*
p_arg
)
override
{
(
void
)
p_arg
;
return
(
true
);
};
std
::
unique_ptr
<
device
::
BaseArgument
>
MakeArgumentPointer
(
const
std
::
array
<
index_t
,
Rank
>
xyLengths
,
const
std
::
array
<
index_t
,
Rank
>
xStrides
,
const
std
::
array
<
index_t
,
Rank
>
yStrides
,
const
std
::
array
<
int
,
NumBatchNormReduceDim
>
reduceDims
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleBiasMeanVarLengths
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnScaleStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnBiasStrides
,
const
std
::
array
<
index_t
,
NumInvariantDim
>
bnMeanVarStrides
,
const
void
*
p_x
,
const
void
*
bnScale
,
const
void
*
bnBias
,
double
epsilon
,
const
YElementwiseOp
y_elementwise_op
,
const
void
*
estimatedMean
,
const
void
*
estimatedVariance
,
void
*
p_y
)
override
{
return
std
::
make_unique
<
Argument
>
(
xyLengths
,
xStrides
,
yStrides
,
reduceDims
,
bnScaleBiasMeanVarLengths
,
bnScaleStrides
,
bnBiasStrides
,
bnMeanVarStrides
,
static_cast
<
const
XDataType
*>
(
p_x
),
static_cast
<
const
ScaleDataType
*>
(
bnScale
),
static_cast
<
const
BiasDataType
*>
(
bnBias
),
epsilon
,
y_elementwise_op
,
static_cast
<
const
MeanVarDataType
*>
(
estimatedMean
),
static_cast
<
const
MeanVarDataType
*>
(
estimatedVariance
),
static_cast
<
YDataType
*>
(
p_y
));
};
std
::
unique_ptr
<
device
::
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
};
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Reference_BatchNorm_Infer<"
<<
std
::
endl
;
// clang-format on
return
str
.
str
();
}
};
}
// namespace host
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/reference_tensor_operation/cpu/reference_softmax.hpp
View file @
463e2aa1
...
@@ -86,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -86,8 +86,8 @@ struct ReferenceSoftmax : public device::BaseOperator
};
};
arg
.
in_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
in_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
reduce_max
(
to_sm_scalar_idx
(
idx
))
=
std
::
max
(
reduce_max
(
to_sm_scalar_idx
(
idx
)),
reduce_max
(
to_sm_scalar_idx
(
idx
))
=
std
::
max
(
static_cas
t
<
AccDataType
>
(
self
(
idx
)));
reduce_max
(
to_sm_scalar_idx
(
idx
)),
ck
::
type_conver
t
<
AccDataType
>
(
self
(
idx
)));
});
});
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
// LogRangeAsType<float>(std::cout << "reduce_max: ", reduce_max.mData, ",") <<
...
@@ -96,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -96,7 +96,7 @@ struct ReferenceSoftmax : public device::BaseOperator
Tensor
<
AccDataType
>
in_stable
(
arg
.
in_
.
mDesc
);
Tensor
<
AccDataType
>
in_stable
(
arg
.
in_
.
mDesc
);
in_stable
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
in_stable
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
// numerator = exp(x - max(x))
// numerator = exp(x - max(x))
self
(
idx
)
=
std
::
exp
(
static_cas
t
<
AccDataType
>
(
arg
.
in_
(
idx
))
-
self
(
idx
)
=
std
::
exp
(
ck
::
type_conver
t
<
AccDataType
>
(
arg
.
in_
(
idx
))
-
reduce_max
(
to_sm_scalar_idx
(
idx
)));
reduce_max
(
to_sm_scalar_idx
(
idx
)));
});
});
...
@@ -111,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
...
@@ -111,8 +111,10 @@ struct ReferenceSoftmax : public device::BaseOperator
// std::endl;
// std::endl;
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
arg
.
out_
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
self
(
idx
)
=
arg
.
alpha_
*
in_stable
(
idx
)
/
reduce_sum
(
to_sm_scalar_idx
(
idx
))
+
AccDataType
temp_result
=
arg
.
beta_
*
self
(
idx
);
arg
.
alpha_
*
in_stable
(
idx
)
/
reduce_sum
(
to_sm_scalar_idx
(
idx
))
+
arg
.
beta_
*
self
(
idx
);
self
(
idx
)
=
ck
::
type_convert
<
OutDataType
>
(
temp_result
);
});
});
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
// LogRangeAsType<float>(std::cout << "out: ", arg.out_.mData, ",") << std::endl;
...
...
library/include/ck/library/tensor_operation_instance/device_operation_instance_factory.hpp
View file @
463e2aa1
...
@@ -87,6 +87,8 @@ using Relu = ck::tensor_operation::element_wise::Relu;
...
@@ -87,6 +87,8 @@ using Relu = ck::tensor_operation::element_wise::Relu;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
Bilinear
=
ck
::
tensor_operation
::
element_wise
::
Bilinear
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddAddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddAddFastGelu
;
using
AddFastGelu
=
ck
::
tensor_operation
::
element_wise
::
AddFastGelu
;
using
FastGelu
=
ck
::
tensor_operation
::
element_wise
::
FastGelu
;
template
<
typename
Activation
>
template
<
typename
Activation
>
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
using
Activation_Mul_Clamp
=
ck
::
tensor_operation
::
element_wise
::
Activation_Mul_Clamp
<
Activation
>
;
...
...
library/include/ck/library/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute.hpp
View file @
463e2aa1
...
@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
...
@@ -59,6 +59,48 @@ void add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_g
MaskingSpecialization
::
MaskDisabled
>>>&
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
instances
);
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
);
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
);
template
<
typename
ADataType
,
template
<
typename
ADataType
,
typename
B0DataType
,
typename
B0DataType
,
typename
B1DataType
,
typename
B1DataType
,
...
@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
...
@@ -119,6 +161,20 @@ struct DeviceOperationInstanceFactory<
op_ptrs
);
op_ptrs
);
}
}
}
}
else
if
constexpr
(
is_same_v
<
ADataType
,
BF16
>
&&
is_same_v
<
B0DataType
,
BF16
>
&&
is_same_v
<
B1DataType
,
BF16
>
&&
is_same_v
<
CDataType
,
BF16
>
)
{
if
constexpr
(
MaskingSpec
==
MaskingSpecialization
::
MaskOutUpperTriangle
)
{
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
else
if
(
MaskingSpec
==
MaskingSpecialization
::
MaskDisabled
)
{
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
return
op_ptrs
;
}
}
};
};
...
...
library/include/ck/library/tensor_operation_instance/gpu/batchnorm_forward.hpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/device_batchnorm_forward.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
// FP16
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP32
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// BF16
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
);
// FP64
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
);
template
<
typename
XDataType
,
typename
YDataType
,
typename
AccDataType
,
typename
ScaleDataType
,
typename
BiasDataType
,
typename
MeanVarDataType
,
typename
YElementwiseOp
,
index_t
Rank
,
index_t
NumReduceDim
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>>
{
using
DeviceOp
=
DeviceBatchNormFwd
<
XDataType
,
YDataType
,
AccDataType
,
ScaleDataType
,
BiasDataType
,
MeanVarDataType
,
YElementwiseOp
,
Rank
,
NumReduceDim
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
XDataType
,
F16
>
&&
is_same_v
<
YDataType
,
F16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F16
>
&&
is_same_v
<
BiasDataType
,
F16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F32
>
&&
is_same_v
<
YDataType
,
F32
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
F32
>
&&
is_same_v
<
BiasDataType
,
F32
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
BF16
>
&&
is_same_v
<
YDataType
,
BF16
>
&&
is_same_v
<
AccDataType
,
F32
>
&&
is_same_v
<
ScaleDataType
,
BF16
>
&&
is_same_v
<
BiasDataType
,
BF16
>
&&
is_same_v
<
MeanVarDataType
,
F32
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
XDataType
,
F64
>
&&
is_same_v
<
YDataType
,
F64
>
&&
is_same_v
<
AccDataType
,
F64
>
&&
is_same_v
<
ScaleDataType
,
F64
>
&&
is_same_v
<
BiasDataType
,
F64
>
&&
is_same_v
<
MeanVarDataType
,
F64
>
)
{
if
constexpr
(
Rank
==
4
&&
NumReduceDim
==
3
&&
is_same_v
<
YElementwiseOp
,
PassThrough
>
)
{
add_device_batchnorm_forward_rank_4_3_f64_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/convolution_backward_data.hpp
View file @
463e2aa1
...
@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
...
@@ -101,6 +101,42 @@ void add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances(
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
PassThrough
>>>&
instances
);
// conv2d dl
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
F16
,
F16
,
F16
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
F32
,
F32
,
F32
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
int8_t
,
int8_t
,
int8_t
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
);
// conv3d backward data
// conv3d backward data
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
void
add_device_conv3d_bwd_data_xdl_ndhwc_kzyxc_ndhwk_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
3
,
...
@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -216,11 +252,13 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v
<
OutDataType
,
float
>
)
is_same_v
<
OutDataType
,
float
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
half_t
>
&&
is_same_v
<
WeiDataType
,
half_t
>
&&
is_same_v
<
OutDataType
,
half_t
>
)
is_same_v
<
OutDataType
,
half_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
op_ptrs
);
}
}
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
else
if
constexpr
(
is_same_v
<
InDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
is_same_v
<
WeiDataType
,
ck
::
bhalf_t
>
&&
...
@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
...
@@ -232,6 +270,7 @@ struct DeviceOperationInstanceFactory<ck::tensor_operation::device::DeviceConvBw
is_same_v
<
OutDataType
,
int8_t
>
)
is_same_v
<
OutDataType
,
int8_t
>
)
{
{
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instances
(
op_ptrs
);
}
}
}
}
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
else
if
constexpr
(
NumDimSpatial
==
3
&&
is_same_v
<
InLayout
,
NDHWC
>
&&
...
...
library/include/ck/library/tensor_operation_instance/gpu/gemm_add_fastgelu.hpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
void
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Row_Tuple
,
Row
,
F16
,
F16
,
F16_Tuple
,
F16
,
PassThrough
,
PassThrough
,
AddFastGelu
>>>&
);
// GEMM + Add + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
D0Layout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
D0DataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddFastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
ck
::
Tuple
<
D0Layout
>
,
ELayout
,
ADataType
,
BDataType
,
ck
::
Tuple
<
D0DataType
>
,
EDataType
,
PassThrough
,
PassThrough
,
AddFastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
D0DataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_kn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_mk_nk_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_kn_mn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
D0Layout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_add_fastgelu_xdl_c_shuffle_f16_f16_f16_f16_km_nk_mn_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/tensor_operation_instance/gpu/gemm_fastgelu.hpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <vector>
#include <memory>
#include "ck/ck.hpp"
#include "ck/library/tensor_operation_instance/device_operation_instance_factory.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm_multiple_d.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Row
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Row
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
void
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceGemmMultipleD
<
Col
,
Col
,
Empty_Tuple
,
Row
,
F16
,
F16
,
Empty_Tuple
,
F16
,
PassThrough
,
PassThrough
,
FastGelu
>>>&
);
// GEMM + FastGelu
template
<
typename
ALayout
,
typename
BLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
EDataType
>
struct
DeviceOperationInstanceFactory
<
ck
::
tensor_operation
::
device
::
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>>
{
using
DeviceOp
=
DeviceGemmMultipleD
<
ALayout
,
BLayout
,
Empty_Tuple
,
ELayout
,
ADataType
,
BDataType
,
Empty_Tuple
,
EDataType
,
PassThrough
,
PassThrough
,
FastGelu
>
;
static
auto
GetInstances
()
{
std
::
vector
<
std
::
unique_ptr
<
DeviceOp
>>
op_ptrs
;
if
constexpr
(
is_same_v
<
ADataType
,
half_t
>
&&
is_same_v
<
BDataType
,
half_t
>
&&
is_same_v
<
EDataType
,
half_t
>
)
{
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Row
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Row
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances
(
op_ptrs
);
}
else
if
constexpr
(
is_same_v
<
ALayout
,
Col
>
&&
is_same_v
<
BLayout
,
Col
>
&&
is_same_v
<
ELayout
,
Row
>
)
{
add_device_gemm_fastgelu_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances
(
op_ptrs
);
}
}
return
op_ptrs
;
}
};
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/include/ck/library/utility/host_common_util.hpp
View file @
463e2aa1
...
@@ -4,9 +4,11 @@
...
@@ -4,9 +4,11 @@
#pragma once
#pragma once
#include <vector>
#include <vector>
#include <array>
#include <iostream>
#include <iostream>
#include <fstream>
#include <fstream>
#include <string>
#include <string>
#include <algorithm>
#include "ck/ck.hpp"
#include "ck/ck.hpp"
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
...
@@ -72,5 +74,63 @@ static inline std::vector<T> getTypeValuesFromString(const char* cstr_values)
return
(
values
);
return
(
values
);
}
}
template
<
int
NDim
>
static
inline
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
get_index_set
(
const
std
::
array
<
index_t
,
NDim
>&
dim_lengths
)
{
static_assert
(
NDim
>=
1
,
"NDim >= 1 is required to use this function!"
);
if
constexpr
(
NDim
==
1
)
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
for
(
int
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
{
std
::
array
<
index_t
,
1
>
index
{
i
};
index_set
.
push_back
(
index
);
};
return
index_set
;
}
else
{
std
::
vector
<
std
::
array
<
index_t
,
NDim
>>
index_set
;
std
::
array
<
index_t
,
NDim
-
1
>
partial_dim_lengths
;
std
::
copy
(
dim_lengths
.
begin
()
+
1
,
dim_lengths
.
end
(),
partial_dim_lengths
.
begin
());
std
::
vector
<
std
::
array
<
index_t
,
NDim
-
1
>>
partial_index_set
;
partial_index_set
=
get_index_set
<
NDim
-
1
>
(
partial_dim_lengths
);
for
(
index_t
i
=
0
;
i
<
dim_lengths
[
0
];
i
++
)
for
(
const
auto
&
partial_index
:
partial_index_set
)
{
std
::
array
<
index_t
,
NDim
>
index
;
index
[
0
]
=
i
;
std
::
copy
(
partial_index
.
begin
(),
partial_index
.
end
(),
index
.
begin
()
+
1
);
index_set
.
push_back
(
index
);
};
return
index_set
;
};
};
template
<
int
NDim
>
static
inline
size_t
get_offset_from_index
(
const
std
::
array
<
index_t
,
NDim
>&
strides
,
const
std
::
array
<
index_t
,
NDim
>&
index
)
{
size_t
offset
=
0
;
for
(
int
i
=
0
;
i
<
NDim
;
i
++
)
offset
+=
index
[
i
]
*
strides
[
i
];
return
(
offset
);
};
}
// namespace host_common
}
// namespace host_common
}
// namespace ck
}
// namespace ck
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/CMakeLists.txt
View file @
463e2aa1
add_instance_library
(
device_batched_gemm_softmax_gemm_permute_instance
add_instance_library
(
device_batched_gemm_softmax_gemm_permute_instance
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_f16_f16_f16_f16_gmk_gnk_gno_gmo_instance.cpp
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/batched_gemm_softmax_gemm_permute/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_xdl_cshuffle.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
Scale
=
ck
::
tensor_operation
::
element_wise
::
Scale
;
static
constexpr
auto
GemmDefault
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmPadded
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKOPadding
;
static
constexpr
auto
TensorDefault
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
// c[g, m, n] = a[g, m, k] * b[g, n, k]
template
<
index_t
NumDimG
,
index_t
NumDimM
,
index_t
NumDimN
,
index_t
NumDimK
,
index_t
NumDimO
,
MaskingSpecialization
MaskingSpec
>
using
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
=
std
::
tuple
<
// clang-format off
// #############################################| NumDimG| NumDimM| NumDimN| NumDimK| NumDimO| AData| B0Data| B1Data| CData| Acc0BiasData| Acc1BiasData| AccData| CShuffle| A| B0| Acc0| B1| C| GEMM| ATensorSpec| B0TensorSpec| B1TensorSpec| CTensorSpec| NumGemmK| Block| Gemm01| Gemm0| Gemm0| Gemm1| Gemm1| AK1| BK1| B1K1| MPer| NPer| Gemm0| Gemm0| Gemm1| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockTransfer| B0BlockLds| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockTransfer| B1BlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| MaskingSpec|
// #############################################| | | | | | Type| Type| Type| Type| Type| Type| Type| DataType| Elementwise| Elementwise| Elementwise| Elementwise| Elementwise| Specialization| | | | | Prefetch| Size| MPer| NPer| KPer| NPer| KPer| | | | XDL| XDL| MXdl| NXdl| NXdl| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MWaveMPerXdl| ScalarPerVector| |
// #############################################| | | | | | | | | | | | | | Operation| Operation| Operation| Operation| Operation| | | | | | Stage| | Block| Block| Block| Block| Block| | | | | | Per| Per| Per| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NWaveNPerXdl| _NWaveNPerXdl| |
// #############################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | Wave| Wave| Wave| | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
256
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
2
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
8
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
32
,
64
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
64
,
256
,
32
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
64
,
256
,
32
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
64
,
256
,
64
,
128
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
8
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
8
,
S
<
1
,
16
,
1
,
16
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
64
,
256
,
64
,
64
,
32
,
8
,
8
,
2
,
16
,
16
,
1
,
16
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
4
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
// Padded fallback kernel
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
128
,
64
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
4
,
4
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
false
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
,
DeviceBatchedGemmSoftmaxGemmPermute_Xdl_CShuffle
<
NumDimG
,
NumDimM
,
NumDimN
,
NumDimK
,
NumDimO
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
GemmPadded
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
TensorDefault
,
1
,
256
,
128
,
64
,
32
,
128
,
32
,
8
,
8
,
2
,
32
,
32
,
1
,
2
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
true
,
S
<
8
,
32
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
2
,
false
,
1
,
2
,
S
<
1
,
32
,
1
,
8
>
,
8
,
MaskingSpec
>
// clang-format on
>
;
void
add_device_batched_gemm_masking_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskOutUpperTriangle
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskOutUpperTriangle
>
{});
}
void
add_device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchedGemmSoftmaxGemmPermute
<
2
,
1
,
1
,
1
,
1
,
BF16
,
BF16
,
BF16
,
BF16
,
ck
::
Tuple
<>
,
ck
::
Tuple
<>
,
PassThrough
,
PassThrough
,
Scale
,
PassThrough
,
PassThrough
,
MaskingSpecialization
::
MaskDisabled
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batched_gemm_softmax_gemm_permute_xdl_cshuffle_bf16_bf16_bf16_bf16_gmk_gnk_gno_gmo_instances
<
2
,
1
,
1
,
1
,
1
,
MaskingSpecialization
::
MaskDisabled
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/CMakeLists.txt
0 → 100644
View file @
463e2aa1
add_instance_library
(
device_batchnorm_instance
device_batchnorm_forward_f16_instance.cpp
device_batchnorm_forward_f32_instance.cpp
device_batchnorm_forward_bf16_instance.cpp
device_batchnorm_forward_f64_instance.cpp
)
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_bf16_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
BF16
=
ck
::
bhalf_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_bf16_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_bf16_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_bf16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
BF16
,
BF16
,
F32
,
BF16
,
BF16
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_bf16_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_bf16_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f16_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f16_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f16_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F16
,
F16
,
F32
,
F16
,
F16
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f16_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f16_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f32_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F32
=
float
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f32_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f32_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F32
,
F32
,
F32
,
F32
,
F32
,
F32
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f32_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f32_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/batchnorm/device_batchnorm_forward_f64_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_batchnorm_forward_impl.hpp"
#include "ck/utility/data_type.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
F64
=
double
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f64_blockwise_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
false
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
// clang-format off
template
<
index_t
Rank
,
index_t
NumReduceDim
,
typename
YElementwiseOp
>
using
device_batchnorm_forward_f64_multiblock_instances
=
std
::
tuple
<
// XDataType, YDataType, AccDataType, ScaleDataType, BiasDataType, MeanVarDataType, YElementwiseOp, Rank, NumReduceDim, UseMultiBlockInK, BLockSize, MThreadClusterSize, KThreadClusterSize, MThreadSliceSize, KThreadSliceSize, XSrcYDstVectorDim, XSrcVectorSize, YDstVectorSize, ScaleSrcVectorSize, BiasSrcVectorSize, MeanVarSrcDstVectorSize
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
128
,
2
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
64
,
4
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
32
,
8
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
16
,
16
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
8
,
32
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
4
,
64
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
2
,
128
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
1
,
1
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
0
,
2
,
2
,
1
,
1
,
1
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
2
,
2
,
2
>
,
DeviceBatchNormFwdImpl
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
YElementwiseOp
,
Rank
,
NumReduceDim
,
true
,
256
,
1
,
256
,
2
,
2
,
1
,
1
,
1
,
1
,
1
,
1
>
>
;
// clang-format on
void
add_device_batchnorm_forward_rank_4_3_f64_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceBatchNormFwd
<
F64
,
F64
,
F64
,
F64
,
F64
,
F64
,
PassThrough
,
4
,
3
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f64_blockwise_instances
<
4
,
3
,
PassThrough
>
{});
add_device_operation_instances
(
instances
,
device_batchnorm_forward_f64_multiblock_instances
<
4
,
3
,
PassThrough
>
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/CMakeLists.txt
View file @
463e2aa1
...
@@ -3,4 +3,8 @@ add_instance_library(device_conv2d_bwd_data_instance
...
@@ -3,4 +3,8 @@ add_instance_library(device_conv2d_bwd_data_instance
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_bf16_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
device_conv2d_bwd_data_xdl_nhwc_kyxc_nhwk_int8_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_int8_instance.cpp
)
)
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
InDataType
=
ck
::
half_t
;
using
WeiDataType
=
ck
::
half_t
;
using
OutDataType
=
ck
::
half_t
;
using
AccDataType
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
KYXC
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
NHWK
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNdBwdDataNwcKxcNwk_Dl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvBwdDataDefault
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
8
,
2
>
,
S
<
16
,
1
,
16
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
using
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances
=
std
::
tuple
<
// clang-format off
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNdBwdDataNwcKxcNwk_Dl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
128
,
16
,
2
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
2
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
2
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
1
,
1
,
8
,
2
>
,
S
<
16
,
1
,
16
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
8
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
2
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
InDataType
,
WeiDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f16_instances
{});
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f16_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
library/src/tensor_operation_instance/gpu/conv2d_bwd_data/device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instance.cpp
0 → 100644
View file @
463e2aa1
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <cstdlib>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_convnd_bwd_data_nwc_kxc_nwk_dl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "ck/library/tensor_operation_instance/add_device_operation_instance.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
instance
{
using
InDataType
=
float
;
using
WeiDataType
=
float
;
using
OutDataType
=
float
;
using
AccDataType
=
float
;
template
<
ck
::
index_t
...
Is
>
using
S
=
ck
::
Sequence
<
Is
...
>
;
using
NHWC
=
ck
::
tensor_layout
::
convolution
::
NHWC
;
using
KYXC
=
ck
::
tensor_layout
::
convolution
::
KYXC
;
using
NHWK
=
ck
::
tensor_layout
::
convolution
::
NHWK
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
InElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
WeiElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
OutElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
static
constexpr
auto
ConvBwdDataDefault
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Default
;
static
constexpr
auto
ConvBwdDataFilter1x1Stride1Pad0
=
ck
::
tensor_operation
::
device
::
ConvolutionBackwardDataSpecialization
::
Filter1x1Stride1Pad0
;
// Compilation parameters for in[n, hi, wi, c] * wei[k, y, x, c] = out[n, ho, wo, k]
using
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
=
std
::
tuple
<
// clang-format off
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNdBwdDataNwcKxcNwk_Dl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvBwdDataDefault
,
256
,
128
,
128
,
16
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
1
,
8
,
1
>
,
S
<
16
,
1
,
16
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
using
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances
=
std
::
tuple
<
// clang-format off
//#########################| NDim| InData| WeiData| OutData| AccData| In| Wei| Out| Convolution| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| BBlockTransfer| CThreadTransfer| CThreadTransfer| CThreadTransfer|
//#########################| Spatial| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise| Forward| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| SrcVectorTensor| SrcVectorTensor| DstVectorTensor| SrcDstAccess| SrcDstVectorDim| DstScalarPerVector|
//#########################| | | | | | Operation| Operation| Operation| Specialization| | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| Lengths_K0_M0_M1_K1| ContiguousDimOrder| Lengths_K0_M0_M1_K1| K0_N0_N1_K1| K0_N0_N1_K1| ArrangeOrder| Order| Lengths_K0_N0_N1_K1| ContiguousDimOrder| Lengths_K0_N0_N1_K1| Order| | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceConvNdBwdDataNwcKxcNwk_Dl
<
2
,
InDataType
,
WeiDataType
,
OutDataType
,
AccDataType
,
InElementOp
,
WeiElementOp
,
OutElementOp
,
ConvBwdDataFilter1x1Stride1Pad0
,
256
,
128
,
128
,
16
,
1
,
4
,
4
,
1
,
S
<
8
,
2
>
,
S
<
8
,
2
>
,
S
<
8
,
1
,
1
,
1
>
,
S
<
2
,
1
,
128
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
4
,
1
,
1
,
1
>
,
S
<
1
,
2
,
0
,
3
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
1
,
1
,
8
,
1
>
,
S
<
16
,
1
,
16
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
4
,
1
>
,
S
<
0
,
3
,
1
,
2
>
,
S
<
1
,
1
,
1
,
1
>
,
S
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
4
>
// clang-format on
>
;
void
add_device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
(
std
::
vector
<
std
::
unique_ptr
<
DeviceConvBwdData
<
2
,
NHWC
,
KYXC
,
NHWK
,
InDataType
,
WeiDataType
,
OutDataType
,
PassThrough
,
PassThrough
,
PassThrough
>>>&
instances
)
{
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_f32_instances
{});
add_device_operation_instances
(
instances
,
device_conv2d_bwd_data_dl_nhwc_kyxc_nhwk_1x1_s1_p0_f32_instances
{});
}
}
// namespace instance
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment