Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
6046d8fb
Commit
6046d8fb
authored
Apr 25, 2023
by
yuguo960516yuguo
Browse files
dtk23.04
parent
a715222c
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
419 additions
and
1010 deletions
+419
-1010
cmake/oneflow.cmake
cmake/oneflow.cmake
+10
-10
cmake/third_party.cmake
cmake/third_party.cmake
+2
-2
oneflow/core/common/math_util.h
oneflow/core/common/math_util.h
+2
-2
oneflow/core/cuda/atomic.cuh
oneflow/core/cuda/atomic.cuh
+1
-1
oneflow/core/cuda/layer_norm.cuh
oneflow/core/cuda/layer_norm.cuh
+46
-4
oneflow/core/embedding/lru_cache.cu
oneflow/core/embedding/lru_cache.cu
+20
-7
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
...w/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
+1
-1
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
...e/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
+36
-36
oneflow/core/ep/cuda/primitive/unary_functor.cuh
oneflow/core/ep/cuda/primitive/unary_functor.cuh
+2
-2
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
+156
-25
oneflow/user/kernels/fused_gelu_mul_kernel.cu
oneflow/user/kernels/fused_gelu_mul_kernel.cu
+4
-4
oneflow/user/kernels/group_norm_kernel.cu
oneflow/user/kernels/group_norm_kernel.cu
+4
-0
oneflow/user/kernels/layer_norm_gpu_kernel.cu
oneflow/user/kernels/layer_norm_gpu_kernel.cu
+26
-647
oneflow/user/kernels/math_binary_elementwise_func.h
oneflow/user/kernels/math_binary_elementwise_func.h
+1
-1
oneflow/user/kernels/normalization_kernel.cu
oneflow/user/kernels/normalization_kernel.cu
+95
-268
oneflow/user/ops/normalization_op.cpp
oneflow/user/ops/normalization_op.cpp
+13
-0
No files found.
cmake/oneflow.cmake
View file @
6046d8fb
...
...
@@ -334,16 +334,16 @@ elseif(WIN32)
set
(
CMAKE_EXE_LINKER_FLAGS
"
${
CMAKE_EXE_LINKER_FLAGS
}
/WHOLEARCHIVE:oneflow"
)
endif
()
if
(
BUILD_ROCM
)
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
set_source_files_properties
(
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS
"-O0"
)
endif
()
#
if (BUILD_ROCM)
#
# AMD compiler fails to compile these three files with '-O1/2/3'.
#
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
#
# so '-O0' will override '-O1/2/3'.
#
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
#
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
#
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#
#
${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
#
PROPERTIES COMPILE_OPTIONS "-O0")
#
endif()
if
(
BUILD_CUDA
)
string
(
JOIN
","
CUDA_REAL_ARCHS
${
CUDA_REAL_ARCHS_LIST
}
)
...
...
cmake/third_party.cmake
View file @
6046d8fb
oneflow/core/common/math_util.h
View file @
6046d8fb
...
...
@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template
<
typename
T
>
OF_DEVICE_FUNC
T
DeviceMin
(
T
a
,
T
b
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined(__HIP_DEVICE_COMPILE__)
return
a
<
b
?
a
:
b
;
#else
return
std
::
min
(
a
,
b
);
...
...
@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template
<
typename
T
>
OF_DEVICE_FUNC
T
DeviceMax
(
T
a
,
T
b
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined(__HIP_DEVICE_COMPILE__)
return
a
>
b
?
a
:
b
;
#else
return
std
::
max
(
a
,
b
);
...
...
oneflow/core/cuda/atomic.cuh
View file @
6046d8fb
...
...
@@ -54,7 +54,7 @@ struct CastCASImpl {
}
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
|| defined(WITH_ROCM)
template
<
typename
T
>
struct
CastCASImpl
<
T
,
unsigned
short
int
>
{
...
...
oneflow/core/cuda/layer_norm.cuh
View file @
6046d8fb
...
...
@@ -308,7 +308,8 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared
[
wid
]
=
warp_count
;
}
__syncthreads
();
if
(
wid
==
0
)
{
#ifdef WITH_ROCM
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
warp_mean
=
mean_shared
[
lid
];
warp_m2
=
m2_shared
[
lid
];
...
...
@@ -318,10 +319,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
#ifdef WITH_ROCM
__syncthreads
();
__syncthreads
();
if
(
wid
==
0
)
{
#else
__syncwarp
();
if
(
wid
==
0
)
{
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
warp_mean
=
mean_shared
[
lid
];
warp_m2
=
m2_shared
[
lid
];
warp_count
=
count_shared
[
lid
];
}
else
{
warp_mean
=
static_cast
<
T
>
(
0
);
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
__syncwarp
();
#endif
T
block_mean
=
0
;
T
block_m2
=
0
;
...
...
@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
constexpr
int
block_size
=
128
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
#endif
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
...
...
@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
...
...
@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
...
...
@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
constexpr
int
block_size
=
1024
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
#endif
int
grid_dim_x
;
{
GPU
(
Error_t
)
err
=
...
...
@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
constexpr
int
block_size
=
128
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
#endif
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
...
...
@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
...
...
oneflow/core/embedding/lru_cache.cu
View file @
6046d8fb
...
...
@@ -252,15 +252,21 @@ struct SetContext {
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_hit
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_hit
);
if
(
hit_mask
!=
0
)
{
return
__ffsll
(
static_cast
<
unsigned
long
long
int
>
(
hit_mask
))
-
1
;
}
else
{
return
-
1
;
}
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_hit
);
#endif
if
(
hit_mask
!=
0
)
{
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
}
else
{
return
-
1
;
}
#endif
}
__device__
void
Read
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
...
...
@@ -277,15 +283,16 @@ struct SetContext {
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_key
==
key
&&
lane_age
!=
0
);
#endif
if
(
hit_mask
!=
0
)
{
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
#ifdef WITH_ROCM
insert_way
=
__ffsll
(
static_cast
<
unsigned
long
long
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
#else
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl_sync
(
kFullMask
,
lane_age
,
insert_way
);
#endif
if
(
lane_age
>
insert_way_age
)
{
...
...
@@ -301,12 +308,16 @@ __syncwarp();
}
if
(
insert_way
==
-
1
)
{
#ifdef WITH_ROCM
const
unsigned
valid_mask
=
__ballot
(
lane_age
!=
0
);
const
unsigned
long
long
int
valid_mask
=
__ballot
(
lane_age
!=
0
);
#else
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
#endif
if
(
valid_mask
!=
kFullMask
)
{
#ifdef WITH_ROCM
insert_way
=
__popcll
(
static_cast
<
unsigned
long
long
int
>
(
valid_mask
));
#else
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
#endif
if
(
lane_age
>
0
)
{
lane_age
-=
1
;
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
...
...
@@ -329,7 +340,8 @@ __syncwarp();
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
const
int
insert_way
=
__ffs
(
static_cast
<
int
>
(
__ballot
(
lane_age
==
1
)))
-
1
;
unsigned
long
long
int
valid_mask_tmp
=
__ballot
(
lane_age
==
1
);
const
int
insert_way
=
__ffsll
(
valid_mask_tmp
)
-
1
;
#else
const
int
insert_way
=
__ffs
(
__ballot_sync
(
kFullMask
,
lane_age
==
1
))
-
1
;
#endif
...
...
@@ -594,7 +606,8 @@ __syncwarp();
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
#ifdef WITH_ROCM
const
int
key_count
=
__popc
(
static_cast
<
int
>
(
__ballot
(
lane_age
!=
0
)));
unsigned
long
long
int
valid_mask_tmp
=
__ballot
(
lane_age
!=
0
);
const
int
key_count
=
__popcll
(
valid_mask_tmp
);
#else
const
int
key_count
=
__popc
(
__ballot_sync
(
kFullMask
,
lane_age
!=
0
));
#endif
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
View file @
6046d8fb
...
...
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive/
/
broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
View file @
6046d8fb
oneflow/core/ep/cuda/primitive/unary_functor.cuh
View file @
6046d8fb
...
...
@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
float
tanh_in
=
__half2float
(
__float2half_rn
(
alpha
)
*
(
src
+
__float2half_rn
(
beta
)
*
src
*
src
*
src
));
const
float
tanh_out
=
unary_functor_internal
::
TanhApprox
(
tanh_in
);
...
...
@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
dst
,
const
half
*
src
)
const
{
const
half2
src2
=
*
(
reinterpret_cast
<
const
half2
*>
(
src
));
const
float2
tanh_in
=
__half22float2
(
__hmul2
(
...
...
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
View file @
6046d8fb
...
...
@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like"
,
"layer_norm_grad"
,
"layer_norm"
,
"layer_norm_param_grad"
};
return
black_list
;
}
"layer_norm_param_grad"
,
const
AMPList
&
AutoMixedPrecisionLists
::
GrayList
()
{
static
AMPList
gray_list
=
{
"add_n"
,
"add_n"
,
"tf_avg_pool_1d"
,
"tf_avg_pool_1d_grad"
,
"tf_avg_pool_2d"
,
"tf_avg_pool_2d_grad"
,
"tf_avg_pool_3d"
,
"tf_avg_pool_3d_grad"
,
"avg_pool_1d"
,
"avg_pool_1d_grad"
,
"avg_pool_2d"
,
"avg_pool_2d_grad"
,
"avg_pool_3d"
,
"avg_pool_3d_grad"
,
//
"avg_pool_1d",
//
"avg_pool_1d_grad",
//
"avg_pool_2d",
//
"avg_pool_2d_grad",
//
"avg_pool_3d",
//
"avg_pool_3d_grad",
"bias_add"
,
"sigmoid_grad"
,
"tanh"
,
...
...
@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad"
,
"silu"
,
"silu_grad"
,
"fused_weighted_sum"
};
return
gray_list
;
}
"fused_weighted_sum"
,
const
AMPList
&
AutoMixedPrecisionLists
::
ClearList
()
{
// TODO(niuchong): tuple_identity
static
AMPList
clear_list
=
{
"broadcast_like"
,
"broadcast_like"
,
"gather"
,
"gather_nd"
,
"scatter_nd"
,
...
...
@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad"
,
"tf_max_pool_3d"
,
"tf_max_pool_3d_grad"
,
"max_pool_1d"
,
"max_pool_1d_grad"
,
"max_pool_2d"
,
"max_pool_2d_grad"
,
"max_pool_3d"
,
"max_pool_3d_grad"
,
//
"max_pool_1d",
//
"max_pool_1d_grad",
//
"max_pool_2d",
//
"max_pool_2d_grad",
//
"max_pool_3d",
//
"max_pool_3d_grad",
"reshape"
,
"reshape_like"
,
"relu"
,
...
...
@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity"
,
"to_contiguous"
,
"copy"
,
"upsample_nearest_2d"
};
"upsample_nearest_2d"
};
return
black_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
GrayList
()
{
static
AMPList
gray_list
=
{
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d"
,
"avg_pool_1d_grad"
,
"avg_pool_2d"
,
"avg_pool_2d_grad"
,
"avg_pool_3d"
,
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return
gray_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
ClearList
()
{
// TODO(niuchong): tuple_identity
static
AMPList
clear_list
=
{
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d"
,
"max_pool_1d_grad"
,
"max_pool_2d"
,
"max_pool_2d_grad"
,
"max_pool_3d"
,
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return
clear_list
;
}
...
...
oneflow/user/kernels/fused_gelu_mul_kernel.cu
View file @
6046d8fb
...
...
@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC
FusedFastGeluMulFunctor
()
{}
OF_DEVICE_FUNC
half
operator
()(
const
half
x
,
const
half
m
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
float
tanh_in
=
__half2float
(
__float2half_rn
(
alpha
)
*
(
x
+
__float2half_rn
(
beta
)
*
x
*
x
*
x
));
const
float
tanh_out
=
TanhApprox
(
tanh_in
);
...
...
@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
y
,
const
half
*
x
,
const
half
*
m
)
const
{
const
half2
x2
=
*
(
reinterpret_cast
<
const
half2
*>
(
x
));
const
float2
tanh_in
=
__half22float2
(
...
...
@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__
void
operator
()(
half
&
x_diff
,
half
&
m_diff
,
const
half
&
dy
,
const
half
&
x
,
const
half
&
m
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
half
halpha
=
__float2half_rn
(
alpha
);
const
half
hbeta
=
__float2half_rn
(
beta
);
const
half
hone
=
__float2half_rn
(
1.0
F
);
...
...
@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
x_diff
,
half
*
m_diff
,
const
half
*
dy
,
const
half
*
x
,
const
half
*
m
)
const
{
const
half2
dy2
=
*
(
reinterpret_cast
<
const
half2
*>
(
dy
));
...
...
oneflow/user/kernels/group_norm_kernel.cu
View file @
6046d8fb
...
...
@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr
int
kReduceBlockSize
=
512
;
constexpr
int
kBlockSize
=
128
;
#ifdef WITH_ROCM
constexpr
int
kNumWaves
=
64
;
#else
constexpr
int
kNumWaves
=
32
;
#endif
inline
GPU
(
Error_t
)
GetReduceNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
int
dev
;
...
...
oneflow/user/kernels/layer_norm_gpu_kernel.cu
View file @
6046d8fb
...
...
@@ -27,9 +27,12 @@ limitations under the License.
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#ifdef WITH_CUDA
#include <thrust/pair.h>
#ifdef WITH_ROCM
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#endif
namespace
oneflow
{
namespace
{
...
...
@@ -143,18 +146,31 @@ __inline__ __device__ T WarpReduce(T val) {
return
val
;
}
constexpr
int
tile_size
=
32
;
constexpr
int
num_per_block
=
4
;
#ifdef WITH_ROCM
constexpr
int
tile_size
=
64
;
constexpr
int
block_dim_x
=
64
;
constexpr
int
block_dim_y
=
64
/
num_per_block
;
#else
constexpr
int
tile_size
=
32
;
constexpr
int
block_dim_x
=
32
;
constexpr
int
block_dim_y
=
32
/
num_per_block
;
#endif
template
<
typename
T
,
typename
ComputeType
>
__global__
void
LayerNormParamGrad
(
int
rows
,
int
cols
,
const
T
*
__restrict__
dy
,
const
T
*
__restrict__
x
,
const
ComputeType
*
__restrict__
mean
,
const
ComputeType
*
__restrict__
inv_var
,
T
*
__restrict__
tmp_gamma_diff
,
T
*
__restrict__
tmp_beta_diff
)
{
#ifdef WITH_ROCM
__shared__
ComputeType
dgamma
[
64
][
65
];
__shared__
ComputeType
dbeta
[
64
][
65
];
#else
__shared__
ComputeType
dgamma
[
32
][
33
];
__shared__
ComputeType
dbeta
[
32
][
33
];
#endif
ComputeType
dgamma_sum
[
num_per_block
];
ComputeType
dbeta_sum
[
num_per_block
];
#pragma unroll
...
...
@@ -210,13 +226,13 @@ int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) {
const
int
max_grid_dim_y
=
(
num_instances
+
tile_size
-
1
)
/
tile_size
;
const
int
block_size
=
block_dim_x
*
block_dim_y
;
int
max_active_blocks
=
0
;
OF_CUDA_CHECK
(
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
OF_CUDA_CHECK
(
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks
,
LayerNormParamGrad
<
T
,
ComputeType
>
,
block_size
,
0
));
int
waves
=
1
;
int
dev
;
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
dev
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
dev
));
int
sm_count
;
OF_CUDA_CHECK
(
cuda
DeviceGetAttribute
(
&
sm_count
,
cudaDevAttr
MultiProcessorCount
,
dev
));
OF_CUDA_CHECK
(
GPU
(
DeviceGetAttribute
)
(
&
sm_count
,
GPU
MultiProcessorCount
,
dev
));
int
num_blocks
=
max_active_blocks
*
sm_count
*
waves
;
int
grid_dim_y
=
std
::
min
(
max_grid_dim_y
,
static_cast
<
int
>
(
num_blocks
/
grid_dim_x
));
return
std
::
max
(
grid_dim_y
,
1
);
...
...
@@ -420,6 +436,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
const
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
DataType
data_type
=
dy
->
data_type
();
const
int
grid_dim_x
=
(
norm_size
+
tile_size
-
1
)
/
tile_size
;
const
int
grid_dim_y
=
GetGirdDimY
<
T
>
(
num_instances
,
norm_size
);
const
size_t
tmp_gamma_diff_size
=
grid_dim_y
*
norm_size
*
sizeof
(
T
);
...
...
@@ -428,7 +445,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
T
*
reduce_buf_ptr
=
reinterpret_cast
<
T
*>
(
tmp_buffer
->
mut_dptr
<
char
>
()
+
2
*
tmp_gamma_diff_size
);
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
LayerNormParamGrad
<
T
,
ComputeType
><<<
dim3
(
grid_dim_x
,
grid_dim_y
),
dim3
(
32
,
32
/
num_per_block
),
LayerNormParamGrad
<
T
,
ComputeType
><<<
dim3
(
grid_dim_x
,
grid_dim_y
),
dim3
(
block_dim_x
,
block_dim_y
),
0
,
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
num_instances
,
norm_size
,
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
inv_variance
->
dptr
<
ComputeType
>
(),
tmp_gamma_diff_ptr
,
tmp_beta_diff_ptr
);
...
...
@@ -476,651 +493,13 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float
)
#ifdef WITH_CUDA
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
double
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
nv_bfloat16
)
#endif
}
// namespace oneflow
#endif
#ifdef WITH_ROCM
#include <hipcub/hipcub.hpp>
#include <thrust/pair.h>
template
<
typename
T
,
bool
is_cuda
>
struct
AccumulateType
{
};
#if defined(__HIPCC__)
template
<
>
struct
AccumulateType
<
half
,
true
>
{
using
type
=
float
;
};
#endif
template
<
>
struct
AccumulateType
<
float
,
true
>
{
using
type
=
float
;
};
template
<
>
struct
AccumulateType
<
double
,
true
>
{
using
type
=
double
;
};
template
<
>
struct
AccumulateType
<
int8_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
uint8_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
char
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int16_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int32_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int64_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
bool
,
true
>
{
using
type
=
bool
;
};
template
<
>
struct
AccumulateType
<
float
,
false
>
{
using
type
=
double
;
};
template
<
>
struct
AccumulateType
<
double
,
false
>
{
using
type
=
double
;
};
template
<
>
struct
AccumulateType
<
int8_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
uint8_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
char
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int16_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int32_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int64_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
bool
,
false
>
{
using
type
=
bool
;
};
template
<
typename
T
,
bool
is_cuda
>
using
acc_type
=
typename
AccumulateType
<
T
,
is_cuda
>::
type
;
#define C10_HOST_DEVICE __host__ __device__
#define C10_DEVICE __device__
#define C10_HOST __host__
#define C10_WARP_SIZE 64
#define VEC 4
typedef
int64_t
IndexType
;
constexpr
int
BlockReduceNumThreads
=
512
;
constexpr
int
NumThreads
=
256
;
constexpr
int
ColwiseReduceTileSize
=
32
;
template
<
typename
scalar_t
,
typename
index_t
,
typename
combine_t
>
struct
WelfordData
{
scalar_t
mean
;
scalar_t
m2
;
index_t
n
;
combine_t
nf
;
C10_HOST_DEVICE
WelfordData
()
:
mean
(
0
),
m2
(
0
),
n
(
0
),
nf
(
0
)
{}
C10_HOST_DEVICE
WelfordData
(
scalar_t
mean
,
scalar_t
m2
,
index_t
n
,
combine_t
nf
)
:
mean
(
mean
),
m2
(
m2
),
n
(
n
),
nf
(
nf
)
{}
};
template
<
typename
scalar_t
,
typename
acc_scalar_t
,
typename
index_t
,
typename
combine_t
,
typename
res_t
>
struct
WelfordOps
{
public:
using
acc_t
=
WelfordData
<
acc_scalar_t
,
index_t
,
combine_t
>
;
inline
C10_DEVICE
acc_t
reduce
(
acc_t
acc
,
scalar_t
data
)
const
{
acc_scalar_t
delta
=
data
-
acc
.
mean
;
// using acc.nf(combine_t) here, as acc.n(index_t) would still be converted
// accumulation in reduce is done through index_T
acc_scalar_t
new_mean
=
acc
.
mean
+
delta
/
(
acc
.
nf
+
1
);
acc_scalar_t
new_delta
=
data
-
new_mean
;
return
{
new_mean
,
acc
.
m2
+
delta
*
new_delta
,
acc
.
n
+
1
,
combine_t
(
acc
.
n
+
1
),
// accumulate for combine_t uses index_t
};
}
inline
C10_DEVICE
acc_t
combine
(
acc_t
a
,
acc_t
b
)
const
{
if
(
a
.
nf
==
0
)
{
return
b
;
}
if
(
b
.
nf
==
0
)
{
return
a
;
}
acc_scalar_t
delta
=
b
.
mean
-
a
.
mean
;
combine_t
new_count
=
a
.
nf
+
b
.
nf
;
acc_scalar_t
nb_over_n
=
b
.
nf
/
new_count
;
return
{
a
.
mean
+
delta
*
nb_over_n
,
a
.
m2
+
b
.
m2
+
delta
*
delta
*
a
.
nf
*
nb_over_n
,
// setting acc.n as -1 since acc.n might not be able to represent the count
// correctly within its range, setting it to -1 to avoid confusion
-
1
,
new_count
};
}
inline
C10_DEVICE
res_t
project
(
acc_t
acc
)
const
{
return
res_t
(
acc
.
m2
/
acc
.
nf
,
static_cast
<
scalar_t
>
(
acc
.
mean
));
}
inline
__device__
acc_t
warp_shfl_down
(
acc_t
acc
,
int
offset
)
const
{
return
{
__shfl_down
(
acc
.
mean
,
offset
)
,
__shfl_down
(
acc
.
m2
,
offset
)
,
__shfl_down
(
acc
.
n
,
offset
)
,
__shfl_down
(
acc
.
nf
,
offset
)
};
}
};
template
<
int
max
=
32
,
typename
T
,
class
ReduceOp
>
__inline__
__device__
T
WarpReduce
(
T
val
,
const
ReduceOp
&
op
)
{
#pragma unroll
for
(
int
offset
=
max
;
offset
>
0
;
offset
>>=
1
)
{
val
=
op
.
combine
(
val
,
op
.
warp_shfl_down
(
val
,
offset
));
}
return
val
;
}
template
<
typename
T
,
class
ReduceOp
>
__inline__
__device__
T
BlockReduce
(
T
val
,
const
ReduceOp
&
op
,
T
*
shared
)
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
val
=
WarpReduce
(
val
,
op
);
__syncthreads
();
if
(
lid
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
shared
[
lid
];
val
=
WarpReduce
<
4
>
(
val
,
op
);
}
return
val
;
}
template
<
int
max
=
32
,
typename
T
>
__inline__
__device__
T
WarpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
max
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
>
__inline__
__device__
T
BlockReduceSum
(
T
val
,
T
*
shared
)
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
val
=
WarpReduceSum
(
val
);
__syncthreads
();
if
(
lid
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
shared
[
lid
];
val
=
WarpReduceSum
<
4
>
(
val
);
}
return
val
;
}
template
<
typename
scalar_t
>
__global__
void
layernorm_forward_kernel
(
const
scalar_t
*
input
,
scalar_t
*
ret
,
acc_type
<
scalar_t
,
true
>*
mean
,
acc_type
<
scalar_t
,
true
>*
rstd
,
const
scalar_t
*
gamma
,
const
scalar_t
*
beta
,
IndexType
cols
,
double
eps
)
{
//dropout do nothing in val mode
IndexType
i
=
blockIdx
.
x
;
// add + layernorm get mean and rstd
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
using
WelfordType
=
WelfordData
<
T_ACC
,
IndexType
,
T_ACC
>
;
using
WelfordOp
=
WelfordOps
<
T_ACC
,
T_ACC
,
IndexType
,
T_ACC
,
thrust
::
pair
<
T_ACC
,
T_ACC
>>
;
__shared__
typename
std
::
aligned_storage
<
sizeof
(
WelfordType
),
alignof
(
WelfordType
)
>::
type
val_shared
[
BlockReduceNumThreads
/
C10_WARP_SIZE
];
WelfordType
*
val_shared_ptr
=
reinterpret_cast
<
WelfordType
*>
(
val_shared
);
WelfordOp
welford_op
;
WelfordType
val
;
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
cols
;
j
+=
blockDim
.
x
)
{
IndexType
index
=
i
*
cols
+
j
;
val
=
welford_op
.
reduce
(
val
,
static_cast
<
T_ACC
>
(
input
[
index
]));
}
val
=
BlockReduce
(
val
,
welford_op
,
val_shared_ptr
);
__shared__
T_ACC
s_mean
;
__shared__
T_ACC
s_rstd
;
if
(
threadIdx
.
x
==
0
)
{
thrust
::
tie
(
s_rstd
,
s_mean
)
=
welford_op
.
project
(
val
);
mean
[
i
]
=
s_mean
;
s_rstd
=
rsqrt
(
s_rstd
+
static_cast
<
T_ACC
>
(
eps
));
rstd
[
i
]
=
s_rstd
;
}
__syncthreads
();
//layernorm (x-mean)*rstd*gamma+beta
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
cols
;
j
+=
blockDim
.
x
)
{
IndexType
index
=
i
*
cols
+
j
;
ret
[
index
]
=
static_cast
<
scalar_t
>
((
static_cast
<
T_ACC
>
(
input
[
index
])
-
s_mean
)
*
s_rstd
*
(
gamma
==
nullptr
?
T_ACC
(
1
)
:
static_cast
<
T_ACC
>
(
gamma
[
j
]))
+
(
beta
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
beta
[
j
])));
}
}
template
<
typename
T
>
void
LayerNormKernelImplInternal
(
oneflow
::
ep
::
Stream
*
stream
,
const
T
*
X
,
const
T
*
gamma
,
const
T
*
beta
,
int64_t
M
,
int64_t
N
,
double
eps
,
T
*
Y
,
acc_type
<
T
,
true
>*
mean
,
acc_type
<
T
,
true
>*
rstd
)
{
using
T_ACC
=
acc_type
<
T
,
true
>
;
const
T
*
X_data
=
X
;
const
T
*
gamma_data
=
gamma
;
const
T
*
beta_data
=
beta
;
T
*
Y_data
=
Y
;
T_ACC
*
mean_data
=
mean
;
T_ACC
*
rstd_data
=
rstd
;
hipStream_t
cuda_stream
=
stream
->
As
<
oneflow
::
ep
::
CudaStream
>
()
->
cuda_stream
();
layernorm_forward_kernel
<
T
><<<
M
,
BlockReduceNumThreads
,
0
,
cuda_stream
>>>
(
X_data
,
Y_data
,
mean_data
,
rstd_data
,
gamma_data
,
beta_data
,
N
,
eps
);
}
template
<
typename
scalar_t
>
__global__
void
GammaBetaBackwardSimple
(
IndexType
M
,
IndexType
N
,
const
scalar_t
*
dY
,
const
scalar_t
*
X
,
const
acc_type
<
scalar_t
,
true
>*
mean
,
const
acc_type
<
scalar_t
,
true
>*
rstd
,
scalar_t
*
dg
,
scalar_t
*
db
)
{
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
j
<
N
)
{
T_ACC
sum1
=
0
;
T_ACC
sum2
=
0
;
for
(
int64_t
i
=
0
;
i
<
M
;
++
i
)
{
const
int64_t
index
=
i
*
N
+
j
;
sum1
+=
dg
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index
])
*
(
static_cast
<
T_ACC
>
(
X
[
index
])
-
static_cast
<
T_ACC
>
(
mean
[
i
]))
*
static_cast
<
T_ACC
>
(
rstd
[
i
]);
sum2
+=
db
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index
]);
}
if
(
dg
!=
nullptr
)
{
dg
[
j
]
=
static_cast
<
scalar_t
>
(
sum1
);
}
if
(
db
!=
nullptr
)
{
db
[
j
]
=
static_cast
<
scalar_t
>
(
sum2
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
GammaBetaBackward
(
IndexType
M
,
IndexType
N
,
const
scalar_t
*
dY
,
const
scalar_t
*
X
,
const
acc_type
<
scalar_t
,
true
>*
mean
,
const
acc_type
<
scalar_t
,
true
>*
rstd
,
scalar_t
*
dg
,
scalar_t
*
db
)
{
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
__shared__
T_ACC
g_shared
[
ColwiseReduceTileSize
][
ColwiseReduceTileSize
+
1
];
__shared__
T_ACC
b_shared
[
ColwiseReduceTileSize
][
ColwiseReduceTileSize
+
1
];
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T_ACC
dg_sum1
=
0
;
T_ACC
dg_sum2
=
0
;
T_ACC
db_sum1
=
0
;
T_ACC
db_sum2
=
0
;
if
(
j
<
N
)
{
for
(
int64_t
i
=
threadIdx
.
y
;
i
<
M
;
i
+=
blockDim
.
y
*
2
)
{
const
int64_t
i1
=
i
;
const
int64_t
i2
=
i
+
blockDim
.
y
;
const
int64_t
index1
=
i1
*
N
+
j
;
const
int64_t
index2
=
i2
*
N
+
j
;
dg_sum1
+=
dg
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index1
])
*
(
static_cast
<
T_ACC
>
(
X
[
index1
])
-
static_cast
<
T_ACC
>
(
mean
[
i1
]))
*
static_cast
<
T_ACC
>
(
rstd
[
i1
]);
db_sum1
+=
db
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index1
]);
if
(
i2
<
M
)
{
dg_sum2
+=
dg
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index2
])
*
(
static_cast
<
T_ACC
>
(
X
[
index2
])
-
static_cast
<
T_ACC
>
(
mean
[
i2
]))
*
static_cast
<
T_ACC
>
(
rstd
[
i2
]);
db_sum2
+=
db
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index2
]);
}
}
}
g_shared
[
threadIdx
.
y
][
threadIdx
.
x
]
=
dg_sum1
;
g_shared
[
threadIdx
.
y
+
blockDim
.
y
][
threadIdx
.
x
]
=
dg_sum2
;
b_shared
[
threadIdx
.
y
][
threadIdx
.
x
]
=
db_sum1
;
b_shared
[
threadIdx
.
y
+
blockDim
.
y
][
threadIdx
.
x
]
=
db_sum2
;
__syncthreads
();
T_ACC
sum1
=
g_shared
[
threadIdx
.
x
][
threadIdx
.
y
];
T_ACC
sum2
=
b_shared
[
threadIdx
.
x
][
threadIdx
.
y
];
sum1
=
WarpReduceSum
<
16
>
(
sum1
);
sum2
=
WarpReduceSum
<
16
>
(
sum2
);
if
(
threadIdx
.
x
==
0
)
{
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
if
(
j
<
N
)
{
if
(
dg
!=
nullptr
)
{
dg
[
j
]
=
static_cast
<
scalar_t
>
(
sum1
);
}
if
(
db
!=
nullptr
)
{
db
[
j
]
=
static_cast
<
scalar_t
>
(
sum2
);
}
}
}
sum1
=
g_shared
[
threadIdx
.
x
][
threadIdx
.
y
+
blockDim
.
y
];
sum2
=
b_shared
[
threadIdx
.
x
][
threadIdx
.
y
+
blockDim
.
y
];
sum1
=
WarpReduceSum
<
16
>
(
sum1
);
sum2
=
WarpReduceSum
<
16
>
(
sum2
);
if
(
threadIdx
.
x
==
0
)
{
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
+
blockDim
.
y
;
if
(
j
<
N
)
{
if
(
dg
!=
nullptr
)
{
dg
[
j
]
=
static_cast
<
scalar_t
>
(
sum1
);
}
if
(
db
!=
nullptr
)
{
db
[
j
]
=
static_cast
<
scalar_t
>
(
sum2
);
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
LayerNormBackward_kernel
(
IndexType
N
,
const
scalar_t
*
dY
,
const
scalar_t
*
X
,
const
scalar_t
*
gamma
,
const
acc_type
<
scalar_t
,
true
>*
mean
,
const
acc_type
<
scalar_t
,
true
>*
rstd
,
scalar_t
*
dX
,
const
scalar_t
*
add_to_output
)
{
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
__shared__
T_ACC
ds_shared
[
C10_WARP_SIZE
];
__shared__
T_ACC
db_shared
[
C10_WARP_SIZE
];
const
IndexType
i
=
blockIdx
.
x
;
T_ACC
sum1
=
0
;
T_ACC
sum2
=
0
;
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
N
;
j
+=
blockDim
.
x
)
{
const
IndexType
index
=
i
*
N
+
j
;
const
T_ACC
gamma_v
=
gamma
==
nullptr
?
T_ACC
(
1
)
:
static_cast
<
T_ACC
>
(
gamma
[
j
]);
sum1
+=
static_cast
<
T_ACC
>
(
dY
[
index
])
*
static_cast
<
T_ACC
>
(
X
[
index
])
*
gamma_v
;
sum2
+=
static_cast
<
T_ACC
>
(
dY
[
index
])
*
gamma_v
;
}
sum1
=
BlockReduceSum
<
T_ACC
>
(
sum1
,
ds_shared
);
sum2
=
BlockReduceSum
<
T_ACC
>
(
sum2
,
db_shared
);
const
T_ACC
s
=
T_ACC
(
1
)
/
static_cast
<
T_ACC
>
(
N
);
__shared__
T_ACC
b
;
__shared__
T_ACC
c
;
if
(
threadIdx
.
x
==
0
)
{
b
=
(
sum2
*
static_cast
<
T_ACC
>
(
mean
[
i
])
-
sum1
)
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
s
;
c
=
-
(
b
*
static_cast
<
T_ACC
>
(
mean
[
i
])
+
sum2
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
s
);
}
__syncthreads
();
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
N
;
j
+=
blockDim
.
x
)
{
const
IndexType
index
=
i
*
N
+
j
;
const
T_ACC
gamma_v
=
gamma
==
nullptr
?
T_ACC
(
1
)
:
static_cast
<
T_ACC
>
(
gamma
[
j
]);
dX
[
index
]
=
static_cast
<
scalar_t
>
(
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
static_cast
<
T_ACC
>
(
dY
[
index
])
*
gamma_v
+
b
*
static_cast
<
T_ACC
>
(
X
[
index
])
+
c
+
(
add_to_output
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
add_to_output
[
index
])));
}
}
template
<
typename
T
>
void
LayerNormBackwardKernelImplInternal
(
oneflow
::
ep
::
Stream
*
stream
,
const
T
*
dY
,
const
T
*
X
,
const
acc_type
<
T
,
true
>*
mean
,
const
acc_type
<
T
,
true
>*
rstd
,
const
T
*
gamma
,
int64_t
M
,
int64_t
N
,
T
*
dX
,
const
T
*
add_to_output
)
{
using
T_ACC
=
acc_type
<
T
,
true
>
;
const
T
*
dY_data
=
dY
;
const
T
*
X_data
=
X
;
const
T_ACC
*
mean_data
=
mean
;
const
T_ACC
*
rstd_data
=
rstd
;
const
T
*
gamma_data
=
gamma
;
T
*
dX_data
=
dX
;
const
T
*
add_to_output_data
=
add_to_output
;
hipStream_t
cuda_stream
=
stream
->
As
<
oneflow
::
ep
::
CudaStream
>
()
->
cuda_stream
();
if
(
dX_data
!=
nullptr
)
{
LayerNormBackward_kernel
<
T
><<<
M
,
BlockReduceNumThreads
,
0
,
cuda_stream
>>>
(
N
,
dY_data
,
X_data
,
gamma_data
,
mean_data
,
rstd_data
,
dX_data
,
add_to_output_data
);
}
}
template
<
typename
T
>
void
LayerNormBackwardKernelImplInternalParam
(
oneflow
::
ep
::
Stream
*
stream
,
const
T
*
dY
,
const
T
*
X
,
const
acc_type
<
T
,
true
>*
mean
,
const
acc_type
<
T
,
true
>*
rstd
,
int64_t
M
,
int64_t
N
,
T
*
dgamma
,
T
*
dbeta
)
{
using
T_ACC
=
acc_type
<
T
,
true
>
;
const
T
*
dY_data
=
dY
;
const
T
*
X_data
=
X
;
const
T_ACC
*
mean_data
=
mean
;
const
T_ACC
*
rstd_data
=
rstd
;
hipStream_t
cuda_stream
=
stream
->
As
<
oneflow
::
ep
::
CudaStream
>
()
->
cuda_stream
();
T
*
dgamma_data
=
dgamma
;
T
*
dbeta_data
=
dbeta
;
if
(
M
<
512
)
{
// For small batch size, do colwise reduce directly.
const
int64_t
B
=
(
N
+
NumThreads
-
1
)
/
NumThreads
;
GammaBetaBackwardSimple
<
T
>
<<<
B
,
NumThreads
,
0
,
cuda_stream
>>>
(
M
,
N
,
dY_data
,
X_data
,
mean_data
,
rstd_data
,
dgamma_data
,
dbeta_data
);
}
else
{
const
int64_t
B
=
(
N
+
ColwiseReduceTileSize
-
1
)
/
ColwiseReduceTileSize
;
constexpr
int
kThreadX
=
ColwiseReduceTileSize
;
constexpr
int
kThreadY
=
ColwiseReduceTileSize
/
2
;
GammaBetaBackward
<
T
>
<<<
B
,
dim3
(
kThreadX
,
kThreadY
),
0
,
cuda_stream
>>>
(
M
,
N
,
dY_data
,
X_data
,
mean_data
,
rstd_data
,
dgamma_data
,
dbeta_data
);
}
}
namespace
oneflow
{
template
<
typename
T
>
class
LayerNormGpuKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
LayerNormGpuKernel
()
=
default
;
~
LayerNormGpuKernel
()
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
user_op
::
Tensor
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
user_op
::
Tensor
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
user_op
::
Tensor
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
double
epsilon
=
ctx
->
Attr
<
double
>
(
"epsilon"
);
int64_t
num_instances
=
mean
->
shape_view
().
elem_cnt
();
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
const
T
*
gamma_ptr
=
nullptr
;
const
T
*
beta_ptr
=
nullptr
;
if
(
ctx
->
has_input
(
"gamma"
,
0
))
{
const
user_op
::
Tensor
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
gamma_ptr
=
gamma
->
dptr
<
T
>
();
CHECK_EQ
(
gamma
->
shape_view
().
elem_cnt
(),
norm_size
);
}
if
(
ctx
->
has_input
(
"beta"
,
0
))
{
beta_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
)
->
dptr
<
T
>
();
}
// DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
// gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
LayerNormKernelImplInternal
<
T
>
(
ctx
->
stream
(),
x
->
dptr
<
T
>
(),
gamma_ptr
,
beta_ptr
,
num_instances
,
norm_size
,
epsilon
,
y
->
mut_dptr
<
T
>
(),
mean
->
mut_dptr
<
ComputeType
>
(),
inv_variance
->
mut_dptr
<
ComputeType
>
());
};
};
#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm") \
.SetCreateFn<LayerNormGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value));
REGISTER_LAYER_NORM_CUDA_KERNEL
(
float
)
REGISTER_LAYER_NORM_CUDA_KERNEL
(
double
)
REGISTER_LAYER_NORM_CUDA_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_CUDA_KERNEL
(
nv_bfloat16
)
#endif
template
<
typename
T
>
class
LayerNormGradGpuKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
LayerNormGradGpuKernel
()
=
default
;
~
LayerNormGradGpuKernel
()
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
user_op
::
Tensor
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
const
user_op
::
Tensor
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
const
user_op
::
Tensor
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
user_op
::
Tensor
*
dx
=
ctx
->
Tensor4ArgNameAndIndex
(
"dx"
,
0
);
int64_t
num_instances
=
mean
->
shape_view
().
elem_cnt
();
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
const
T
*
gamma_ptr
=
nullptr
;
if
(
ctx
->
has_input
(
"gamma"
,
0
))
{
gamma_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
)
->
dptr
<
T
>
();
}
const
T
*
add_to_output_ptr
=
nullptr
;
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
CHECK_EQ
(
add_to_output
->
data_type
(),
dx
->
data_type
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
dx
->
shape_view
());
add_to_output_ptr
=
add_to_output
->
dptr
<
T
>
();
}
// LaunchLayerNormBackward<T>(ctx->stream(), num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(),
// mean, inv_variance, gamma_ptr, add_to_output_ptr, dx->mut_dptr<T>());
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
LayerNormBackwardKernelImplInternal
<
T
>
(
ctx
->
stream
(),
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
inv_variance
->
dptr
<
ComputeType
>
(),
gamma_ptr
,
num_instances
,
norm_size
,
dx
->
mut_dptr
<
T
>
(),
add_to_output_ptr
);
};
};
#define REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_grad") \
.SetCreateFn<LayerNormGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn( \
[](const user_op::InferContext& ctx, \
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \
if (ctx.has_input("_add_to_output", 0)) { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); \
} \
return Maybe<void>::Ok(); \
});
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
float
)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
double
)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
nv_bfloat16
)
#endif
template
<
typename
T
>
class
LayerNormParamGradGpuKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
LayerNormParamGradGpuKernel
()
=
default
;
~
LayerNormParamGradGpuKernel
()
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
user_op
::
Tensor
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
const
user_op
::
Tensor
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
const
user_op
::
Tensor
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
int64_t
num_instances
=
mean
->
shape_view
().
elem_cnt
();
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
// const DataType data_type = dy->data_type();
// const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;
// const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);
// const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);
// T* tmp_gamma_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr());
// T* tmp_beta_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + tmp_gamma_diff_size);
// T* reduce_buf_ptr =
// reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
// LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block),
// 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
// num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),
// inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);
// const int32_t m = norm_size;
// const int32_t n = 1;
// const int32_t k = grid_dim_y;
// std::unique_ptr<ep::primitive::Fill> fill =
// ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),
// data_type);
// CHECK(fill);
// fill->Launch(ctx->stream(), reduce_buf_ptr, 1.0, grid_dim_y);
// std::unique_ptr<ep::primitive::Matmul> matmul =
// ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(
// ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T,
// ep::primitive::BlasTransposeType::N);
// CHECK(matmul);
// if (ctx->has_output("gamma_diff", 0)) {
// user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_gamma_diff_ptr, reduce_buf_ptr, 0.0,
// gamma_diff->mut_dptr());
// }
// if (ctx->has_output("beta_diff", 0)) {
// user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_beta_diff_ptr, reduce_buf_ptr, 0.0,
// beta_diff->mut_dptr());
// }
T
*
gamma_diff_ptr
=
nullptr
;
T
*
beta_diff_ptr
=
nullptr
;
if
(
ctx
->
has_output
(
"gamma_diff"
,
0
))
{
gamma_diff_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma_diff"
,
0
)
->
mut_dptr
<
T
>
();
}
if
(
ctx
->
has_output
(
"beta_diff"
,
0
))
{
beta_diff_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta_diff"
,
0
)
->
mut_dptr
<
T
>
();
}
LayerNormBackwardKernelImplInternalParam
<
T
>
(
ctx
->
stream
(),
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
inv_variance
->
dptr
<
ComputeType
>
(),
num_instances
,
norm_size
,
gamma_diff_ptr
,
beta_diff_ptr
);
};
};
#define REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_param_grad") \
.SetCreateFn<LayerNormParamGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis"); \
const bool has_gamma_diff = ctx->has_output("gamma_diff", 0); \
const bool has_beta_diff = ctx->has_output("beta_diff", 0); \
const auto& dy = ctx->InputTensorDesc("dy", 0); \
const int64_t num_instances = dy.shape().Count(0, begin_params_axis); \
const int64_t norm_size = dy.shape().Count(begin_params_axis); \
const int grid_dim_y = num_instances; \
size_t tmp_buffer_size = (2 * grid_dim_y * norm_size + grid_dim_y) * sizeof(dtype); \
return tmp_buffer_size; \
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
double
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
nv_bfloat16
)
#endif
}
}
// namespace oneflow
#endif
oneflow/user/kernels/math_binary_elementwise_func.h
View file @
6046d8fb
...
...
@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__)
#include <hip/
hsa_detail/
math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
...
...
oneflow/user/kernels/normalization_kernel.cu
View file @
6046d8fb
...
...
@@ -882,6 +882,22 @@ namespace oneflow {
namespace
{
template
<
typename
T
>
void
printTensor
(
const
std
::
string
&
str
,
const
T
*
devTensor
,
size_t
size
)
{
T
*
hostTensor
;
hostTensor
=
(
T
*
)
malloc
(
size
*
sizeof
(
T
));
hipMemcpy
(
hostTensor
,
devTensor
,
size
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
std
::
cout
<<
str
<<
": "
;
for
(
int
i
;
i
<
size
;
i
++
)
{
if
(
i
%
16
==
0
)
{
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
hostTensor
[
i
]
<<
", "
;
}
std
::
cout
<<
str
<<
": finish"
<<
std
::
endl
;
free
(
hostTensor
);
}
hipdnnBatchNormMode_t
getCudnnBatchNormMode
(
const
int64_t
dim
)
{
if
(
dim
==
2
)
{
return
HIPDNN_BATCHNORM_PER_ACTIVATION
;
...
...
@@ -969,6 +985,15 @@ class CudnnTensorDescHelper final {
int32_t
param_size_
=
0
;
};
size_t
InferInferTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
y
=
ctx
->
OutputTensorDesc
(
"y"
,
0
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
return
y
.
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
y
.
data_type
());
}
else
{
return
1
;
}
}
size_t
InferTrainWorkspaceSize
(
const
ShapeView
&
x_shape
,
const
DataType
data_type
,
const
int32_t
axis
)
{
return
1
;
...
...
@@ -976,8 +1001,13 @@ size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_typ
size_t
InferTrainTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
&
y
=
ctx
->
OutputTensorDesc
(
"y"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
return
y
.
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
y
.
data_type
());
}
else
{
return
InferTrainWorkspaceSize
(
x
.
shape
(),
x
.
data_type
(),
axis
);
}
}
size_t
InferGradWorkspaceSize
(
const
ShapeView
&
x_shape
,
const
DataType
data_type
,
...
...
@@ -1016,6 +1046,9 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
void
*
add_to_output_dev
=
tmp_buffer
->
mut_dptr
<
void
>
();
const
DataType
data_type
=
x
->
data_type
();
CHECK_EQ
(
x
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
y
->
data_type
(),
data_type
);
...
...
@@ -1030,17 +1063,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper
.
CheckParamTensor
(
moving_variance
);
const
void
*
sp_alpha
=
CudnnSPOnePtr
(
data_type
);
const
void
*
sp_beta
;
const
void
*
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
CHECK_EQ
(
add_to_output
->
data_type
(),
y
->
data_type
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
y
->
shape_view
());
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
y
->
mut_dptr
<
void
>
()
,
add_to_output
->
dptr
<
void
>
(),
ctx
->
stream
(),
add_to_output_dev
,
add_to_output
->
dptr
<
void
>
(),
add_to_output
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
add_to_output
->
data_type
()));
sp_beta
=
CudnnSPOnePtr
(
data_type
);
}
else
{
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
}
OF_CUDNN_CHECK
(
hipdnnBatchNormalizationForwardInference
(
...
...
@@ -1048,6 +1079,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper
.
xy_desc
(),
x
->
dptr
(),
desc_helper
.
xy_desc
(),
y
->
mut_dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
moving_mean
->
dptr
(),
moving_variance
->
dptr
(),
epsilon
));
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
sp_beta
=
CudnnSPOnePtr
(
data_type
);
OF_CUDNN_CHECK
(
hipdnnAddTensor
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
sp_alpha
,
desc_helper
.
xy_desc
(),
add_to_output_dev
,
sp_beta
,
desc_helper
.
xy_desc
(),
y
->
mut_dptr
()));
}
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
...
...
@@ -1057,6 +1097,7 @@ REGISTER_USER_KERNEL("normalization")
.
SetCreateFn
<
NormalizationInferenceKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
&&
(
user_op
::
HobAttr
<
bool
>
(
"training"
)
==
false
))
.
SetInferTmpSizeFn
(
InferInferTmpSize
)
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
ctx
,
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
if
(
ctx
.
has_input
(
"_add_to_output"
,
0
))
{
...
...
@@ -1068,76 +1109,78 @@ REGISTER_USER_KERNEL("normalization")
constexpr
int64_t
kCudaWarpSize
=
64
;
template
<
typename
T
>
__global__
void
ReluGpu
(
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
32
_t
*
mask
)
{
__global__
void
ReluGpu
(
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
64
_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
T
zero
=
static_cast
<
T
>
(
0.
f
);
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
T
x_val
=
x
[
i
];
const
bool
is_positive
=
(
x_val
>
zero
);
int32_t
warp_mask
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
unsigned
long
long
int
warp_mask_tmp
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
int64_t
*
warp_mask
=
reinterpret_cast
<
int64_t
*>
(
&
warp_mask_tmp
);
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
*
warp_mask
;
}
y
[
i
]
=
is_positive
?
x_val
:
zero
;
}
}
template
<
typename
T
>
__global__
void
AddReluGpu
(
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
32
_t
*
mask
)
{
__global__
void
AddReluGpu
(
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
64
_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
T
zero
=
static_cast
<
T
>
(
0.
f
);
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
T
sum
=
x
[
i
]
+
addend
[
i
];
const
bool
is_positive
=
(
sum
>
zero
);
int32_t
warp_mask
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
unsigned
long
long
int
warp_mask_tmp
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
int64_t
*
warp_mask
=
reinterpret_cast
<
int64_t
*>
(
&
warp_mask_tmp
);
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
*
warp_mask
;
}
y
[
i
]
=
is_positive
?
sum
:
zero
;
}
}
template
<
typename
T
>
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
32
_t
*
mask
)
{
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
64
_t
*
mask
)
{
ReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
x
,
y
,
mask
);
}
template
<
typename
T
>
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
32
_t
*
mask
)
{
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
64
_t
*
mask
)
{
AddReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
x
,
addend
,
y
,
mask
);
}
template
<
typename
T
>
__global__
void
ReluBackwardGpu
(
int64_t
n
,
const
int
32
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
int
32
_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
__global__
void
ReluBackwardGpu
(
int64_t
n
,
const
int
64
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
int
64
_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
int
32
_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
bool
is_positive
=
mask_val
&
(
1
<<
lane_id
);
int
64
_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
bool
is_positive
=
mask_val
&
(
(
int64_t
)
1
<<
lane_id
);
addend_diff
[
i
]
=
static_cast
<
T
>
(
is_positive
)
*
dy
[
i
];
}
}
#if CUDA_VERSION >= 11000
//
#if CUDA_VERSION >= 11000
template
<
>
__global__
void
ReluBackwardGpu
<
nv_bfloat16
>
(
int64_t
n
,
const
int32_t
*
mask
,
const
nv_bfloat16
*
dy
,
nv_bfloat16
*
addend_diff
)
{
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
int32_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
bool
is_positive
=
mask_val
&
(
1
<<
lane_id
);
addend_diff
[
i
]
=
static_cast
<
nv_bfloat16
>
(
static_cast
<
float
>
(
is_positive
))
*
dy
[
i
];
}
}
//
template<>
//
__global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,
//
nv_bfloat16* addend_diff) {
//
int32_t lane_id = threadIdx.x % kCudaWarpSize;
//
CUDA_1D_KERNEL_LOOP(i, n) {
//
int32_t mask_val = mask[i / kCudaWarpSize];
//
bool is_positive = mask_val & (1 << lane_id);
//
addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];
//
}
//
}
#endif
//
#endif
template
<
typename
T
>
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
int
32
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
int
64
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
ReluBackwardGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
mask
,
dy
,
addend_diff
);
}
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
void
*
x
,
void
*
y
,
int
32
_t
*
mask
)
{
int
64
_t
*
mask
)
{
if
(
data_type
==
kFloat
)
{
Relu
<
float
>
(
stream
,
n
,
reinterpret_cast
<
const
float
*>
(
x
),
reinterpret_cast
<
float
*>
(
y
),
mask
);
}
else
if
(
data_type
==
kDouble
)
{
...
...
@@ -1156,7 +1199,7 @@ void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x
}
}
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
void
*
x
,
const
void
*
addend
,
void
*
y
,
int
32
_t
*
mask
)
{
const
void
*
addend
,
void
*
y
,
int
64
_t
*
mask
)
{
if
(
data_type
==
kFloat
)
{
AddRelu
<
float
>
(
stream
,
n
,
reinterpret_cast
<
const
float
*>
(
x
),
reinterpret_cast
<
const
float
*>
(
addend
),
reinterpret_cast
<
float
*>
(
y
),
mask
);
...
...
@@ -1178,7 +1221,7 @@ void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void
UNIMPLEMENTED
();
}
}
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
int
32
_t
*
mask
,
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
int
64
_t
*
mask
,
const
void
*
dy
,
void
*
addend_diff
)
{
if
(
data_type
==
kFloat
)
{
ReluBackward
<
float
>
(
stream
,
n
,
mask
,
reinterpret_cast
<
const
float
*>
(
dy
),
...
...
@@ -1225,6 +1268,9 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
hipdnnBatchNormMode_t
mode
=
getCudnnBatchNormMode
(
x
->
shape_view
().
NumAxes
());
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
mode
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
void
*
add_to_output_dev
=
tmp_buffer
->
mut_dptr
<
void
>
();
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
...
...
@@ -1244,17 +1290,15 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
desc_helper
.
CheckParamTensor
(
moving_variance
);
}
const
void
*
sp_alpha
=
CudnnSPOnePtr
(
data_type
);
const
void
*
sp_beta
;
const
void
*
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
CHECK_EQ
(
add_to_output
->
data_type
(),
y
->
data_type
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
y
->
shape_view
());
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
y
->
mut_dptr
<
void
>
()
,
add_to_output
->
dptr
<
void
>
(),
ctx
->
stream
(),
add_to_output_dev
,
add_to_output
->
dptr
<
void
>
(),
add_to_output
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
add_to_output
->
data_type
()));
sp_beta
=
CudnnSPOnePtr
(
data_type
);
}
else
{
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
}
OF_CUDNN_CHECK
(
hipdnnBatchNormalizationForwardTraining
(
...
...
@@ -1265,6 +1309,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
moving_variance
?
moving_variance
->
mut_dptr
()
:
NULL
,
epsilon
,
mean
->
mut_dptr
(),
inv_variance
->
mut_dptr
()));
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
sp_beta
=
CudnnSPOnePtr
(
data_type
);
OF_CUDNN_CHECK
(
hipdnnAddTensor
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
sp_alpha
,
desc_helper
.
xy_desc
(),
add_to_output_dev
,
sp_beta
,
desc_helper
.
xy_desc
(),
y
->
mut_dptr
()));
}
if
(
ctx
->
op_type_name
()
==
"normalization_add_relu"
)
{
CHECK
(
!
ctx
->
has_input
(
"_add_to_output"
,
0
));
const
int64_t
elem_cnt
=
x
->
shape_view
().
elem_cnt
();
...
...
@@ -1272,10 +1324,10 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
const
auto
*
addend
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend"
,
0
);
AddRelu
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
y
->
dptr
(),
addend
->
dptr
(),
y
->
mut_dptr
(),
mask
->
mut_dptr
<
int
32
_t
>
());
mask
->
mut_dptr
<
int
64
_t
>
());
}
else
{
Relu
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
y
->
dptr
(),
y
->
mut_dptr
(),
mask
->
mut_dptr
<
int
32
_t
>
());
mask
->
mut_dptr
<
int
64
_t
>
());
}
}
}
...
...
@@ -1351,7 +1403,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
user_op
::
Tensor
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
user_op
::
Tensor
*
addend_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend_diff"
,
0
);
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
32
_t
>
(),
dy
->
dptr
(),
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
64
_t
>
(),
dy
->
dptr
(),
addend_diff
->
mut_dptr
());
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
();
bn_workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
...
...
@@ -1361,7 +1413,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
const
size_t
relu_dx_size
=
GetCudaAlignedSize
(
dy
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
data_type
));
CHECK_GE
(
tmp_buffer_size
,
relu_dx_size
);
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
32
_t
>
(),
dy
->
dptr
(),
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
64
_t
>
(),
dy
->
dptr
(),
tmp_buffer
->
mut_dptr
());
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
<
char
>
()
+
relu_dx_size
;
bn_workspace_size
=
tmp_buffer_size
-
relu_dx_size
;
...
...
@@ -1393,231 +1445,6 @@ REGISTER_USER_KERNEL("normalization_add_relu_grad")
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetInferTmpSizeFn
(
InferGradTmpSize
);
#if (HIPDNN_VERSION >= 7401)
size_t
InferFusedNormalizationAddReluTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
CudnnTensorDescHelper
desc_helper
(
x
.
shape
(),
x
.
data_type
(),
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
size_t
size_in_bytes
;
hipdnnHandle_t
handle
=
Singleton
<
CudnnHandlePool
>::
Get
()
->
Get
();
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
cudnnBatchNormOps_t
ops
;
hipdnnTensorDescriptor_t
z_desc
;
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
z_desc
=
desc_helper
.
xy_desc
();
}
else
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
z_desc
=
nullptr
;
}
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
handle
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
z_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
size_in_bytes
));
Singleton
<
CudnnHandlePool
>::
Get
()
->
Put
(
handle
);
return
std
::
max
(
size_in_bytes
,
static_cast
<
size_t
>
(
1
));
}
size_t
InferFusedNormalizationAddReluGradTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
CudnnTensorDescHelper
desc_helper
(
x
.
shape
(),
x
.
data_type
(),
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
size_t
size_in_bytes
;
hipdnnHandle_t
handle
=
Singleton
<
CudnnHandlePool
>::
Get
()
->
Get
();
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
cudnnBatchNormOps_t
ops
;
hipdnnTensorDescriptor_t
z_desc
;
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
z_desc
=
desc_helper
.
xy_desc
();
}
else
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
z_desc
=
nullptr
;
}
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
handle
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
z_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
size_in_bytes
));
Singleton
<
CudnnHandlePool
>::
Get
()
->
Put
(
handle
);
return
std
::
max
(
size_in_bytes
,
static_cast
<
size_t
>
(
1
));
}
class
FusedNormalizationAddReluKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
FusedNormalizationAddReluKernel
()
=
default
;
~
FusedNormalizationAddReluKernel
()
override
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
auto
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
auto
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
auto
*
moving_mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"moving_mean"
,
0
);
auto
*
moving_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"moving_variance"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
const
auto
momentum
=
ctx
->
Attr
<
float
>
(
"momentum"
);
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
auto
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
auto
*
reserve_space
=
ctx
->
Tensor4ArgNameAndIndex
(
"reserve_space"
,
0
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
DataType
data_type
=
x
->
data_type
();
CHECK_EQ
(
x
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
y
->
data_type
(),
data_type
);
CHECK_GE
(
axis
,
0
);
CHECK_LT
(
axis
,
x
->
shape_view
().
NumAxes
());
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
desc_helper
.
CheckParamTensor
(
gamma
);
desc_helper
.
CheckParamTensor
(
beta
);
desc_helper
.
CheckParamTensor
(
moving_mean
);
desc_helper
.
CheckParamTensor
(
moving_variance
);
desc_helper
.
CheckParamTensor
(
mean
);
desc_helper
.
CheckParamTensor
(
inv_variance
);
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
hipdnnTensorDescriptor_t
z_desc
;
const
void
*
z_ptr
;
cudnnBatchNormOps_t
ops
;
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
z_desc
=
desc_helper
.
xy_desc
();
z_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend"
,
0
)
->
dptr
();
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
}
else
{
z_desc
=
nullptr
;
z_ptr
=
nullptr
;
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
}
size_t
min_workspace_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
z_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
min_workspace_size
));
const
size_t
workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
CHECK_GE
(
workspace_size
,
min_workspace_size
);
size_t
min_reserve_space_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
activation_desc
.
Get
(),
desc_helper
.
xy_desc
(),
&
min_reserve_space_size
));
const
size_t
reserve_space_size
=
reserve_space
->
shape_view
().
elem_cnt
();
CHECK_GE
(
reserve_space_size
,
min_reserve_space_size
);
OF_CUDNN_CHECK
(
cudnnBatchNormalizationForwardTrainingEx
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
CudnnSPOnePtr
(
data_type
),
CudnnSPZeroPtr
(
data_type
),
desc_helper
.
xy_desc
(),
x
->
dptr
(),
z_desc
,
z_ptr
,
desc_helper
.
xy_desc
(),
y
->
mut_dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
1.0
-
momentum
,
moving_mean
->
mut_dptr
(),
moving_variance
->
mut_dptr
(),
epsilon
,
mean
->
mut_dptr
(),
inv_variance
->
mut_dptr
(),
activation_desc
.
Get
(),
tmp_buffer
->
mut_dptr
(),
workspace_size
,
reserve_space
->
mut_dptr
(),
reserve_space_size
));
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"cudnn_fused_normalization_add_relu"
)
.
SetCreateFn
<
FusedNormalizationAddReluKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetInferTmpSizeFn
(
InferFusedNormalizationAddReluTmpSize
);
class
FusedNormalizationAddReluGradUserKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
FusedNormalizationAddReluGradUserKernel
()
=
default
;
~
FusedNormalizationAddReluGradUserKernel
()
override
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
auto
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
const
auto
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
auto
*
dx
=
ctx
->
Tensor4ArgNameAndIndex
(
"dx"
,
0
);
const
auto
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
auto
*
gamma_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma_diff"
,
0
);
auto
*
beta_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta_diff"
,
0
);
const
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
const
auto
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
const
auto
*
reserve_space
=
ctx
->
Tensor4ArgNameAndIndex
(
"reserve_space"
,
0
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
const
DataType
data_type
=
x
->
data_type
();
CHECK_EQ
(
dy
->
shape_view
(),
x
->
shape_view
());
CHECK_EQ
(
dy
->
data_type
(),
data_type
);
CHECK_EQ
(
dx
->
shape_view
(),
x
->
shape_view
());
CHECK_EQ
(
dx
->
data_type
(),
data_type
);
CHECK_GE
(
axis
,
0
);
CHECK_LT
(
axis
,
x
->
shape_view
().
NumAxes
());
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
desc_helper
.
CheckParamTensor
(
gamma
);
desc_helper
.
CheckParamTensor
(
beta
);
desc_helper
.
CheckParamTensor
(
gamma_diff
);
desc_helper
.
CheckParamTensor
(
beta_diff
);
desc_helper
.
CheckParamTensor
(
mean
);
desc_helper
.
CheckParamTensor
(
inv_variance
);
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
hipdnnTensorDescriptor_t
dz_desc
;
void
*
dz_ptr
;
cudnnBatchNormOps_t
ops
;
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
dz_desc
=
desc_helper
.
xy_desc
();
dz_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend_diff"
,
0
)
->
mut_dptr
();
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
}
else
{
dz_desc
=
nullptr
;
dz_ptr
=
nullptr
;
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
}
size_t
min_workspace_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
dz_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
min_workspace_size
));
const
size_t
workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
CHECK_GE
(
workspace_size
,
min_workspace_size
);
size_t
min_reserve_space_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
activation_desc
.
Get
(),
desc_helper
.
xy_desc
(),
&
min_reserve_space_size
));
const
size_t
reserve_space_size
=
reserve_space
->
shape_view
().
elem_cnt
();
CHECK_GE
(
reserve_space_size
,
min_reserve_space_size
);
OF_CUDNN_CHECK
(
cudnnBatchNormalizationBackwardEx
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
CudnnSPOnePtr
(
data_type
),
CudnnSPZeroPtr
(
data_type
),
CudnnSPOnePtr
(
data_type
),
CudnnSPZeroPtr
(
data_type
),
desc_helper
.
xy_desc
(),
x
->
dptr
(),
desc_helper
.
xy_desc
(),
y
->
dptr
(),
desc_helper
.
xy_desc
(),
dy
->
dptr
(),
dz_desc
,
dz_ptr
,
desc_helper
.
xy_desc
(),
dx
->
mut_dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
gamma_diff
->
mut_dptr
(),
beta_diff
->
mut_dptr
(),
epsilon
,
mean
->
dptr
(),
inv_variance
->
dptr
(),
activation_desc
.
Get
(),
tmp_buffer
->
mut_dptr
(),
workspace_size
,
const_cast
<
void
*>
(
reserve_space
->
dptr
()),
reserve_space_size
));
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"cudnn_fused_normalization_add_relu_grad"
)
.
SetCreateFn
<
FusedNormalizationAddReluGradUserKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetInferTmpSizeFn
(
InferFusedNormalizationAddReluGradTmpSize
);
#endif
}
// namespace
}
// namespace oneflow
...
...
oneflow/user/ops/normalization_op.cpp
View file @
6046d8fb
...
...
@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN
(
reserve_space_bits
%
split_num
,
0
);
reserve_space_bits
=
reserve_space_bits
/
split_num
;
}
#ifdef WITH_ROCM
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
64
)
/
64
)}));
#else
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
32
)
/
32
)}));
#endif
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
}
...
...
@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return
MakeFwTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
const
auto
&
x_desc
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
#ifdef WITH_ROCM
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
64
)
/
64
)}));
#else
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
32
)
/
32
)}));
#endif
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
}
...
...
@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */
Maybe
<
void
>
NormalizationAddReluOp
::
InferDataType
(
user_op
::
InferContext
*
ctx
)
{
return
MakeFwDataTypeInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
#ifdef WITH_ROCM
reserve_space
->
set_data_type
(
DataType
::
kInt64
);
#else
reserve_space
->
set_data_type
(
DataType
::
kInt32
);
#endif
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment