Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
6046d8fb
Commit
6046d8fb
authored
Apr 25, 2023
by
yuguo960516yuguo
Browse files
dtk23.04
parent
a715222c
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
419 additions
and
1010 deletions
+419
-1010
cmake/oneflow.cmake
cmake/oneflow.cmake
+10
-10
cmake/third_party.cmake
cmake/third_party.cmake
+2
-2
oneflow/core/common/math_util.h
oneflow/core/common/math_util.h
+2
-2
oneflow/core/cuda/atomic.cuh
oneflow/core/cuda/atomic.cuh
+1
-1
oneflow/core/cuda/layer_norm.cuh
oneflow/core/cuda/layer_norm.cuh
+46
-4
oneflow/core/embedding/lru_cache.cu
oneflow/core/embedding/lru_cache.cu
+20
-7
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
...w/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
+1
-1
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
...e/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
+36
-36
oneflow/core/ep/cuda/primitive/unary_functor.cuh
oneflow/core/ep/cuda/primitive/unary_functor.cuh
+2
-2
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
+156
-25
oneflow/user/kernels/fused_gelu_mul_kernel.cu
oneflow/user/kernels/fused_gelu_mul_kernel.cu
+4
-4
oneflow/user/kernels/group_norm_kernel.cu
oneflow/user/kernels/group_norm_kernel.cu
+4
-0
oneflow/user/kernels/layer_norm_gpu_kernel.cu
oneflow/user/kernels/layer_norm_gpu_kernel.cu
+26
-647
oneflow/user/kernels/math_binary_elementwise_func.h
oneflow/user/kernels/math_binary_elementwise_func.h
+1
-1
oneflow/user/kernels/normalization_kernel.cu
oneflow/user/kernels/normalization_kernel.cu
+95
-268
oneflow/user/ops/normalization_op.cpp
oneflow/user/ops/normalization_op.cpp
+13
-0
No files found.
cmake/oneflow.cmake
View file @
6046d8fb
...
@@ -334,16 +334,16 @@ elseif(WIN32)
...
@@ -334,16 +334,16 @@ elseif(WIN32)
set
(
CMAKE_EXE_LINKER_FLAGS
"
${
CMAKE_EXE_LINKER_FLAGS
}
/WHOLEARCHIVE:oneflow"
)
set
(
CMAKE_EXE_LINKER_FLAGS
"
${
CMAKE_EXE_LINKER_FLAGS
}
/WHOLEARCHIVE:oneflow"
)
endif
()
endif
()
if
(
BUILD_ROCM
)
#
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
#
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
#
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
#
# so '-O0' will override '-O1/2/3'.
set_source_files_properties
(
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
#
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
#
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
#
#
${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS
"-O0"
)
#
PROPERTIES COMPILE_OPTIONS "-O0")
endif
()
#
endif()
if
(
BUILD_CUDA
)
if
(
BUILD_CUDA
)
string
(
JOIN
","
CUDA_REAL_ARCHS
${
CUDA_REAL_ARCHS_LIST
}
)
string
(
JOIN
","
CUDA_REAL_ARCHS
${
CUDA_REAL_ARCHS_LIST
}
)
...
...
cmake/third_party.cmake
View file @
6046d8fb
...
@@ -186,7 +186,7 @@ if (BUILD_ROCM)
...
@@ -186,7 +186,7 @@ if (BUILD_ROCM)
if
(
BUILD_ROCM_GRAPHS
)
if
(
BUILD_ROCM_GRAPHS
)
add_definitions
(
-DWITH_ROCM_GRAPHS
)
add_definitions
(
-DWITH_ROCM_GRAPHS
)
endif
()
endif
()
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--gpu-max-threads-per-block=1024"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--gpu-max-threads-per-block=1024"
)
...
@@ -204,7 +204,7 @@ if (BUILD_ROCM)
...
@@ -204,7 +204,7 @@ if (BUILD_ROCM)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
-mcmodel=large"
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs hip::hipfft
)
list
(
APPEND oneflow_third_party_libs hip::hipfft
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
...
...
oneflow/core/common/math_util.h
View file @
6046d8fb
...
@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
...
@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template
<
typename
T
>
template
<
typename
T
>
OF_DEVICE_FUNC
T
DeviceMin
(
T
a
,
T
b
)
{
OF_DEVICE_FUNC
T
DeviceMin
(
T
a
,
T
b
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined(__HIP_DEVICE_COMPILE__)
return
a
<
b
?
a
:
b
;
return
a
<
b
?
a
:
b
;
#else
#else
return
std
::
min
(
a
,
b
);
return
std
::
min
(
a
,
b
);
...
@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
...
@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template
<
typename
T
>
template
<
typename
T
>
OF_DEVICE_FUNC
T
DeviceMax
(
T
a
,
T
b
)
{
OF_DEVICE_FUNC
T
DeviceMax
(
T
a
,
T
b
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined(__HIP_DEVICE_COMPILE__)
return
a
>
b
?
a
:
b
;
return
a
>
b
?
a
:
b
;
#else
#else
return
std
::
max
(
a
,
b
);
return
std
::
max
(
a
,
b
);
...
...
oneflow/core/cuda/atomic.cuh
View file @
6046d8fb
...
@@ -54,7 +54,7 @@ struct CastCASImpl {
...
@@ -54,7 +54,7 @@ struct CastCASImpl {
}
}
};
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
|| defined(WITH_ROCM)
template
<
typename
T
>
template
<
typename
T
>
struct
CastCASImpl
<
T
,
unsigned
short
int
>
{
struct
CastCASImpl
<
T
,
unsigned
short
int
>
{
...
...
oneflow/core/cuda/layer_norm.cuh
View file @
6046d8fb
...
@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
...
@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared
[
wid
]
=
warp_count
;
count_shared
[
wid
]
=
warp_count
;
}
}
__syncthreads
();
__syncthreads
();
#ifdef WITH_ROCM
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
warp_mean
=
mean_shared
[
lid
];
warp_m2
=
m2_shared
[
lid
];
warp_count
=
count_shared
[
lid
];
}
else
{
warp_mean
=
static_cast
<
T
>
(
0
);
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
__syncthreads
();
if
(
wid
==
0
)
{
#else
if
(
wid
==
0
)
{
if
(
wid
==
0
)
{
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
warp_mean
=
mean_shared
[
lid
];
warp_mean
=
mean_shared
[
lid
];
...
@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
...
@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2
=
static_cast
<
T
>
(
0
);
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
}
#ifdef WITH_ROCM
__syncwarp
();
__syncthreads
();
#else
__syncwarp
();
#endif
#endif
T
block_mean
=
0
;
T
block_mean
=
0
;
T
block_m2
=
0
;
T
block_m2
=
0
;
...
@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
...
@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const
double
epsilon
,
ComputeType
*
mean
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
ComputeType
*
inv_variance
)
{
constexpr
int
block_size
=
128
;
constexpr
int
block_size
=
128
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
constexpr
int
waves
=
32
;
#endif
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
...
@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
...
@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
} \
}
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
else if (cols <= (max_col)*kWarpSize) { \
...
@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
...
@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
} \
}
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
...
@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
...
@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const
double
epsilon
,
ComputeType
*
mean
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
ComputeType
*
inv_variance
)
{
constexpr
int
block_size
=
1024
;
constexpr
int
block_size
=
1024
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
constexpr
int
waves
=
32
;
#endif
int
grid_dim_x
;
int
grid_dim_x
;
{
{
GPU
(
Error_t
)
err
=
GPU
(
Error_t
)
err
=
...
@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
...
@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
const
int64_t
cols
)
{
constexpr
int
block_size
=
128
;
constexpr
int
block_size
=
128
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
constexpr
int
waves
=
32
;
#endif
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
...
@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
...
@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
} \
}
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
else if (cols <= (max_col)*kWarpSize) { \
...
...
oneflow/core/embedding/lru_cache.cu
View file @
6046d8fb
...
@@ -252,15 +252,21 @@ struct SetContext {
...
@@ -252,15 +252,21 @@ struct SetContext {
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_hit
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_hit
);
if
(
hit_mask
!=
0
)
{
return
__ffsll
(
static_cast
<
unsigned
long
long
int
>
(
hit_mask
))
-
1
;
}
else
{
return
-
1
;
}
#else
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_hit
);
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_hit
);
#endif
if
(
hit_mask
!=
0
)
{
if
(
hit_mask
!=
0
)
{
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
}
else
{
}
else
{
return
-
1
;
return
-
1
;
}
}
#endif
}
}
__device__
void
Read
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
__device__
void
Read
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
...
@@ -277,15 +283,16 @@ struct SetContext {
...
@@ -277,15 +283,16 @@ struct SetContext {
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
#else
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_key
==
key
&&
lane_age
!=
0
);
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_key
==
key
&&
lane_age
!=
0
);
#endif
#endif
if
(
hit_mask
!=
0
)
{
if
(
hit_mask
!=
0
)
{
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
#ifdef WITH_ROCM
#ifdef WITH_ROCM
insert_way
=
__ffsll
(
static_cast
<
unsigned
long
long
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
#else
#else
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl_sync
(
kFullMask
,
lane_age
,
insert_way
);
const
int
insert_way_age
=
__shfl_sync
(
kFullMask
,
lane_age
,
insert_way
);
#endif
#endif
if
(
lane_age
>
insert_way_age
)
{
if
(
lane_age
>
insert_way_age
)
{
...
@@ -301,12 +308,16 @@ __syncwarp();
...
@@ -301,12 +308,16 @@ __syncwarp();
}
}
if
(
insert_way
==
-
1
)
{
if
(
insert_way
==
-
1
)
{
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
unsigned
valid_mask
=
__ballot
(
lane_age
!=
0
);
const
unsigned
long
long
int
valid_mask
=
__ballot
(
lane_age
!=
0
);
#else
#else
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
#endif
#endif
if
(
valid_mask
!=
kFullMask
)
{
if
(
valid_mask
!=
kFullMask
)
{
#ifdef WITH_ROCM
insert_way
=
__popcll
(
static_cast
<
unsigned
long
long
int
>
(
valid_mask
));
#else
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
#endif
if
(
lane_age
>
0
)
{
if
(
lane_age
>
0
)
{
lane_age
-=
1
;
lane_age
-=
1
;
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
...
@@ -329,7 +340,8 @@ __syncwarp();
...
@@ -329,7 +340,8 @@ __syncwarp();
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
int
insert_way
=
__ffs
(
static_cast
<
int
>
(
__ballot
(
lane_age
==
1
)))
-
1
;
unsigned
long
long
int
valid_mask_tmp
=
__ballot
(
lane_age
==
1
);
const
int
insert_way
=
__ffsll
(
valid_mask_tmp
)
-
1
;
#else
#else
const
int
insert_way
=
__ffs
(
__ballot_sync
(
kFullMask
,
lane_age
==
1
))
-
1
;
const
int
insert_way
=
__ffs
(
__ballot_sync
(
kFullMask
,
lane_age
==
1
))
-
1
;
#endif
#endif
...
@@ -594,7 +606,8 @@ __syncwarp();
...
@@ -594,7 +606,8 @@ __syncwarp();
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
int
key_count
=
__popc
(
static_cast
<
int
>
(
__ballot
(
lane_age
!=
0
)));
unsigned
long
long
int
valid_mask_tmp
=
__ballot
(
lane_age
!=
0
);
const
int
key_count
=
__popcll
(
valid_mask_tmp
);
#else
#else
const
int
key_count
=
__popc
(
__ballot_sync
(
kFullMask
,
lane_age
!=
0
));
const
int
key_count
=
__popc
(
__ballot_sync
(
kFullMask
,
lane_age
!=
0
));
#endif
#endif
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
View file @
6046d8fb
...
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive/
/
broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
View file @
6046d8fb
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar
attr0
,
Scalar
attr1
);
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ_1
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
BINARY_MATH_OP_SEQ_1
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/cuda/primitive/unary_functor.cuh
View file @
6046d8fb
...
@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
...
@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
float
tanh_in
=
const
float
tanh_in
=
__half2float
(
__float2half_rn
(
alpha
)
*
(
src
+
__float2half_rn
(
beta
)
*
src
*
src
*
src
));
__half2float
(
__float2half_rn
(
alpha
)
*
(
src
+
__float2half_rn
(
beta
)
*
src
*
src
*
src
));
const
float
tanh_out
=
unary_functor_internal
::
TanhApprox
(
tanh_in
);
const
float
tanh_out
=
unary_functor_internal
::
TanhApprox
(
tanh_in
);
...
@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
...
@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
dst
,
const
half
*
src
)
const
{
__device__
void
Apply2
(
half
*
dst
,
const
half
*
src
)
const
{
const
half2
src2
=
*
(
reinterpret_cast
<
const
half2
*>
(
src
));
const
half2
src2
=
*
(
reinterpret_cast
<
const
half2
*>
(
src
));
const
float2
tanh_in
=
__half22float2
(
__hmul2
(
const
float2
tanh_in
=
__half22float2
(
__hmul2
(
...
...
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
View file @
6046d8fb
...
@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
...
@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like"
,
"reduce_sum_like"
,
"layer_norm_grad"
,
"layer_norm_grad"
,
"layer_norm"
,
"layer_norm"
,
"layer_norm_param_grad"
"layer_norm_param_grad"
,
};
return
black_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
GrayList
()
{
"add_n"
,
static
AMPList
gray_list
=
{
"add_n"
,
"tf_avg_pool_1d"
,
"tf_avg_pool_1d"
,
"tf_avg_pool_1d_grad"
,
"tf_avg_pool_1d_grad"
,
"tf_avg_pool_2d"
,
"tf_avg_pool_2d"
,
"tf_avg_pool_2d_grad"
,
"tf_avg_pool_2d_grad"
,
"tf_avg_pool_3d"
,
"tf_avg_pool_3d"
,
"tf_avg_pool_3d_grad"
,
"tf_avg_pool_3d_grad"
,
"avg_pool_1d"
,
//
"avg_pool_1d",
"avg_pool_1d_grad"
,
//
"avg_pool_1d_grad",
"avg_pool_2d"
,
//
"avg_pool_2d",
"avg_pool_2d_grad"
,
//
"avg_pool_2d_grad",
"avg_pool_3d"
,
//
"avg_pool_3d",
"avg_pool_3d_grad"
,
//
"avg_pool_3d_grad",
"bias_add"
,
"bias_add"
,
"sigmoid_grad"
,
"sigmoid_grad"
,
"tanh"
,
"tanh"
,
...
@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
...
@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad"
,
"group_norm_grad"
,
"silu"
,
"silu"
,
"silu_grad"
,
"silu_grad"
,
"fused_weighted_sum"
};
"fused_weighted_sum"
,
return
gray_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
ClearList
()
{
"broadcast_like"
,
// TODO(niuchong): tuple_identity
static
AMPList
clear_list
=
{
"broadcast_like"
,
"gather"
,
"gather"
,
"gather_nd"
,
"gather_nd"
,
"scatter_nd"
,
"scatter_nd"
,
...
@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
...
@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad"
,
"tf_max_pool_2d_grad"
,
"tf_max_pool_3d"
,
"tf_max_pool_3d"
,
"tf_max_pool_3d_grad"
,
"tf_max_pool_3d_grad"
,
"max_pool_1d"
,
//
"max_pool_1d",
"max_pool_1d_grad"
,
//
"max_pool_1d_grad",
"max_pool_2d"
,
//
"max_pool_2d",
"max_pool_2d_grad"
,
//
"max_pool_2d_grad",
"max_pool_3d"
,
//
"max_pool_3d",
"max_pool_3d_grad"
,
//
"max_pool_3d_grad",
"reshape"
,
"reshape"
,
"reshape_like"
,
"reshape_like"
,
"relu"
,
"relu"
,
...
@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
...
@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity"
,
"pinned_identity"
,
"to_contiguous"
,
"to_contiguous"
,
"copy"
,
"copy"
,
"upsample_nearest_2d"
};
"upsample_nearest_2d"
};
return
black_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
GrayList
()
{
static
AMPList
gray_list
=
{
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d"
,
"avg_pool_1d_grad"
,
"avg_pool_2d"
,
"avg_pool_2d_grad"
,
"avg_pool_3d"
,
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return
gray_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
ClearList
()
{
// TODO(niuchong): tuple_identity
static
AMPList
clear_list
=
{
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d"
,
"max_pool_1d_grad"
,
"max_pool_2d"
,
"max_pool_2d_grad"
,
"max_pool_3d"
,
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return
clear_list
;
return
clear_list
;
}
}
...
...
oneflow/user/kernels/fused_gelu_mul_kernel.cu
View file @
6046d8fb
...
@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
...
@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC
FusedFastGeluMulFunctor
()
{}
OF_DEVICE_FUNC
FusedFastGeluMulFunctor
()
{}
OF_DEVICE_FUNC
half
operator
()(
const
half
x
,
const
half
m
)
const
{
OF_DEVICE_FUNC
half
operator
()(
const
half
x
,
const
half
m
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
float
tanh_in
=
const
float
tanh_in
=
__half2float
(
__float2half_rn
(
alpha
)
*
(
x
+
__float2half_rn
(
beta
)
*
x
*
x
*
x
));
__half2float
(
__float2half_rn
(
alpha
)
*
(
x
+
__float2half_rn
(
beta
)
*
x
*
x
*
x
));
const
float
tanh_out
=
TanhApprox
(
tanh_in
);
const
float
tanh_out
=
TanhApprox
(
tanh_in
);
...
@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
...
@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
y
,
const
half
*
x
,
const
half
*
m
)
const
{
__device__
void
Apply2
(
half
*
y
,
const
half
*
x
,
const
half
*
m
)
const
{
const
half2
x2
=
*
(
reinterpret_cast
<
const
half2
*>
(
x
));
const
half2
x2
=
*
(
reinterpret_cast
<
const
half2
*>
(
x
));
const
float2
tanh_in
=
__half22float2
(
const
float2
tanh_in
=
__half22float2
(
...
@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
...
@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__
void
operator
()(
half
&
x_diff
,
half
&
m_diff
,
const
half
&
dy
,
const
half
&
x
,
__device__
void
operator
()(
half
&
x_diff
,
half
&
m_diff
,
const
half
&
dy
,
const
half
&
x
,
const
half
&
m
)
const
{
const
half
&
m
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
half
halpha
=
__float2half_rn
(
alpha
);
const
half
halpha
=
__float2half_rn
(
alpha
);
const
half
hbeta
=
__float2half_rn
(
beta
);
const
half
hbeta
=
__float2half_rn
(
beta
);
const
half
hone
=
__float2half_rn
(
1.0
F
);
const
half
hone
=
__float2half_rn
(
1.0
F
);
...
@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
...
@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
x_diff
,
half
*
m_diff
,
const
half
*
dy
,
const
half
*
x
,
__device__
void
Apply2
(
half
*
x_diff
,
half
*
m_diff
,
const
half
*
dy
,
const
half
*
x
,
const
half
*
m
)
const
{
const
half
*
m
)
const
{
const
half2
dy2
=
*
(
reinterpret_cast
<
const
half2
*>
(
dy
));
const
half2
dy2
=
*
(
reinterpret_cast
<
const
half2
*>
(
dy
));
...
...
oneflow/user/kernels/group_norm_kernel.cu
View file @
6046d8fb
...
@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
...
@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr
int
kReduceBlockSize
=
512
;
constexpr
int
kReduceBlockSize
=
512
;
constexpr
int
kBlockSize
=
128
;
constexpr
int
kBlockSize
=
128
;
#ifdef WITH_ROCM
constexpr
int
kNumWaves
=
64
;
#else
constexpr
int
kNumWaves
=
32
;
constexpr
int
kNumWaves
=
32
;
#endif
inline
GPU
(
Error_t
)
GetReduceNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
inline
GPU
(
Error_t
)
GetReduceNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
int
dev
;
int
dev
;
...
...
oneflow/user/kernels/layer_norm_gpu_kernel.cu
View file @
6046d8fb
...
@@ -27,9 +27,12 @@ limitations under the License.
...
@@ -27,9 +27,12 @@ limitations under the License.
#include <cuda_bf16.h>
#include <cuda_bf16.h>
#endif // CUDA_VERSION >= 11000
#endif // CUDA_VERSION >= 11000
#ifdef WITH_CUDA
#ifdef WITH_ROCM
#include <thrust/pair.h>
#include <hipcub/hipcub.hpp>
#else
#include <cub/cub.cuh>
#include <cub/cub.cuh>
#endif
namespace
oneflow
{
namespace
oneflow
{
namespace
{
namespace
{
...
@@ -143,18 +146,31 @@ __inline__ __device__ T WarpReduce(T val) {
...
@@ -143,18 +146,31 @@ __inline__ __device__ T WarpReduce(T val) {
return
val
;
return
val
;
}
}
constexpr
int
tile_size
=
32
;
constexpr
int
num_per_block
=
4
;
constexpr
int
num_per_block
=
4
;
#ifdef WITH_ROCM
constexpr
int
tile_size
=
64
;
constexpr
int
block_dim_x
=
64
;
constexpr
int
block_dim_y
=
64
/
num_per_block
;
#else
constexpr
int
tile_size
=
32
;
constexpr
int
block_dim_x
=
32
;
constexpr
int
block_dim_x
=
32
;
constexpr
int
block_dim_y
=
32
/
num_per_block
;
constexpr
int
block_dim_y
=
32
/
num_per_block
;
#endif
template
<
typename
T
,
typename
ComputeType
>
template
<
typename
T
,
typename
ComputeType
>
__global__
void
LayerNormParamGrad
(
int
rows
,
int
cols
,
const
T
*
__restrict__
dy
,
__global__
void
LayerNormParamGrad
(
int
rows
,
int
cols
,
const
T
*
__restrict__
dy
,
const
T
*
__restrict__
x
,
const
ComputeType
*
__restrict__
mean
,
const
T
*
__restrict__
x
,
const
ComputeType
*
__restrict__
mean
,
const
ComputeType
*
__restrict__
inv_var
,
const
ComputeType
*
__restrict__
inv_var
,
T
*
__restrict__
tmp_gamma_diff
,
T
*
__restrict__
tmp_beta_diff
)
{
T
*
__restrict__
tmp_gamma_diff
,
T
*
__restrict__
tmp_beta_diff
)
{
#ifdef WITH_ROCM
__shared__
ComputeType
dgamma
[
64
][
65
];
__shared__
ComputeType
dbeta
[
64
][
65
];
#else
__shared__
ComputeType
dgamma
[
32
][
33
];
__shared__
ComputeType
dgamma
[
32
][
33
];
__shared__
ComputeType
dbeta
[
32
][
33
];
__shared__
ComputeType
dbeta
[
32
][
33
];
#endif
ComputeType
dgamma_sum
[
num_per_block
];
ComputeType
dgamma_sum
[
num_per_block
];
ComputeType
dbeta_sum
[
num_per_block
];
ComputeType
dbeta_sum
[
num_per_block
];
#pragma unroll
#pragma unroll
...
@@ -210,13 +226,13 @@ int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) {
...
@@ -210,13 +226,13 @@ int GetGirdDimY(const int64_t num_instances, const int64_t norm_size) {
const
int
max_grid_dim_y
=
(
num_instances
+
tile_size
-
1
)
/
tile_size
;
const
int
max_grid_dim_y
=
(
num_instances
+
tile_size
-
1
)
/
tile_size
;
const
int
block_size
=
block_dim_x
*
block_dim_y
;
const
int
block_size
=
block_dim_x
*
block_dim_y
;
int
max_active_blocks
=
0
;
int
max_active_blocks
=
0
;
OF_CUDA_CHECK
(
cuda
OccupancyMaxActiveBlocksPerMultiprocessor
(
OF_CUDA_CHECK
(
GPU
(
OccupancyMaxActiveBlocksPerMultiprocessor
)
(
&
max_active_blocks
,
LayerNormParamGrad
<
T
,
ComputeType
>
,
block_size
,
0
));
&
max_active_blocks
,
LayerNormParamGrad
<
T
,
ComputeType
>
,
block_size
,
0
));
int
waves
=
1
;
int
waves
=
1
;
int
dev
;
int
dev
;
OF_CUDA_CHECK
(
cuda
GetDevice
(
&
dev
));
OF_CUDA_CHECK
(
GPU
(
GetDevice
)
(
&
dev
));
int
sm_count
;
int
sm_count
;
OF_CUDA_CHECK
(
cuda
DeviceGetAttribute
(
&
sm_count
,
cudaDevAttr
MultiProcessorCount
,
dev
));
OF_CUDA_CHECK
(
GPU
(
DeviceGetAttribute
)
(
&
sm_count
,
GPU
MultiProcessorCount
,
dev
));
int
num_blocks
=
max_active_blocks
*
sm_count
*
waves
;
int
num_blocks
=
max_active_blocks
*
sm_count
*
waves
;
int
grid_dim_y
=
std
::
min
(
max_grid_dim_y
,
static_cast
<
int
>
(
num_blocks
/
grid_dim_x
));
int
grid_dim_y
=
std
::
min
(
max_grid_dim_y
,
static_cast
<
int
>
(
num_blocks
/
grid_dim_x
));
return
std
::
max
(
grid_dim_y
,
1
);
return
std
::
max
(
grid_dim_y
,
1
);
...
@@ -420,6 +436,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
...
@@ -420,6 +436,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
const
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
const
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
DataType
data_type
=
dy
->
data_type
();
const
DataType
data_type
=
dy
->
data_type
();
const
int
grid_dim_x
=
(
norm_size
+
tile_size
-
1
)
/
tile_size
;
const
int
grid_dim_x
=
(
norm_size
+
tile_size
-
1
)
/
tile_size
;
const
int
grid_dim_y
=
GetGirdDimY
<
T
>
(
num_instances
,
norm_size
);
const
int
grid_dim_y
=
GetGirdDimY
<
T
>
(
num_instances
,
norm_size
);
const
size_t
tmp_gamma_diff_size
=
grid_dim_y
*
norm_size
*
sizeof
(
T
);
const
size_t
tmp_gamma_diff_size
=
grid_dim_y
*
norm_size
*
sizeof
(
T
);
...
@@ -428,7 +445,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
...
@@ -428,7 +445,7 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
T
*
reduce_buf_ptr
=
T
*
reduce_buf_ptr
=
reinterpret_cast
<
T
*>
(
tmp_buffer
->
mut_dptr
<
char
>
()
+
2
*
tmp_gamma_diff_size
);
reinterpret_cast
<
T
*>
(
tmp_buffer
->
mut_dptr
<
char
>
()
+
2
*
tmp_gamma_diff_size
);
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
LayerNormParamGrad
<
T
,
ComputeType
><<<
dim3
(
grid_dim_x
,
grid_dim_y
),
dim3
(
32
,
32
/
num_per_block
),
LayerNormParamGrad
<
T
,
ComputeType
><<<
dim3
(
grid_dim_x
,
grid_dim_y
),
dim3
(
block_dim_x
,
block_dim_y
),
0
,
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
0
,
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
num_instances
,
norm_size
,
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
num_instances
,
norm_size
,
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
inv_variance
->
dptr
<
ComputeType
>
(),
tmp_gamma_diff_ptr
,
tmp_beta_diff_ptr
);
inv_variance
->
dptr
<
ComputeType
>
(),
tmp_gamma_diff_ptr
,
tmp_beta_diff_ptr
);
...
@@ -476,651 +493,13 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
...
@@ -476,651 +493,13 @@ class LayerNormParamGradGpuKernel final : public user_op::OpKernel,
});
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float
)
#ifdef WITH_CUDA
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
double
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
double
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
nv_bfloat16
)
#endif
}
// namespace oneflow
#endif
#endif
#ifdef WITH_ROCM
#include <hipcub/hipcub.hpp>
#include <thrust/pair.h>
template
<
typename
T
,
bool
is_cuda
>
struct
AccumulateType
{
};
#if defined(__HIPCC__)
template
<
>
struct
AccumulateType
<
half
,
true
>
{
using
type
=
float
;
};
#endif
template
<
>
struct
AccumulateType
<
float
,
true
>
{
using
type
=
float
;
};
template
<
>
struct
AccumulateType
<
double
,
true
>
{
using
type
=
double
;
};
template
<
>
struct
AccumulateType
<
int8_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
uint8_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
char
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int16_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int32_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int64_t
,
true
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
bool
,
true
>
{
using
type
=
bool
;
};
template
<
>
struct
AccumulateType
<
float
,
false
>
{
using
type
=
double
;
};
template
<
>
struct
AccumulateType
<
double
,
false
>
{
using
type
=
double
;
};
template
<
>
struct
AccumulateType
<
int8_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
uint8_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
char
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int16_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int32_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
int64_t
,
false
>
{
using
type
=
int64_t
;
};
template
<
>
struct
AccumulateType
<
bool
,
false
>
{
using
type
=
bool
;
};
template
<
typename
T
,
bool
is_cuda
>
using
acc_type
=
typename
AccumulateType
<
T
,
is_cuda
>::
type
;
#define C10_HOST_DEVICE __host__ __device__
#define C10_DEVICE __device__
#define C10_HOST __host__
#define C10_WARP_SIZE 64
#define VEC 4
typedef
int64_t
IndexType
;
constexpr
int
BlockReduceNumThreads
=
512
;
constexpr
int
NumThreads
=
256
;
constexpr
int
ColwiseReduceTileSize
=
32
;
template
<
typename
scalar_t
,
typename
index_t
,
typename
combine_t
>
struct
WelfordData
{
scalar_t
mean
;
scalar_t
m2
;
index_t
n
;
combine_t
nf
;
C10_HOST_DEVICE
WelfordData
()
:
mean
(
0
),
m2
(
0
),
n
(
0
),
nf
(
0
)
{}
C10_HOST_DEVICE
WelfordData
(
scalar_t
mean
,
scalar_t
m2
,
index_t
n
,
combine_t
nf
)
:
mean
(
mean
),
m2
(
m2
),
n
(
n
),
nf
(
nf
)
{}
};
template
<
typename
scalar_t
,
typename
acc_scalar_t
,
typename
index_t
,
typename
combine_t
,
typename
res_t
>
struct
WelfordOps
{
public:
using
acc_t
=
WelfordData
<
acc_scalar_t
,
index_t
,
combine_t
>
;
inline
C10_DEVICE
acc_t
reduce
(
acc_t
acc
,
scalar_t
data
)
const
{
acc_scalar_t
delta
=
data
-
acc
.
mean
;
// using acc.nf(combine_t) here, as acc.n(index_t) would still be converted
// accumulation in reduce is done through index_T
acc_scalar_t
new_mean
=
acc
.
mean
+
delta
/
(
acc
.
nf
+
1
);
acc_scalar_t
new_delta
=
data
-
new_mean
;
return
{
new_mean
,
acc
.
m2
+
delta
*
new_delta
,
acc
.
n
+
1
,
combine_t
(
acc
.
n
+
1
),
// accumulate for combine_t uses index_t
};
}
inline
C10_DEVICE
acc_t
combine
(
acc_t
a
,
acc_t
b
)
const
{
if
(
a
.
nf
==
0
)
{
return
b
;
}
if
(
b
.
nf
==
0
)
{
return
a
;
}
acc_scalar_t
delta
=
b
.
mean
-
a
.
mean
;
combine_t
new_count
=
a
.
nf
+
b
.
nf
;
acc_scalar_t
nb_over_n
=
b
.
nf
/
new_count
;
return
{
a
.
mean
+
delta
*
nb_over_n
,
a
.
m2
+
b
.
m2
+
delta
*
delta
*
a
.
nf
*
nb_over_n
,
// setting acc.n as -1 since acc.n might not be able to represent the count
// correctly within its range, setting it to -1 to avoid confusion
-
1
,
new_count
};
}
inline
C10_DEVICE
res_t
project
(
acc_t
acc
)
const
{
return
res_t
(
acc
.
m2
/
acc
.
nf
,
static_cast
<
scalar_t
>
(
acc
.
mean
));
}
inline
__device__
acc_t
warp_shfl_down
(
acc_t
acc
,
int
offset
)
const
{
return
{
__shfl_down
(
acc
.
mean
,
offset
)
,
__shfl_down
(
acc
.
m2
,
offset
)
,
__shfl_down
(
acc
.
n
,
offset
)
,
__shfl_down
(
acc
.
nf
,
offset
)
};
}
};
template
<
int
max
=
32
,
typename
T
,
class
ReduceOp
>
__inline__
__device__
T
WarpReduce
(
T
val
,
const
ReduceOp
&
op
)
{
#pragma unroll
for
(
int
offset
=
max
;
offset
>
0
;
offset
>>=
1
)
{
val
=
op
.
combine
(
val
,
op
.
warp_shfl_down
(
val
,
offset
));
}
return
val
;
}
template
<
typename
T
,
class
ReduceOp
>
__inline__
__device__
T
BlockReduce
(
T
val
,
const
ReduceOp
&
op
,
T
*
shared
)
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
val
=
WarpReduce
(
val
,
op
);
__syncthreads
();
if
(
lid
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
shared
[
lid
];
val
=
WarpReduce
<
4
>
(
val
,
op
);
}
return
val
;
}
template
<
int
max
=
32
,
typename
T
>
__inline__
__device__
T
WarpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
offset
=
max
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_down
(
val
,
offset
);
}
return
val
;
}
template
<
typename
T
>
__inline__
__device__
T
BlockReduceSum
(
T
val
,
T
*
shared
)
{
const
int
lid
=
threadIdx
.
x
%
C10_WARP_SIZE
;
const
int
wid
=
threadIdx
.
x
/
C10_WARP_SIZE
;
val
=
WarpReduceSum
(
val
);
__syncthreads
();
if
(
lid
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
)
{
val
=
shared
[
lid
];
val
=
WarpReduceSum
<
4
>
(
val
);
}
return
val
;
}
template
<
typename
scalar_t
>
__global__
void
layernorm_forward_kernel
(
const
scalar_t
*
input
,
scalar_t
*
ret
,
acc_type
<
scalar_t
,
true
>*
mean
,
acc_type
<
scalar_t
,
true
>*
rstd
,
const
scalar_t
*
gamma
,
const
scalar_t
*
beta
,
IndexType
cols
,
double
eps
)
{
//dropout do nothing in val mode
IndexType
i
=
blockIdx
.
x
;
// add + layernorm get mean and rstd
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
using
WelfordType
=
WelfordData
<
T_ACC
,
IndexType
,
T_ACC
>
;
using
WelfordOp
=
WelfordOps
<
T_ACC
,
T_ACC
,
IndexType
,
T_ACC
,
thrust
::
pair
<
T_ACC
,
T_ACC
>>
;
__shared__
typename
std
::
aligned_storage
<
sizeof
(
WelfordType
),
alignof
(
WelfordType
)
>::
type
val_shared
[
BlockReduceNumThreads
/
C10_WARP_SIZE
];
WelfordType
*
val_shared_ptr
=
reinterpret_cast
<
WelfordType
*>
(
val_shared
);
WelfordOp
welford_op
;
WelfordType
val
;
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
cols
;
j
+=
blockDim
.
x
)
{
IndexType
index
=
i
*
cols
+
j
;
val
=
welford_op
.
reduce
(
val
,
static_cast
<
T_ACC
>
(
input
[
index
]));
}
val
=
BlockReduce
(
val
,
welford_op
,
val_shared_ptr
);
__shared__
T_ACC
s_mean
;
__shared__
T_ACC
s_rstd
;
if
(
threadIdx
.
x
==
0
)
{
thrust
::
tie
(
s_rstd
,
s_mean
)
=
welford_op
.
project
(
val
);
mean
[
i
]
=
s_mean
;
s_rstd
=
rsqrt
(
s_rstd
+
static_cast
<
T_ACC
>
(
eps
));
rstd
[
i
]
=
s_rstd
;
}
__syncthreads
();
//layernorm (x-mean)*rstd*gamma+beta
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
cols
;
j
+=
blockDim
.
x
)
{
IndexType
index
=
i
*
cols
+
j
;
ret
[
index
]
=
static_cast
<
scalar_t
>
((
static_cast
<
T_ACC
>
(
input
[
index
])
-
s_mean
)
*
s_rstd
*
(
gamma
==
nullptr
?
T_ACC
(
1
)
:
static_cast
<
T_ACC
>
(
gamma
[
j
]))
+
(
beta
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
beta
[
j
])));
}
}
template
<
typename
T
>
void
LayerNormKernelImplInternal
(
oneflow
::
ep
::
Stream
*
stream
,
const
T
*
X
,
const
T
*
gamma
,
const
T
*
beta
,
int64_t
M
,
int64_t
N
,
double
eps
,
T
*
Y
,
acc_type
<
T
,
true
>*
mean
,
acc_type
<
T
,
true
>*
rstd
)
{
using
T_ACC
=
acc_type
<
T
,
true
>
;
const
T
*
X_data
=
X
;
const
T
*
gamma_data
=
gamma
;
const
T
*
beta_data
=
beta
;
T
*
Y_data
=
Y
;
T_ACC
*
mean_data
=
mean
;
T_ACC
*
rstd_data
=
rstd
;
hipStream_t
cuda_stream
=
stream
->
As
<
oneflow
::
ep
::
CudaStream
>
()
->
cuda_stream
();
layernorm_forward_kernel
<
T
><<<
M
,
BlockReduceNumThreads
,
0
,
cuda_stream
>>>
(
X_data
,
Y_data
,
mean_data
,
rstd_data
,
gamma_data
,
beta_data
,
N
,
eps
);
}
template
<
typename
scalar_t
>
__global__
void
GammaBetaBackwardSimple
(
IndexType
M
,
IndexType
N
,
const
scalar_t
*
dY
,
const
scalar_t
*
X
,
const
acc_type
<
scalar_t
,
true
>*
mean
,
const
acc_type
<
scalar_t
,
true
>*
rstd
,
scalar_t
*
dg
,
scalar_t
*
db
)
{
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
j
<
N
)
{
T_ACC
sum1
=
0
;
T_ACC
sum2
=
0
;
for
(
int64_t
i
=
0
;
i
<
M
;
++
i
)
{
const
int64_t
index
=
i
*
N
+
j
;
sum1
+=
dg
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index
])
*
(
static_cast
<
T_ACC
>
(
X
[
index
])
-
static_cast
<
T_ACC
>
(
mean
[
i
]))
*
static_cast
<
T_ACC
>
(
rstd
[
i
]);
sum2
+=
db
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index
]);
}
if
(
dg
!=
nullptr
)
{
dg
[
j
]
=
static_cast
<
scalar_t
>
(
sum1
);
}
if
(
db
!=
nullptr
)
{
db
[
j
]
=
static_cast
<
scalar_t
>
(
sum2
);
}
}
}
template
<
typename
scalar_t
>
__global__
void
GammaBetaBackward
(
IndexType
M
,
IndexType
N
,
const
scalar_t
*
dY
,
const
scalar_t
*
X
,
const
acc_type
<
scalar_t
,
true
>*
mean
,
const
acc_type
<
scalar_t
,
true
>*
rstd
,
scalar_t
*
dg
,
scalar_t
*
db
)
{
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
__shared__
T_ACC
g_shared
[
ColwiseReduceTileSize
][
ColwiseReduceTileSize
+
1
];
__shared__
T_ACC
b_shared
[
ColwiseReduceTileSize
][
ColwiseReduceTileSize
+
1
];
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
T_ACC
dg_sum1
=
0
;
T_ACC
dg_sum2
=
0
;
T_ACC
db_sum1
=
0
;
T_ACC
db_sum2
=
0
;
if
(
j
<
N
)
{
for
(
int64_t
i
=
threadIdx
.
y
;
i
<
M
;
i
+=
blockDim
.
y
*
2
)
{
const
int64_t
i1
=
i
;
const
int64_t
i2
=
i
+
blockDim
.
y
;
const
int64_t
index1
=
i1
*
N
+
j
;
const
int64_t
index2
=
i2
*
N
+
j
;
dg_sum1
+=
dg
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index1
])
*
(
static_cast
<
T_ACC
>
(
X
[
index1
])
-
static_cast
<
T_ACC
>
(
mean
[
i1
]))
*
static_cast
<
T_ACC
>
(
rstd
[
i1
]);
db_sum1
+=
db
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index1
]);
if
(
i2
<
M
)
{
dg_sum2
+=
dg
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index2
])
*
(
static_cast
<
T_ACC
>
(
X
[
index2
])
-
static_cast
<
T_ACC
>
(
mean
[
i2
]))
*
static_cast
<
T_ACC
>
(
rstd
[
i2
]);
db_sum2
+=
db
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
dY
[
index2
]);
}
}
}
g_shared
[
threadIdx
.
y
][
threadIdx
.
x
]
=
dg_sum1
;
g_shared
[
threadIdx
.
y
+
blockDim
.
y
][
threadIdx
.
x
]
=
dg_sum2
;
b_shared
[
threadIdx
.
y
][
threadIdx
.
x
]
=
db_sum1
;
b_shared
[
threadIdx
.
y
+
blockDim
.
y
][
threadIdx
.
x
]
=
db_sum2
;
__syncthreads
();
T_ACC
sum1
=
g_shared
[
threadIdx
.
x
][
threadIdx
.
y
];
T_ACC
sum2
=
b_shared
[
threadIdx
.
x
][
threadIdx
.
y
];
sum1
=
WarpReduceSum
<
16
>
(
sum1
);
sum2
=
WarpReduceSum
<
16
>
(
sum2
);
if
(
threadIdx
.
x
==
0
)
{
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
;
if
(
j
<
N
)
{
if
(
dg
!=
nullptr
)
{
dg
[
j
]
=
static_cast
<
scalar_t
>
(
sum1
);
}
if
(
db
!=
nullptr
)
{
db
[
j
]
=
static_cast
<
scalar_t
>
(
sum2
);
}
}
}
sum1
=
g_shared
[
threadIdx
.
x
][
threadIdx
.
y
+
blockDim
.
y
];
sum2
=
b_shared
[
threadIdx
.
x
][
threadIdx
.
y
+
blockDim
.
y
];
sum1
=
WarpReduceSum
<
16
>
(
sum1
);
sum2
=
WarpReduceSum
<
16
>
(
sum2
);
if
(
threadIdx
.
x
==
0
)
{
const
int64_t
j
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
y
+
blockDim
.
y
;
if
(
j
<
N
)
{
if
(
dg
!=
nullptr
)
{
dg
[
j
]
=
static_cast
<
scalar_t
>
(
sum1
);
}
if
(
db
!=
nullptr
)
{
db
[
j
]
=
static_cast
<
scalar_t
>
(
sum2
);
}
}
}
}
template
<
typename
scalar_t
>
__global__
void
LayerNormBackward_kernel
(
IndexType
N
,
const
scalar_t
*
dY
,
const
scalar_t
*
X
,
const
scalar_t
*
gamma
,
const
acc_type
<
scalar_t
,
true
>*
mean
,
const
acc_type
<
scalar_t
,
true
>*
rstd
,
scalar_t
*
dX
,
const
scalar_t
*
add_to_output
)
{
using
T_ACC
=
acc_type
<
scalar_t
,
true
>
;
__shared__
T_ACC
ds_shared
[
C10_WARP_SIZE
];
__shared__
T_ACC
db_shared
[
C10_WARP_SIZE
];
const
IndexType
i
=
blockIdx
.
x
;
T_ACC
sum1
=
0
;
T_ACC
sum2
=
0
;
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
N
;
j
+=
blockDim
.
x
)
{
const
IndexType
index
=
i
*
N
+
j
;
const
T_ACC
gamma_v
=
gamma
==
nullptr
?
T_ACC
(
1
)
:
static_cast
<
T_ACC
>
(
gamma
[
j
]);
sum1
+=
static_cast
<
T_ACC
>
(
dY
[
index
])
*
static_cast
<
T_ACC
>
(
X
[
index
])
*
gamma_v
;
sum2
+=
static_cast
<
T_ACC
>
(
dY
[
index
])
*
gamma_v
;
}
sum1
=
BlockReduceSum
<
T_ACC
>
(
sum1
,
ds_shared
);
sum2
=
BlockReduceSum
<
T_ACC
>
(
sum2
,
db_shared
);
const
T_ACC
s
=
T_ACC
(
1
)
/
static_cast
<
T_ACC
>
(
N
);
__shared__
T_ACC
b
;
__shared__
T_ACC
c
;
if
(
threadIdx
.
x
==
0
)
{
b
=
(
sum2
*
static_cast
<
T_ACC
>
(
mean
[
i
])
-
sum1
)
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
s
;
c
=
-
(
b
*
static_cast
<
T_ACC
>
(
mean
[
i
])
+
sum2
*
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
s
);
}
__syncthreads
();
#pragma unroll
for
(
IndexType
j
=
threadIdx
.
x
;
j
<
N
;
j
+=
blockDim
.
x
)
{
const
IndexType
index
=
i
*
N
+
j
;
const
T_ACC
gamma_v
=
gamma
==
nullptr
?
T_ACC
(
1
)
:
static_cast
<
T_ACC
>
(
gamma
[
j
]);
dX
[
index
]
=
static_cast
<
scalar_t
>
(
static_cast
<
T_ACC
>
(
rstd
[
i
])
*
static_cast
<
T_ACC
>
(
dY
[
index
])
*
gamma_v
+
b
*
static_cast
<
T_ACC
>
(
X
[
index
])
+
c
+
(
add_to_output
==
nullptr
?
T_ACC
(
0
)
:
static_cast
<
T_ACC
>
(
add_to_output
[
index
])));
}
}
template
<
typename
T
>
void
LayerNormBackwardKernelImplInternal
(
oneflow
::
ep
::
Stream
*
stream
,
const
T
*
dY
,
const
T
*
X
,
const
acc_type
<
T
,
true
>*
mean
,
const
acc_type
<
T
,
true
>*
rstd
,
const
T
*
gamma
,
int64_t
M
,
int64_t
N
,
T
*
dX
,
const
T
*
add_to_output
)
{
using
T_ACC
=
acc_type
<
T
,
true
>
;
const
T
*
dY_data
=
dY
;
const
T
*
X_data
=
X
;
const
T_ACC
*
mean_data
=
mean
;
const
T_ACC
*
rstd_data
=
rstd
;
const
T
*
gamma_data
=
gamma
;
T
*
dX_data
=
dX
;
const
T
*
add_to_output_data
=
add_to_output
;
hipStream_t
cuda_stream
=
stream
->
As
<
oneflow
::
ep
::
CudaStream
>
()
->
cuda_stream
();
if
(
dX_data
!=
nullptr
)
{
LayerNormBackward_kernel
<
T
><<<
M
,
BlockReduceNumThreads
,
0
,
cuda_stream
>>>
(
N
,
dY_data
,
X_data
,
gamma_data
,
mean_data
,
rstd_data
,
dX_data
,
add_to_output_data
);
}
}
template
<
typename
T
>
void
LayerNormBackwardKernelImplInternalParam
(
oneflow
::
ep
::
Stream
*
stream
,
const
T
*
dY
,
const
T
*
X
,
const
acc_type
<
T
,
true
>*
mean
,
const
acc_type
<
T
,
true
>*
rstd
,
int64_t
M
,
int64_t
N
,
T
*
dgamma
,
T
*
dbeta
)
{
using
T_ACC
=
acc_type
<
T
,
true
>
;
const
T
*
dY_data
=
dY
;
const
T
*
X_data
=
X
;
const
T_ACC
*
mean_data
=
mean
;
const
T_ACC
*
rstd_data
=
rstd
;
hipStream_t
cuda_stream
=
stream
->
As
<
oneflow
::
ep
::
CudaStream
>
()
->
cuda_stream
();
T
*
dgamma_data
=
dgamma
;
T
*
dbeta_data
=
dbeta
;
if
(
M
<
512
)
{
// For small batch size, do colwise reduce directly.
const
int64_t
B
=
(
N
+
NumThreads
-
1
)
/
NumThreads
;
GammaBetaBackwardSimple
<
T
>
<<<
B
,
NumThreads
,
0
,
cuda_stream
>>>
(
M
,
N
,
dY_data
,
X_data
,
mean_data
,
rstd_data
,
dgamma_data
,
dbeta_data
);
}
else
{
const
int64_t
B
=
(
N
+
ColwiseReduceTileSize
-
1
)
/
ColwiseReduceTileSize
;
constexpr
int
kThreadX
=
ColwiseReduceTileSize
;
constexpr
int
kThreadY
=
ColwiseReduceTileSize
/
2
;
GammaBetaBackward
<
T
>
<<<
B
,
dim3
(
kThreadX
,
kThreadY
),
0
,
cuda_stream
>>>
(
M
,
N
,
dY_data
,
X_data
,
mean_data
,
rstd_data
,
dgamma_data
,
dbeta_data
);
}
}
namespace
oneflow
{
template
<
typename
T
>
class
LayerNormGpuKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
LayerNormGpuKernel
()
=
default
;
~
LayerNormGpuKernel
()
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
user_op
::
Tensor
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
user_op
::
Tensor
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
user_op
::
Tensor
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
double
epsilon
=
ctx
->
Attr
<
double
>
(
"epsilon"
);
int64_t
num_instances
=
mean
->
shape_view
().
elem_cnt
();
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
const
T
*
gamma_ptr
=
nullptr
;
const
T
*
beta_ptr
=
nullptr
;
if
(
ctx
->
has_input
(
"gamma"
,
0
))
{
const
user_op
::
Tensor
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
gamma_ptr
=
gamma
->
dptr
<
T
>
();
CHECK_EQ
(
gamma
->
shape_view
().
elem_cnt
(),
norm_size
);
}
if
(
ctx
->
has_input
(
"beta"
,
0
))
{
beta_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
)
->
dptr
<
T
>
();
}
// DispatchLayerNormForwardGpu<T>(ctx->stream(), num_instances, norm_size, epsilon, x->dptr<T>(),
// gamma_ptr, beta_ptr, y->mut_dptr<T>(), mean, inv_variance);
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
LayerNormKernelImplInternal
<
T
>
(
ctx
->
stream
(),
x
->
dptr
<
T
>
(),
gamma_ptr
,
beta_ptr
,
num_instances
,
norm_size
,
epsilon
,
y
->
mut_dptr
<
T
>
(),
mean
->
mut_dptr
<
ComputeType
>
(),
inv_variance
->
mut_dptr
<
ComputeType
>
());
};
};
#define REGISTER_LAYER_NORM_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm") \
.SetCreateFn<LayerNormGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("x", 0) == GetDataType<dtype>::value));
REGISTER_LAYER_NORM_CUDA_KERNEL
(
float
)
REGISTER_LAYER_NORM_CUDA_KERNEL
(
double
)
REGISTER_LAYER_NORM_CUDA_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_CUDA_KERNEL
(
nv_bfloat16
)
#endif
template
<
typename
T
>
class
LayerNormGradGpuKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
LayerNormGradGpuKernel
()
=
default
;
~
LayerNormGradGpuKernel
()
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
user_op
::
Tensor
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
const
user_op
::
Tensor
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
const
user_op
::
Tensor
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
user_op
::
Tensor
*
dx
=
ctx
->
Tensor4ArgNameAndIndex
(
"dx"
,
0
);
int64_t
num_instances
=
mean
->
shape_view
().
elem_cnt
();
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
const
T
*
gamma_ptr
=
nullptr
;
if
(
ctx
->
has_input
(
"gamma"
,
0
))
{
gamma_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
)
->
dptr
<
T
>
();
}
const
T
*
add_to_output_ptr
=
nullptr
;
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
CHECK_EQ
(
add_to_output
->
data_type
(),
dx
->
data_type
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
dx
->
shape_view
());
add_to_output_ptr
=
add_to_output
->
dptr
<
T
>
();
}
// LaunchLayerNormBackward<T>(ctx->stream(), num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(),
// mean, inv_variance, gamma_ptr, add_to_output_ptr, dx->mut_dptr<T>());
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
LayerNormBackwardKernelImplInternal
<
T
>
(
ctx
->
stream
(),
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
inv_variance
->
dptr
<
ComputeType
>
(),
gamma_ptr
,
num_instances
,
norm_size
,
dx
->
mut_dptr
<
T
>
(),
add_to_output_ptr
);
};
};
#define REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_grad") \
.SetCreateFn<LayerNormGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInplaceProposalFn( \
[](const user_op::InferContext& ctx, \
const user_op::AddInplaceArgPair& AddInplaceArgPairFn) -> Maybe<void> { \
if (ctx.has_input("_add_to_output", 0)) { \
OF_RETURN_IF_ERROR(AddInplaceArgPairFn("dx", 0, "_add_to_output", 0, true)); \
} \
return Maybe<void>::Ok(); \
});
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
float
)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
double
)
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_GRAD_CUDA_KERNEL
(
nv_bfloat16
)
#endif
template
<
typename
T
>
class
LayerNormParamGradGpuKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
LayerNormParamGradGpuKernel
()
=
default
;
~
LayerNormParamGradGpuKernel
()
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
user_op
::
Tensor
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
user_op
::
Tensor
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
const
user_op
::
Tensor
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
const
user_op
::
Tensor
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
int64_t
num_instances
=
mean
->
shape_view
().
elem_cnt
();
int64_t
norm_size
=
x
->
shape_view
().
elem_cnt
()
/
num_instances
;
user_op
::
Tensor
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
// const DataType data_type = dy->data_type();
// const int grid_dim_x = (norm_size + tile_size - 1) / tile_size;
// const int grid_dim_y = GetGirdDimY<T>(num_instances, norm_size);
// const size_t tmp_gamma_diff_size = grid_dim_y * norm_size * sizeof(T);
// T* tmp_gamma_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr());
// T* tmp_beta_diff_ptr = reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + tmp_gamma_diff_size);
// T* reduce_buf_ptr =
// reinterpret_cast<T*>(tmp_buffer->mut_dptr<char>() + 2 * tmp_gamma_diff_size);
using
ComputeType
=
typename
cuda
::
layer_norm
::
DefaultComputeType
<
T
>::
type
;
// LayerNormParamGrad<T, ComputeType><<<dim3(grid_dim_x, grid_dim_y), dim3(32, 32 / num_per_block),
// 0, ctx->stream()->As<ep::CudaStream>()->cuda_stream()>>>(
// num_instances, norm_size, dy->dptr<T>(), x->dptr<T>(), mean->dptr<ComputeType>(),
// inv_variance->dptr<ComputeType>(), tmp_gamma_diff_ptr, tmp_beta_diff_ptr);
// const int32_t m = norm_size;
// const int32_t n = 1;
// const int32_t k = grid_dim_y;
// std::unique_ptr<ep::primitive::Fill> fill =
// ep::primitive::NewPrimitive<ep::primitive::FillFactory>(ctx->stream()->device_type(),
// data_type);
// CHECK(fill);
// fill->Launch(ctx->stream(), reduce_buf_ptr, 1.0, grid_dim_y);
// std::unique_ptr<ep::primitive::Matmul> matmul =
// ep::primitive::NewPrimitive<ep::primitive::MatmulFactory>(
// ctx->stream()->device_type(), data_type, ep::primitive::BlasTransposeType::T,
// ep::primitive::BlasTransposeType::N);
// CHECK(matmul);
// if (ctx->has_output("gamma_diff", 0)) {
// user_op::Tensor* gamma_diff = ctx->Tensor4ArgNameAndIndex("gamma_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_gamma_diff_ptr, reduce_buf_ptr, 0.0,
// gamma_diff->mut_dptr());
// }
// if (ctx->has_output("beta_diff", 0)) {
// user_op::Tensor* beta_diff = ctx->Tensor4ArgNameAndIndex("beta_diff", 0);
// matmul->Launch(ctx->stream(), m, n, k, 1.0, tmp_beta_diff_ptr, reduce_buf_ptr, 0.0,
// beta_diff->mut_dptr());
// }
T
*
gamma_diff_ptr
=
nullptr
;
T
*
beta_diff_ptr
=
nullptr
;
if
(
ctx
->
has_output
(
"gamma_diff"
,
0
))
{
gamma_diff_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma_diff"
,
0
)
->
mut_dptr
<
T
>
();
}
if
(
ctx
->
has_output
(
"beta_diff"
,
0
))
{
beta_diff_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta_diff"
,
0
)
->
mut_dptr
<
T
>
();
}
LayerNormBackwardKernelImplInternalParam
<
T
>
(
ctx
->
stream
(),
dy
->
dptr
<
T
>
(),
x
->
dptr
<
T
>
(),
mean
->
dptr
<
ComputeType
>
(),
inv_variance
->
dptr
<
ComputeType
>
(),
num_instances
,
norm_size
,
gamma_diff_ptr
,
beta_diff_ptr
);
};
};
#define REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL(dtype) \
REGISTER_USER_KERNEL("layer_norm_param_grad") \
.SetCreateFn<LayerNormParamGradGpuKernel<dtype>>() \
.SetIsMatchedHob((user_op::HobDeviceType() == DeviceType::kCUDA) \
&& (user_op::HobDataType("dy", 0) == GetDataType<dtype>::value)) \
.SetInferTmpSizeFn([](user_op::InferContext* ctx) { \
const int64_t begin_params_axis = ctx->Attr<int64_t>("begin_params_axis"); \
const bool has_gamma_diff = ctx->has_output("gamma_diff", 0); \
const bool has_beta_diff = ctx->has_output("beta_diff", 0); \
const auto& dy = ctx->InputTensorDesc("dy", 0); \
const int64_t num_instances = dy.shape().Count(0, begin_params_axis); \
const int64_t norm_size = dy.shape().Count(begin_params_axis); \
const int grid_dim_y = num_instances; \
size_t tmp_buffer_size = (2 * grid_dim_y * norm_size + grid_dim_y) * sizeof(dtype); \
return tmp_buffer_size; \
});
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
float
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
double
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
half
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
half
)
#if CUDA_VERSION >= 11000
#if CUDA_VERSION >= 11000
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
nv_bfloat16
)
REGISTER_LAYER_NORM_PARAM_GRAD_GPU_KERNEL
(
nv_bfloat16
)
#endif
#endif
}
}
// namespace oneflow
#endif
oneflow/user/kernels/math_binary_elementwise_func.h
View file @
6046d8fb
...
@@ -28,7 +28,7 @@ limitations under the License.
...
@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__)
#elif defined(__HIPCC__)
#include <hip/
hsa_detail/
math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__HIP_DEVICE_COMPILE__)
...
...
oneflow/user/kernels/normalization_kernel.cu
View file @
6046d8fb
...
@@ -882,6 +882,22 @@ namespace oneflow {
...
@@ -882,6 +882,22 @@ namespace oneflow {
namespace
{
namespace
{
template
<
typename
T
>
void
printTensor
(
const
std
::
string
&
str
,
const
T
*
devTensor
,
size_t
size
)
{
T
*
hostTensor
;
hostTensor
=
(
T
*
)
malloc
(
size
*
sizeof
(
T
));
hipMemcpy
(
hostTensor
,
devTensor
,
size
*
sizeof
(
T
),
hipMemcpyDeviceToHost
);
std
::
cout
<<
str
<<
": "
;
for
(
int
i
;
i
<
size
;
i
++
)
{
if
(
i
%
16
==
0
)
{
std
::
cout
<<
std
::
endl
;
}
std
::
cout
<<
hostTensor
[
i
]
<<
", "
;
}
std
::
cout
<<
str
<<
": finish"
<<
std
::
endl
;
free
(
hostTensor
);
}
hipdnnBatchNormMode_t
getCudnnBatchNormMode
(
const
int64_t
dim
)
{
hipdnnBatchNormMode_t
getCudnnBatchNormMode
(
const
int64_t
dim
)
{
if
(
dim
==
2
)
{
if
(
dim
==
2
)
{
return
HIPDNN_BATCHNORM_PER_ACTIVATION
;
return
HIPDNN_BATCHNORM_PER_ACTIVATION
;
...
@@ -969,6 +985,15 @@ class CudnnTensorDescHelper final {
...
@@ -969,6 +985,15 @@ class CudnnTensorDescHelper final {
int32_t
param_size_
=
0
;
int32_t
param_size_
=
0
;
};
};
size_t
InferInferTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
y
=
ctx
->
OutputTensorDesc
(
"y"
,
0
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
return
y
.
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
y
.
data_type
());
}
else
{
return
1
;
}
}
size_t
InferTrainWorkspaceSize
(
const
ShapeView
&
x_shape
,
const
DataType
data_type
,
size_t
InferTrainWorkspaceSize
(
const
ShapeView
&
x_shape
,
const
DataType
data_type
,
const
int32_t
axis
)
{
const
int32_t
axis
)
{
return
1
;
return
1
;
...
@@ -976,8 +1001,13 @@ size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_typ
...
@@ -976,8 +1001,13 @@ size_t InferTrainWorkspaceSize(const ShapeView& x_shape, const DataType data_typ
size_t
InferTrainTmpSize
(
user_op
::
InferContext
*
ctx
)
{
size_t
InferTrainTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
&
y
=
ctx
->
OutputTensorDesc
(
"y"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
return
InferTrainWorkspaceSize
(
x
.
shape
(),
x
.
data_type
(),
axis
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
return
y
.
shape
().
elem_cnt
()
*
GetSizeOfDataType
(
y
.
data_type
());
}
else
{
return
InferTrainWorkspaceSize
(
x
.
shape
(),
x
.
data_type
(),
axis
);
}
}
}
size_t
InferGradWorkspaceSize
(
const
ShapeView
&
x_shape
,
const
DataType
data_type
,
size_t
InferGradWorkspaceSize
(
const
ShapeView
&
x_shape
,
const
DataType
data_type
,
...
@@ -1016,6 +1046,9 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
...
@@ -1016,6 +1046,9 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
void
*
add_to_output_dev
=
tmp_buffer
->
mut_dptr
<
void
>
();
const
DataType
data_type
=
x
->
data_type
();
const
DataType
data_type
=
x
->
data_type
();
CHECK_EQ
(
x
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
x
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
y
->
data_type
(),
data_type
);
CHECK_EQ
(
y
->
data_type
(),
data_type
);
...
@@ -1030,17 +1063,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
...
@@ -1030,17 +1063,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper
.
CheckParamTensor
(
moving_variance
);
desc_helper
.
CheckParamTensor
(
moving_variance
);
const
void
*
sp_alpha
=
CudnnSPOnePtr
(
data_type
);
const
void
*
sp_alpha
=
CudnnSPOnePtr
(
data_type
);
const
void
*
sp_beta
;
const
void
*
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
CHECK_EQ
(
add_to_output
->
data_type
(),
y
->
data_type
());
CHECK_EQ
(
add_to_output
->
data_type
(),
y
->
data_type
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
y
->
shape_view
());
Memcpy
<
DeviceType
::
kCUDA
>
(
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
y
->
mut_dptr
<
void
>
()
,
add_to_output
->
dptr
<
void
>
(),
ctx
->
stream
(),
add_to_output_dev
,
add_to_output
->
dptr
<
void
>
(),
add_to_output
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
add_to_output
->
data_type
()));
add_to_output
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
add_to_output
->
data_type
()));
sp_beta
=
CudnnSPOnePtr
(
data_type
);
}
else
{
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
}
}
OF_CUDNN_CHECK
(
hipdnnBatchNormalizationForwardInference
(
OF_CUDNN_CHECK
(
hipdnnBatchNormalizationForwardInference
(
...
@@ -1048,6 +1079,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
...
@@ -1048,6 +1079,15 @@ class NormalizationInferenceKernel final : public user_op::OpKernel,
desc_helper
.
xy_desc
(),
x
->
dptr
(),
desc_helper
.
xy_desc
(),
y
->
mut_dptr
(),
desc_helper
.
xy_desc
(),
x
->
dptr
(),
desc_helper
.
xy_desc
(),
y
->
mut_dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
moving_mean
->
dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
moving_mean
->
dptr
(),
moving_variance
->
dptr
(),
epsilon
));
moving_variance
->
dptr
(),
epsilon
));
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
sp_beta
=
CudnnSPOnePtr
(
data_type
);
OF_CUDNN_CHECK
(
hipdnnAddTensor
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
sp_alpha
,
desc_helper
.
xy_desc
(),
add_to_output_dev
,
sp_beta
,
desc_helper
.
xy_desc
(),
y
->
mut_dptr
()));
}
}
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
...
@@ -1057,6 +1097,7 @@ REGISTER_USER_KERNEL("normalization")
...
@@ -1057,6 +1097,7 @@ REGISTER_USER_KERNEL("normalization")
.
SetCreateFn
<
NormalizationInferenceKernel
>
()
.
SetCreateFn
<
NormalizationInferenceKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
)
&&
(
user_op
::
HobAttr
<
bool
>
(
"training"
)
==
false
))
&&
(
user_op
::
HobAttr
<
bool
>
(
"training"
)
==
false
))
.
SetInferTmpSizeFn
(
InferInferTmpSize
)
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
ctx
,
.
SetInplaceProposalFn
([](
const
user_op
::
InferContext
&
ctx
,
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
user_op
::
AddInplaceArgPair
AddInplaceArgPairFn
)
->
Maybe
<
void
>
{
if
(
ctx
.
has_input
(
"_add_to_output"
,
0
))
{
if
(
ctx
.
has_input
(
"_add_to_output"
,
0
))
{
...
@@ -1068,76 +1109,78 @@ REGISTER_USER_KERNEL("normalization")
...
@@ -1068,76 +1109,78 @@ REGISTER_USER_KERNEL("normalization")
constexpr
int64_t
kCudaWarpSize
=
64
;
constexpr
int64_t
kCudaWarpSize
=
64
;
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ReluGpu
(
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
32
_t
*
mask
)
{
__global__
void
ReluGpu
(
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
64
_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
T
zero
=
static_cast
<
T
>
(
0.
f
);
const
T
zero
=
static_cast
<
T
>
(
0.
f
);
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
T
x_val
=
x
[
i
];
const
T
x_val
=
x
[
i
];
const
bool
is_positive
=
(
x_val
>
zero
);
const
bool
is_positive
=
(
x_val
>
zero
);
int32_t
warp_mask
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
unsigned
long
long
int
warp_mask_tmp
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
int64_t
*
warp_mask
=
reinterpret_cast
<
int64_t
*>
(
&
warp_mask_tmp
);
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
*
warp_mask
;
}
y
[
i
]
=
is_positive
?
x_val
:
zero
;
y
[
i
]
=
is_positive
?
x_val
:
zero
;
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
AddReluGpu
(
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
32
_t
*
mask
)
{
__global__
void
AddReluGpu
(
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
64
_t
*
mask
)
{
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
const
T
zero
=
static_cast
<
T
>
(
0.
f
);
const
T
zero
=
static_cast
<
T
>
(
0.
f
);
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
const
T
sum
=
x
[
i
]
+
addend
[
i
];
const
T
sum
=
x
[
i
]
+
addend
[
i
];
const
bool
is_positive
=
(
sum
>
zero
);
const
bool
is_positive
=
(
sum
>
zero
);
int32_t
warp_mask
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
unsigned
long
long
int
warp_mask_tmp
=
__ballot
(
static_cast
<
int
>
(
is_positive
));
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
warp_mask
;
}
int64_t
*
warp_mask
=
reinterpret_cast
<
int64_t
*>
(
&
warp_mask_tmp
);
if
(
lane_id
==
0
)
{
mask
[
i
/
kCudaWarpSize
]
=
*
warp_mask
;
}
y
[
i
]
=
is_positive
?
sum
:
zero
;
y
[
i
]
=
is_positive
?
sum
:
zero
;
}
}
}
}
template
<
typename
T
>
template
<
typename
T
>
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
32
_t
*
mask
)
{
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
T
*
y
,
int
64
_t
*
mask
)
{
ReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
ReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
x
,
y
,
mask
);
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
x
,
y
,
mask
);
}
}
template
<
typename
T
>
template
<
typename
T
>
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
32
_t
*
mask
)
{
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
T
*
x
,
const
T
*
addend
,
T
*
y
,
int
64
_t
*
mask
)
{
AddReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
AddReluGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
x
,
addend
,
y
,
mask
);
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
x
,
addend
,
y
,
mask
);
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
ReluBackwardGpu
(
int64_t
n
,
const
int
32
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
__global__
void
ReluBackwardGpu
(
int64_t
n
,
const
int
64
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
int
32
_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
int
64
_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
int
32
_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
int
64
_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
bool
is_positive
=
mask_val
&
(
1
<<
lane_id
);
bool
is_positive
=
mask_val
&
(
(
int64_t
)
1
<<
lane_id
);
addend_diff
[
i
]
=
static_cast
<
T
>
(
is_positive
)
*
dy
[
i
];
addend_diff
[
i
]
=
static_cast
<
T
>
(
is_positive
)
*
dy
[
i
];
}
}
}
}
#if CUDA_VERSION >= 11000
//
#if CUDA_VERSION >= 11000
template
<
>
//
template<>
__global__
void
ReluBackwardGpu
<
nv_bfloat16
>
(
int64_t
n
,
const
int32_t
*
mask
,
const
nv_bfloat16
*
dy
,
//
__global__ void ReluBackwardGpu<nv_bfloat16>(int64_t n, const int32_t* mask, const nv_bfloat16* dy,
nv_bfloat16
*
addend_diff
)
{
//
nv_bfloat16* addend_diff) {
int32_t
lane_id
=
threadIdx
.
x
%
kCudaWarpSize
;
//
int32_t lane_id = threadIdx.x % kCudaWarpSize;
CUDA_1D_KERNEL_LOOP
(
i
,
n
)
{
//
CUDA_1D_KERNEL_LOOP(i, n) {
int32_t
mask_val
=
mask
[
i
/
kCudaWarpSize
];
//
int32_t mask_val = mask[i / kCudaWarpSize];
bool
is_positive
=
mask_val
&
(
1
<<
lane_id
);
//
bool is_positive = mask_val & (1 << lane_id);
addend_diff
[
i
]
=
static_cast
<
nv_bfloat16
>
(
static_cast
<
float
>
(
is_positive
))
*
dy
[
i
];
//
addend_diff[i] = static_cast<nv_bfloat16>(static_cast<float>(is_positive)) * dy[i];
}
//
}
}
//
}
#endif
//
#endif
template
<
typename
T
>
template
<
typename
T
>
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
int
32
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
int
64
_t
*
mask
,
const
T
*
dy
,
T
*
addend_diff
)
{
ReluBackwardGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
ReluBackwardGpu
<
T
><<<
BlocksNum4ThreadsNum
(
n
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
mask
,
dy
,
addend_diff
);
stream
->
As
<
ep
::
CudaStream
>
()
->
cuda_stream
()
>>>
(
n
,
mask
,
dy
,
addend_diff
);
}
}
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
void
*
x
,
void
*
y
,
void
Relu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
void
*
x
,
void
*
y
,
int
32
_t
*
mask
)
{
int
64
_t
*
mask
)
{
if
(
data_type
==
kFloat
)
{
if
(
data_type
==
kFloat
)
{
Relu
<
float
>
(
stream
,
n
,
reinterpret_cast
<
const
float
*>
(
x
),
reinterpret_cast
<
float
*>
(
y
),
mask
);
Relu
<
float
>
(
stream
,
n
,
reinterpret_cast
<
const
float
*>
(
x
),
reinterpret_cast
<
float
*>
(
y
),
mask
);
}
else
if
(
data_type
==
kDouble
)
{
}
else
if
(
data_type
==
kDouble
)
{
...
@@ -1156,7 +1199,7 @@ void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x
...
@@ -1156,7 +1199,7 @@ void Relu(ep::Stream* stream, int64_t n, const DataType data_type, const void* x
}
}
}
}
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
void
*
x
,
void
AddRelu
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
void
*
x
,
const
void
*
addend
,
void
*
y
,
int
32
_t
*
mask
)
{
const
void
*
addend
,
void
*
y
,
int
64
_t
*
mask
)
{
if
(
data_type
==
kFloat
)
{
if
(
data_type
==
kFloat
)
{
AddRelu
<
float
>
(
stream
,
n
,
reinterpret_cast
<
const
float
*>
(
x
),
AddRelu
<
float
>
(
stream
,
n
,
reinterpret_cast
<
const
float
*>
(
x
),
reinterpret_cast
<
const
float
*>
(
addend
),
reinterpret_cast
<
float
*>
(
y
),
mask
);
reinterpret_cast
<
const
float
*>
(
addend
),
reinterpret_cast
<
float
*>
(
y
),
mask
);
...
@@ -1178,7 +1221,7 @@ void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void
...
@@ -1178,7 +1221,7 @@ void AddRelu(ep::Stream* stream, int64_t n, const DataType data_type, const void
UNIMPLEMENTED
();
UNIMPLEMENTED
();
}
}
}
}
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
int
32
_t
*
mask
,
void
ReluBackward
(
ep
::
Stream
*
stream
,
int64_t
n
,
const
DataType
data_type
,
const
int
64
_t
*
mask
,
const
void
*
dy
,
void
*
addend_diff
)
{
const
void
*
dy
,
void
*
addend_diff
)
{
if
(
data_type
==
kFloat
)
{
if
(
data_type
==
kFloat
)
{
ReluBackward
<
float
>
(
stream
,
n
,
mask
,
reinterpret_cast
<
const
float
*>
(
dy
),
ReluBackward
<
float
>
(
stream
,
n
,
mask
,
reinterpret_cast
<
const
float
*>
(
dy
),
...
@@ -1225,6 +1268,9 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
...
@@ -1225,6 +1268,9 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
hipdnnBatchNormMode_t
mode
=
getCudnnBatchNormMode
(
x
->
shape_view
().
NumAxes
());
hipdnnBatchNormMode_t
mode
=
getCudnnBatchNormMode
(
x
->
shape_view
().
NumAxes
());
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
mode
);
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
mode
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
void
*
add_to_output_dev
=
tmp_buffer
->
mut_dptr
<
void
>
();
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
...
@@ -1244,17 +1290,15 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
...
@@ -1244,17 +1290,15 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
desc_helper
.
CheckParamTensor
(
moving_variance
);
desc_helper
.
CheckParamTensor
(
moving_variance
);
}
}
const
void
*
sp_alpha
=
CudnnSPOnePtr
(
data_type
);
const
void
*
sp_alpha
=
CudnnSPOnePtr
(
data_type
);
const
void
*
sp_beta
;
const
void
*
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
const
user_op
::
Tensor
*
add_to_output
=
ctx
->
Tensor4ArgNameAndIndex
(
"_add_to_output"
,
0
);
CHECK_EQ
(
add_to_output
->
data_type
(),
y
->
data_type
());
CHECK_EQ
(
add_to_output
->
data_type
(),
y
->
data_type
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
add_to_output
->
shape_view
(),
y
->
shape_view
());
Memcpy
<
DeviceType
::
kCUDA
>
(
Memcpy
<
DeviceType
::
kCUDA
>
(
ctx
->
stream
(),
y
->
mut_dptr
<
void
>
()
,
add_to_output
->
dptr
<
void
>
(),
ctx
->
stream
(),
add_to_output_dev
,
add_to_output
->
dptr
<
void
>
(),
add_to_output
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
add_to_output
->
data_type
()));
add_to_output
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
add_to_output
->
data_type
()));
sp_beta
=
CudnnSPOnePtr
(
data_type
);
}
else
{
sp_beta
=
CudnnSPZeroPtr
(
data_type
);
}
}
OF_CUDNN_CHECK
(
hipdnnBatchNormalizationForwardTraining
(
OF_CUDNN_CHECK
(
hipdnnBatchNormalizationForwardTraining
(
...
@@ -1265,6 +1309,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
...
@@ -1265,6 +1309,14 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
moving_variance
?
moving_variance
->
mut_dptr
()
:
NULL
,
epsilon
,
mean
->
mut_dptr
(),
moving_variance
?
moving_variance
->
mut_dptr
()
:
NULL
,
epsilon
,
mean
->
mut_dptr
(),
inv_variance
->
mut_dptr
()));
inv_variance
->
mut_dptr
()));
if
(
ctx
->
has_input
(
"_add_to_output"
,
0
))
{
sp_beta
=
CudnnSPOnePtr
(
data_type
);
OF_CUDNN_CHECK
(
hipdnnAddTensor
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
sp_alpha
,
desc_helper
.
xy_desc
(),
add_to_output_dev
,
sp_beta
,
desc_helper
.
xy_desc
(),
y
->
mut_dptr
()));
}
if
(
ctx
->
op_type_name
()
==
"normalization_add_relu"
)
{
if
(
ctx
->
op_type_name
()
==
"normalization_add_relu"
)
{
CHECK
(
!
ctx
->
has_input
(
"_add_to_output"
,
0
));
CHECK
(
!
ctx
->
has_input
(
"_add_to_output"
,
0
));
const
int64_t
elem_cnt
=
x
->
shape_view
().
elem_cnt
();
const
int64_t
elem_cnt
=
x
->
shape_view
().
elem_cnt
();
...
@@ -1272,10 +1324,10 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
...
@@ -1272,10 +1324,10 @@ class NormalizationTrainKernel final : public user_op::OpKernel, public user_op:
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
const
auto
*
addend
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend"
,
0
);
const
auto
*
addend
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend"
,
0
);
AddRelu
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
y
->
dptr
(),
addend
->
dptr
(),
y
->
mut_dptr
(),
AddRelu
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
y
->
dptr
(),
addend
->
dptr
(),
y
->
mut_dptr
(),
mask
->
mut_dptr
<
int
32
_t
>
());
mask
->
mut_dptr
<
int
64
_t
>
());
}
else
{
}
else
{
Relu
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
y
->
dptr
(),
y
->
mut_dptr
(),
Relu
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
y
->
dptr
(),
y
->
mut_dptr
(),
mask
->
mut_dptr
<
int
32
_t
>
());
mask
->
mut_dptr
<
int
64
_t
>
());
}
}
}
}
}
}
...
@@ -1351,7 +1403,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
...
@@ -1351,7 +1403,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
user_op
::
Tensor
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
user_op
::
Tensor
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
user_op
::
Tensor
*
addend_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend_diff"
,
0
);
user_op
::
Tensor
*
addend_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend_diff"
,
0
);
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
32
_t
>
(),
dy
->
dptr
(),
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
64
_t
>
(),
dy
->
dptr
(),
addend_diff
->
mut_dptr
());
addend_diff
->
mut_dptr
());
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
();
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
();
bn_workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
bn_workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
...
@@ -1361,7 +1413,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
...
@@ -1361,7 +1413,7 @@ class NormalizationGradUserKernel final : public user_op::OpKernel,
const
size_t
relu_dx_size
=
const
size_t
relu_dx_size
=
GetCudaAlignedSize
(
dy
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
data_type
));
GetCudaAlignedSize
(
dy
->
shape_view
().
elem_cnt
()
*
GetSizeOfDataType
(
data_type
));
CHECK_GE
(
tmp_buffer_size
,
relu_dx_size
);
CHECK_GE
(
tmp_buffer_size
,
relu_dx_size
);
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
32
_t
>
(),
dy
->
dptr
(),
ReluBackward
(
ctx
->
stream
(),
elem_cnt
,
data_type
,
mask
->
dptr
<
int
64
_t
>
(),
dy
->
dptr
(),
tmp_buffer
->
mut_dptr
());
tmp_buffer
->
mut_dptr
());
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
<
char
>
()
+
relu_dx_size
;
bn_workspace_ptr
=
tmp_buffer
->
mut_dptr
<
char
>
()
+
relu_dx_size
;
bn_workspace_size
=
tmp_buffer_size
-
relu_dx_size
;
bn_workspace_size
=
tmp_buffer_size
-
relu_dx_size
;
...
@@ -1393,231 +1445,6 @@ REGISTER_USER_KERNEL("normalization_add_relu_grad")
...
@@ -1393,231 +1445,6 @@ REGISTER_USER_KERNEL("normalization_add_relu_grad")
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetInferTmpSizeFn
(
InferGradTmpSize
);
.
SetInferTmpSizeFn
(
InferGradTmpSize
);
#if (HIPDNN_VERSION >= 7401)
size_t
InferFusedNormalizationAddReluTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
CudnnTensorDescHelper
desc_helper
(
x
.
shape
(),
x
.
data_type
(),
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
size_t
size_in_bytes
;
hipdnnHandle_t
handle
=
Singleton
<
CudnnHandlePool
>::
Get
()
->
Get
();
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
cudnnBatchNormOps_t
ops
;
hipdnnTensorDescriptor_t
z_desc
;
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
z_desc
=
desc_helper
.
xy_desc
();
}
else
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
z_desc
=
nullptr
;
}
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
handle
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
z_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
size_in_bytes
));
Singleton
<
CudnnHandlePool
>::
Get
()
->
Put
(
handle
);
return
std
::
max
(
size_in_bytes
,
static_cast
<
size_t
>
(
1
));
}
size_t
InferFusedNormalizationAddReluGradTmpSize
(
user_op
::
InferContext
*
ctx
)
{
const
auto
&
x
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
CudnnTensorDescHelper
desc_helper
(
x
.
shape
(),
x
.
data_type
(),
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
size_t
size_in_bytes
;
hipdnnHandle_t
handle
=
Singleton
<
CudnnHandlePool
>::
Get
()
->
Get
();
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
cudnnBatchNormOps_t
ops
;
hipdnnTensorDescriptor_t
z_desc
;
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
z_desc
=
desc_helper
.
xy_desc
();
}
else
{
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
z_desc
=
nullptr
;
}
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
handle
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
z_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
size_in_bytes
));
Singleton
<
CudnnHandlePool
>::
Get
()
->
Put
(
handle
);
return
std
::
max
(
size_in_bytes
,
static_cast
<
size_t
>
(
1
));
}
class
FusedNormalizationAddReluKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
FusedNormalizationAddReluKernel
()
=
default
;
~
FusedNormalizationAddReluKernel
()
override
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
auto
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
auto
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
auto
*
moving_mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"moving_mean"
,
0
);
auto
*
moving_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"moving_variance"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
const
auto
momentum
=
ctx
->
Attr
<
float
>
(
"momentum"
);
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
auto
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
auto
*
reserve_space
=
ctx
->
Tensor4ArgNameAndIndex
(
"reserve_space"
,
0
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
DataType
data_type
=
x
->
data_type
();
CHECK_EQ
(
x
->
shape_view
(),
y
->
shape_view
());
CHECK_EQ
(
y
->
data_type
(),
data_type
);
CHECK_GE
(
axis
,
0
);
CHECK_LT
(
axis
,
x
->
shape_view
().
NumAxes
());
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
desc_helper
.
CheckParamTensor
(
gamma
);
desc_helper
.
CheckParamTensor
(
beta
);
desc_helper
.
CheckParamTensor
(
moving_mean
);
desc_helper
.
CheckParamTensor
(
moving_variance
);
desc_helper
.
CheckParamTensor
(
mean
);
desc_helper
.
CheckParamTensor
(
inv_variance
);
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
hipdnnTensorDescriptor_t
z_desc
;
const
void
*
z_ptr
;
cudnnBatchNormOps_t
ops
;
if
(
ctx
->
has_input
(
"addend"
,
0
))
{
z_desc
=
desc_helper
.
xy_desc
();
z_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend"
,
0
)
->
dptr
();
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
}
else
{
z_desc
=
nullptr
;
z_ptr
=
nullptr
;
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
}
size_t
min_workspace_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationForwardTrainingExWorkspaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
z_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
min_workspace_size
));
const
size_t
workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
CHECK_GE
(
workspace_size
,
min_workspace_size
);
size_t
min_reserve_space_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
activation_desc
.
Get
(),
desc_helper
.
xy_desc
(),
&
min_reserve_space_size
));
const
size_t
reserve_space_size
=
reserve_space
->
shape_view
().
elem_cnt
();
CHECK_GE
(
reserve_space_size
,
min_reserve_space_size
);
OF_CUDNN_CHECK
(
cudnnBatchNormalizationForwardTrainingEx
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
CudnnSPOnePtr
(
data_type
),
CudnnSPZeroPtr
(
data_type
),
desc_helper
.
xy_desc
(),
x
->
dptr
(),
z_desc
,
z_ptr
,
desc_helper
.
xy_desc
(),
y
->
mut_dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
1.0
-
momentum
,
moving_mean
->
mut_dptr
(),
moving_variance
->
mut_dptr
(),
epsilon
,
mean
->
mut_dptr
(),
inv_variance
->
mut_dptr
(),
activation_desc
.
Get
(),
tmp_buffer
->
mut_dptr
(),
workspace_size
,
reserve_space
->
mut_dptr
(),
reserve_space_size
));
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"cudnn_fused_normalization_add_relu"
)
.
SetCreateFn
<
FusedNormalizationAddReluKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetInferTmpSizeFn
(
InferFusedNormalizationAddReluTmpSize
);
class
FusedNormalizationAddReluGradUserKernel
final
:
public
user_op
::
OpKernel
,
public
user_op
::
CudaGraphSupport
{
public:
FusedNormalizationAddReluGradUserKernel
()
=
default
;
~
FusedNormalizationAddReluGradUserKernel
()
override
=
default
;
private:
using
user_op
::
OpKernel
::
Compute
;
void
Compute
(
user_op
::
KernelComputeContext
*
ctx
)
const
override
{
const
auto
*
x
=
ctx
->
Tensor4ArgNameAndIndex
(
"x"
,
0
);
const
auto
*
y
=
ctx
->
Tensor4ArgNameAndIndex
(
"y"
,
0
);
auto
*
dx
=
ctx
->
Tensor4ArgNameAndIndex
(
"dx"
,
0
);
const
auto
*
dy
=
ctx
->
Tensor4ArgNameAndIndex
(
"dy"
,
0
);
const
auto
*
gamma
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma"
,
0
);
const
auto
*
beta
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta"
,
0
);
auto
*
gamma_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"gamma_diff"
,
0
);
auto
*
beta_diff
=
ctx
->
Tensor4ArgNameAndIndex
(
"beta_diff"
,
0
);
const
auto
*
mean
=
ctx
->
Tensor4ArgNameAndIndex
(
"mean"
,
0
);
const
auto
*
inv_variance
=
ctx
->
Tensor4ArgNameAndIndex
(
"inv_variance"
,
0
);
const
auto
*
reserve_space
=
ctx
->
Tensor4ArgNameAndIndex
(
"reserve_space"
,
0
);
auto
*
tmp_buffer
=
ctx
->
Tensor4ArgNameAndIndex
(
"tmp_buffer"
,
0
);
const
auto
axis
=
ctx
->
Attr
<
int32_t
>
(
"axis"
);
const
auto
epsilon
=
ctx
->
Attr
<
float
>
(
"epsilon"
);
const
DataType
data_type
=
x
->
data_type
();
CHECK_EQ
(
dy
->
shape_view
(),
x
->
shape_view
());
CHECK_EQ
(
dy
->
data_type
(),
data_type
);
CHECK_EQ
(
dx
->
shape_view
(),
x
->
shape_view
());
CHECK_EQ
(
dx
->
data_type
(),
data_type
);
CHECK_GE
(
axis
,
0
);
CHECK_LT
(
axis
,
x
->
shape_view
().
NumAxes
());
const
CudnnTensorDescHelper
desc_helper
(
x
->
shape_view
(),
data_type
,
axis
,
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
);
desc_helper
.
CheckParamTensor
(
gamma
);
desc_helper
.
CheckParamTensor
(
beta
);
desc_helper
.
CheckParamTensor
(
gamma_diff
);
desc_helper
.
CheckParamTensor
(
beta_diff
);
desc_helper
.
CheckParamTensor
(
mean
);
desc_helper
.
CheckParamTensor
(
inv_variance
);
CudnnActivationDesc
activation_desc
(
HIPDNN_ACTIVATION_RELU
,
HIPDNN_PROPAGATE_NAN
,
0
);
hipdnnTensorDescriptor_t
dz_desc
;
void
*
dz_ptr
;
cudnnBatchNormOps_t
ops
;
if
(
ctx
->
has_output
(
"addend_diff"
,
0
))
{
dz_desc
=
desc_helper
.
xy_desc
();
dz_ptr
=
ctx
->
Tensor4ArgNameAndIndex
(
"addend_diff"
,
0
)
->
mut_dptr
();
ops
=
CUDNN_BATCHNORM_OPS_BN_ADD_ACTIVATION
;
}
else
{
dz_desc
=
nullptr
;
dz_ptr
=
nullptr
;
ops
=
CUDNN_BATCHNORM_OPS_BN_ACTIVATION
;
}
size_t
min_workspace_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationBackwardExWorkspaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
desc_helper
.
xy_desc
(),
dz_desc
,
desc_helper
.
xy_desc
(),
desc_helper
.
param_desc
(),
activation_desc
.
Get
(),
&
min_workspace_size
));
const
size_t
workspace_size
=
tmp_buffer
->
shape_view
().
elem_cnt
();
CHECK_GE
(
workspace_size
,
min_workspace_size
);
size_t
min_reserve_space_size
;
OF_CUDNN_CHECK
(
cudnnGetBatchNormalizationTrainingExReserveSpaceSize
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
activation_desc
.
Get
(),
desc_helper
.
xy_desc
(),
&
min_reserve_space_size
));
const
size_t
reserve_space_size
=
reserve_space
->
shape_view
().
elem_cnt
();
CHECK_GE
(
reserve_space_size
,
min_reserve_space_size
);
OF_CUDNN_CHECK
(
cudnnBatchNormalizationBackwardEx
(
ctx
->
stream
()
->
As
<
ep
::
CudaStream
>
()
->
cudnn_handle
(),
HIPDNN_BATCHNORM_SPATIAL_PERSISTENT
,
ops
,
CudnnSPOnePtr
(
data_type
),
CudnnSPZeroPtr
(
data_type
),
CudnnSPOnePtr
(
data_type
),
CudnnSPZeroPtr
(
data_type
),
desc_helper
.
xy_desc
(),
x
->
dptr
(),
desc_helper
.
xy_desc
(),
y
->
dptr
(),
desc_helper
.
xy_desc
(),
dy
->
dptr
(),
dz_desc
,
dz_ptr
,
desc_helper
.
xy_desc
(),
dx
->
mut_dptr
(),
desc_helper
.
param_desc
(),
gamma
->
dptr
(),
beta
->
dptr
(),
gamma_diff
->
mut_dptr
(),
beta_diff
->
mut_dptr
(),
epsilon
,
mean
->
dptr
(),
inv_variance
->
dptr
(),
activation_desc
.
Get
(),
tmp_buffer
->
mut_dptr
(),
workspace_size
,
const_cast
<
void
*>
(
reserve_space
->
dptr
()),
reserve_space_size
));
}
bool
AlwaysComputeWhenAllOutputsEmpty
()
const
override
{
return
false
;
}
};
REGISTER_USER_KERNEL
(
"cudnn_fused_normalization_add_relu_grad"
)
.
SetCreateFn
<
FusedNormalizationAddReluGradUserKernel
>
()
.
SetIsMatchedHob
((
user_op
::
HobDeviceType
()
==
DeviceType
::
kCUDA
))
.
SetInferTmpSizeFn
(
InferFusedNormalizationAddReluGradTmpSize
);
#endif
}
// namespace
}
// namespace
}
// namespace oneflow
}
// namespace oneflow
...
...
oneflow/user/ops/normalization_op.cpp
View file @
6046d8fb
...
@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
...
@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN
(
reserve_space_bits
%
split_num
,
0
);
CHECK_EQ_OR_RETURN
(
reserve_space_bits
%
split_num
,
0
);
reserve_space_bits
=
reserve_space_bits
/
split_num
;
reserve_space_bits
=
reserve_space_bits
/
split_num
;
}
}
#ifdef WITH_ROCM
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
64
)
/
64
)}));
#else
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
32
)
/
32
)}));
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
32
)
/
32
)}));
#endif
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
})(
ctx
);
}
}
...
@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
...
@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return
MakeFwTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
return
MakeFwTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
const
auto
&
x_desc
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
&
x_desc
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
#ifdef WITH_ROCM
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
64
)
/
64
)}));
#else
reserve_space
->
set_shape
(
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
32
)
/
32
)}));
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
32
)
/
32
)}));
#endif
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
})(
ctx
);
}
}
...
@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
...
@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */
Maybe
<
void
>
NormalizationAddReluOp
::
InferDataType
(
user_op
::
InferContext
*
ctx
)
{
/* static */
Maybe
<
void
>
NormalizationAddReluOp
::
InferDataType
(
user_op
::
InferContext
*
ctx
)
{
return
MakeFwDataTypeInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
return
MakeFwDataTypeInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
#ifdef WITH_ROCM
reserve_space
->
set_data_type
(
DataType
::
kInt64
);
#else
reserve_space
->
set_data_type
(
DataType
::
kInt32
);
reserve_space
->
set_data_type
(
DataType
::
kInt32
);
#endif
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
})(
ctx
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment