Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
6046d8fb
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "b3e6120736e45cc47ed96fe46c8cf418cb3d8cff"
Commit
6046d8fb
authored
Apr 25, 2023
by
yuguo960516yuguo
Browse files
dtk23.04
parent
a715222c
Changes
16
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
419 additions
and
1010 deletions
+419
-1010
cmake/oneflow.cmake
cmake/oneflow.cmake
+10
-10
cmake/third_party.cmake
cmake/third_party.cmake
+2
-2
oneflow/core/common/math_util.h
oneflow/core/common/math_util.h
+2
-2
oneflow/core/cuda/atomic.cuh
oneflow/core/cuda/atomic.cuh
+1
-1
oneflow/core/cuda/layer_norm.cuh
oneflow/core/cuda/layer_norm.cuh
+46
-4
oneflow/core/embedding/lru_cache.cu
oneflow/core/embedding/lru_cache.cu
+20
-7
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
...w/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
+1
-1
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
...e/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
+36
-36
oneflow/core/ep/cuda/primitive/unary_functor.cuh
oneflow/core/ep/cuda/primitive/unary_functor.cuh
+2
-2
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
+156
-25
oneflow/user/kernels/fused_gelu_mul_kernel.cu
oneflow/user/kernels/fused_gelu_mul_kernel.cu
+4
-4
oneflow/user/kernels/group_norm_kernel.cu
oneflow/user/kernels/group_norm_kernel.cu
+4
-0
oneflow/user/kernels/layer_norm_gpu_kernel.cu
oneflow/user/kernels/layer_norm_gpu_kernel.cu
+26
-647
oneflow/user/kernels/math_binary_elementwise_func.h
oneflow/user/kernels/math_binary_elementwise_func.h
+1
-1
oneflow/user/kernels/normalization_kernel.cu
oneflow/user/kernels/normalization_kernel.cu
+95
-268
oneflow/user/ops/normalization_op.cpp
oneflow/user/ops/normalization_op.cpp
+13
-0
No files found.
cmake/oneflow.cmake
View file @
6046d8fb
...
@@ -334,16 +334,16 @@ elseif(WIN32)
...
@@ -334,16 +334,16 @@ elseif(WIN32)
set
(
CMAKE_EXE_LINKER_FLAGS
"
${
CMAKE_EXE_LINKER_FLAGS
}
/WHOLEARCHIVE:oneflow"
)
set
(
CMAKE_EXE_LINKER_FLAGS
"
${
CMAKE_EXE_LINKER_FLAGS
}
/WHOLEARCHIVE:oneflow"
)
endif
()
endif
()
if
(
BUILD_ROCM
)
#
if (BUILD_ROCM)
# AMD compiler fails to compile these three files with '-O1/2/3'.
#
# AMD compiler fails to compile these three files with '-O1/2/3'.
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
#
# The value of `COMPILE_OPTIONS` target property is added after CMAKE_<LANG>_FLAGS_<CONFIG>,
# so '-O0' will override '-O1/2/3'.
#
# so '-O0' will override '-O1/2/3'.
set_source_files_properties
(
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
#
set_source_files_properties(${PROJECT_SOURCE_DIR}/oneflow/user/kernels/median_with_indices_kernel_hip.cpp
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
#
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/radix_sort_top_k_kernel_hip.cpp
${
PROJECT_SOURCE_DIR
}
/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#
${PROJECT_SOURCE_DIR}/oneflow/user/kernels/arg_sort_kernel_hip.cpp
#${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
#
#
${PROJECT_SOURCE_DIR}/oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1_hip.cpp
PROPERTIES COMPILE_OPTIONS
"-O0"
)
#
PROPERTIES COMPILE_OPTIONS "-O0")
endif
()
#
endif()
if
(
BUILD_CUDA
)
if
(
BUILD_CUDA
)
string
(
JOIN
","
CUDA_REAL_ARCHS
${
CUDA_REAL_ARCHS_LIST
}
)
string
(
JOIN
","
CUDA_REAL_ARCHS
${
CUDA_REAL_ARCHS_LIST
}
)
...
...
cmake/third_party.cmake
View file @
6046d8fb
...
@@ -186,7 +186,7 @@ if (BUILD_ROCM)
...
@@ -186,7 +186,7 @@ if (BUILD_ROCM)
if
(
BUILD_ROCM_GRAPHS
)
if
(
BUILD_ROCM_GRAPHS
)
add_definitions
(
-DWITH_ROCM_GRAPHS
)
add_definitions
(
-DWITH_ROCM_GRAPHS
)
endif
()
endif
()
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_C_FLAGS
"
${
CMAKE_C_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
-D__HIP_PLATFORM_HCC__ -D__HIPCC__"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--gpu-max-threads-per-block=1024"
)
set
(
CMAKE_CXX_FLAGS
"
${
CMAKE_CXX_FLAGS
}
--gpu-max-threads-per-block=1024"
)
...
@@ -204,7 +204,7 @@ if (BUILD_ROCM)
...
@@ -204,7 +204,7 @@ if (BUILD_ROCM)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_CXX_FLAGS_DEBUG
"
${
CMAKE_CXX_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
-mcmodel=large"
)
set
(
CMAKE_C_FLAGS_DEBUG
"
${
CMAKE_C_FLAGS_DEBUG
}
-mcmodel=large"
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs hip::device
)
list
(
APPEND oneflow_third_party_libs hip::hipfft
)
list
(
APPEND oneflow_third_party_libs hip::hipfft
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
list
(
APPEND oneflow_third_party_libs roc::hipblas
)
...
...
oneflow/core/common/math_util.h
View file @
6046d8fb
...
@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
...
@@ -27,7 +27,7 @@ int64_t Lcm(int64_t m, int64_t n);
template
<
typename
T
>
template
<
typename
T
>
OF_DEVICE_FUNC
T
DeviceMin
(
T
a
,
T
b
)
{
OF_DEVICE_FUNC
T
DeviceMin
(
T
a
,
T
b
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined(__HIP_DEVICE_COMPILE__)
return
a
<
b
?
a
:
b
;
return
a
<
b
?
a
:
b
;
#else
#else
return
std
::
min
(
a
,
b
);
return
std
::
min
(
a
,
b
);
...
@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
...
@@ -36,7 +36,7 @@ OF_DEVICE_FUNC T DeviceMin(T a, T b) {
template
<
typename
T
>
template
<
typename
T
>
OF_DEVICE_FUNC
T
DeviceMax
(
T
a
,
T
b
)
{
OF_DEVICE_FUNC
T
DeviceMax
(
T
a
,
T
b
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined(__HIP_DEVICE_COMPILE__)
return
a
>
b
?
a
:
b
;
return
a
>
b
?
a
:
b
;
#else
#else
return
std
::
max
(
a
,
b
);
return
std
::
max
(
a
,
b
);
...
...
oneflow/core/cuda/atomic.cuh
View file @
6046d8fb
...
@@ -54,7 +54,7 @@ struct CastCASImpl {
...
@@ -54,7 +54,7 @@ struct CastCASImpl {
}
}
};
};
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
#if __CUDA_ARCH__ < 700 || (defined(__clang__) && defined(__CUDA__))
|| defined(WITH_ROCM)
template
<
typename
T
>
template
<
typename
T
>
struct
CastCASImpl
<
T
,
unsigned
short
int
>
{
struct
CastCASImpl
<
T
,
unsigned
short
int
>
{
...
...
oneflow/core/cuda/layer_norm.cuh
View file @
6046d8fb
...
@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
...
@@ -308,6 +308,21 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
count_shared
[
wid
]
=
warp_count
;
count_shared
[
wid
]
=
warp_count
;
}
}
__syncthreads
();
__syncthreads
();
#ifdef WITH_ROCM
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
warp_mean
=
mean_shared
[
lid
];
warp_m2
=
m2_shared
[
lid
];
warp_count
=
count_shared
[
lid
];
}
else
{
warp_mean
=
static_cast
<
T
>
(
0
);
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
__syncthreads
();
if
(
wid
==
0
)
{
#else
if
(
wid
==
0
)
{
if
(
wid
==
0
)
{
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
if
(
threadIdx
.
x
<
blockDim
.
x
/
kWarpSize
)
{
warp_mean
=
mean_shared
[
lid
];
warp_mean
=
mean_shared
[
lid
];
...
@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
...
@@ -318,10 +333,7 @@ __inline__ __device__ void WelfordBlockAllReduce(T thread_mean, T thread_m2, T t
warp_m2
=
static_cast
<
T
>
(
0
);
warp_m2
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
warp_count
=
static_cast
<
T
>
(
0
);
}
}
#ifdef WITH_ROCM
__syncwarp
();
__syncthreads
();
#else
__syncwarp
();
#endif
#endif
T
block_mean
=
0
;
T
block_mean
=
0
;
T
block_m2
=
0
;
T
block_m2
=
0
;
...
@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
...
@@ -445,7 +457,11 @@ inline GPU(Error_t) LaunchLayerNormWarpImpl(GPU(Stream_t) stream, LOAD load, STO
const
double
epsilon
,
ComputeType
*
mean
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
ComputeType
*
inv_variance
)
{
constexpr
int
block_size
=
128
;
constexpr
int
block_size
=
128
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
constexpr
int
waves
=
32
;
#endif
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
...
@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
...
@@ -502,10 +518,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
} \
}
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
else if (cols <= (max_col)*kWarpSize) { \
...
@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
...
@@ -545,10 +567,16 @@ typename std::enable_if<pack_size == 2, GPU(Error_t)>::type DispatchLayerNormWar
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
stream, load, store, rows, cols, epsilon, mean, inv_variance); \
} \
} \
}
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
else if ((cols <= (max_col)*kWarpSize) && (cols > (min_col)*kWarpSize)) { \
...
@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
...
@@ -869,7 +897,11 @@ inline GPU(Error_t) LaunchLayerNormBlockUncachedImpl(GPU(Stream_t) stream, LOAD
const
double
epsilon
,
ComputeType
*
mean
,
const
double
epsilon
,
ComputeType
*
mean
,
ComputeType
*
inv_variance
)
{
ComputeType
*
inv_variance
)
{
constexpr
int
block_size
=
1024
;
constexpr
int
block_size
=
1024
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
constexpr
int
waves
=
32
;
#endif
int
grid_dim_x
;
int
grid_dim_x
;
{
{
GPU
(
Error_t
)
err
=
GPU
(
Error_t
)
err
=
...
@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
...
@@ -1080,7 +1112,11 @@ inline GPU(Error_t) LaunchLayerNormGradWarpImpl(GPU(Stream_t) stream, LOAD_X loa
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
ComputeType
*
inv_variance
,
const
int64_t
rows
,
const
int64_t
cols
)
{
const
int64_t
cols
)
{
constexpr
int
block_size
=
128
;
constexpr
int
block_size
=
128
;
#ifdef WITH_ROCM
constexpr
int
waves
=
64
;
#else
constexpr
int
waves
=
32
;
constexpr
int
waves
=
32
;
#endif
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
static_assert
(
block_size
%
thread_group_width
==
0
,
""
);
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
constexpr
int
thread_groups_per_block
=
block_size
/
thread_group_width
;
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
dim3
block_dim
(
thread_group_width
,
thread_groups_per_block
);
...
@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
...
@@ -1144,10 +1180,16 @@ typename std::enable_if<pack_size == 1, GPU(Error_t)>::type DispatchLayerNormGra
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
stream, load_x, load_scaled_dy, store, mean, inv_variance, rows, cols); \
} \
} \
}
}
#ifdef WITH_ROCM
DEFINE_ONE_ELIF
(
64
)
#else
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
4
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
8
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
16
)
DEFINE_ONE_ELIF
(
32
)
DEFINE_ONE_ELIF
(
32
)
#endif
#undef DEFINE_ONE_ELIF
#undef DEFINE_ONE_ELIF
#define DEFINE_ONE_ELIF(max_col, min_col) \
#define DEFINE_ONE_ELIF(max_col, min_col) \
else if (cols <= (max_col)*kWarpSize) { \
else if (cols <= (max_col)*kWarpSize) { \
...
...
oneflow/core/embedding/lru_cache.cu
View file @
6046d8fb
...
@@ -252,15 +252,21 @@ struct SetContext {
...
@@ -252,15 +252,21 @@ struct SetContext {
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
const
bool
lane_hit
=
(
lane_key
==
key
&&
lane_age
!=
0
);
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_hit
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_hit
);
if
(
hit_mask
!=
0
)
{
return
__ffsll
(
static_cast
<
unsigned
long
long
int
>
(
hit_mask
))
-
1
;
}
else
{
return
-
1
;
}
#else
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_hit
);
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_hit
);
#endif
if
(
hit_mask
!=
0
)
{
if
(
hit_mask
!=
0
)
{
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
return
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
}
else
{
}
else
{
return
-
1
;
return
-
1
;
}
}
#endif
}
}
__device__
void
Read
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
__device__
void
Read
(
const
LruCacheContext
<
Key
,
Elem
>&
cache_ctx
,
const
ThreadContext
&
thread_ctx
,
...
@@ -277,15 +283,16 @@ struct SetContext {
...
@@ -277,15 +283,16 @@ struct SetContext {
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
unsigned
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
const
unsigned
long
long
int
hit_mask
=
__ballot
(
lane_key
==
key
&&
lane_age
!=
0
);
#else
#else
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_key
==
key
&&
lane_age
!=
0
);
const
unsigned
hit_mask
=
__ballot_sync
(
kFullMask
,
lane_key
==
key
&&
lane_age
!=
0
);
#endif
#endif
if
(
hit_mask
!=
0
)
{
if
(
hit_mask
!=
0
)
{
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
#ifdef WITH_ROCM
#ifdef WITH_ROCM
insert_way
=
__ffsll
(
static_cast
<
unsigned
long
long
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
const
int
insert_way_age
=
__shfl
(
lane_age
,
insert_way
);
#else
#else
insert_way
=
__ffs
(
static_cast
<
int
>
(
hit_mask
))
-
1
;
const
int
insert_way_age
=
__shfl_sync
(
kFullMask
,
lane_age
,
insert_way
);
const
int
insert_way_age
=
__shfl_sync
(
kFullMask
,
lane_age
,
insert_way
);
#endif
#endif
if
(
lane_age
>
insert_way_age
)
{
if
(
lane_age
>
insert_way_age
)
{
...
@@ -301,12 +308,16 @@ __syncwarp();
...
@@ -301,12 +308,16 @@ __syncwarp();
}
}
if
(
insert_way
==
-
1
)
{
if
(
insert_way
==
-
1
)
{
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
unsigned
valid_mask
=
__ballot
(
lane_age
!=
0
);
const
unsigned
long
long
int
valid_mask
=
__ballot
(
lane_age
!=
0
);
#else
#else
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
const
unsigned
valid_mask
=
__ballot_sync
(
kFullMask
,
lane_age
!=
0
);
#endif
#endif
if
(
valid_mask
!=
kFullMask
)
{
if
(
valid_mask
!=
kFullMask
)
{
#ifdef WITH_ROCM
insert_way
=
__popcll
(
static_cast
<
unsigned
long
long
int
>
(
valid_mask
));
#else
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
insert_way
=
__popc
(
static_cast
<
int
>
(
valid_mask
));
#endif
if
(
lane_age
>
0
)
{
if
(
lane_age
>
0
)
{
lane_age
-=
1
;
lane_age
-=
1
;
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
}
else
if
(
thread_ctx
.
lane_id
==
insert_way
)
{
...
@@ -329,7 +340,8 @@ __syncwarp();
...
@@ -329,7 +340,8 @@ __syncwarp();
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
const
Key
lane_key
=
keys
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
int
lane_age
=
ages
[
thread_ctx
.
lane_id
];
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
int
insert_way
=
__ffs
(
static_cast
<
int
>
(
__ballot
(
lane_age
==
1
)))
-
1
;
unsigned
long
long
int
valid_mask_tmp
=
__ballot
(
lane_age
==
1
);
const
int
insert_way
=
__ffsll
(
valid_mask_tmp
)
-
1
;
#else
#else
const
int
insert_way
=
__ffs
(
__ballot_sync
(
kFullMask
,
lane_age
==
1
))
-
1
;
const
int
insert_way
=
__ffs
(
__ballot_sync
(
kFullMask
,
lane_age
==
1
))
-
1
;
#endif
#endif
...
@@ -594,7 +606,8 @@ __syncwarp();
...
@@ -594,7 +606,8 @@ __syncwarp();
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_keys
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_key
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
warp_ages
[
thread_ctx
.
warp_id_in_block
][
thread_ctx
.
lane_id
]
=
lane_age
;
#ifdef WITH_ROCM
#ifdef WITH_ROCM
const
int
key_count
=
__popc
(
static_cast
<
int
>
(
__ballot
(
lane_age
!=
0
)));
unsigned
long
long
int
valid_mask_tmp
=
__ballot
(
lane_age
!=
0
);
const
int
key_count
=
__popcll
(
valid_mask_tmp
);
#else
#else
const
int
key_count
=
__popc
(
__ballot_sync
(
kFullMask
,
lane_age
!=
0
));
const
int
key_count
=
__popc
(
__ballot_sync
(
kFullMask
,
lane_age
!=
0
));
#endif
#endif
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
View file @
6046d8fb
...
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive/
/
broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
View file @
6046d8fb
/*
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
See the License for the specific language governing permissions and
limitations under the License.
limitations under the License.
*/
*/
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace
oneflow
{
namespace
oneflow
{
namespace
ep
{
namespace
ep
{
namespace
primitive
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar
attr0
,
Scalar
attr1
);
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ_1
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
BINARY_MATH_OP_SEQ_1
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace primitive
}
// namespace ep
}
// namespace ep
}
// namespace oneflow
}
// namespace oneflow
oneflow/core/ep/cuda/primitive/unary_functor.cuh
View file @
6046d8fb
...
@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
...
@@ -94,7 +94,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
OF_DEVICE_FUNC
UnaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
OF_DEVICE_FUNC
half
operator
()(
half
src
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
float
tanh_in
=
const
float
tanh_in
=
__half2float
(
__float2half_rn
(
alpha
)
*
(
src
+
__float2half_rn
(
beta
)
*
src
*
src
*
src
));
__half2float
(
__float2half_rn
(
alpha
)
*
(
src
+
__float2half_rn
(
beta
)
*
src
*
src
*
src
));
const
float
tanh_out
=
unary_functor_internal
::
TanhApprox
(
tanh_in
);
const
float
tanh_out
=
unary_functor_internal
::
TanhApprox
(
tanh_in
);
...
@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
...
@@ -104,7 +104,7 @@ struct UnaryFunctor<DeviceType::kCUDA, UnaryOp::kFastGelu, half, half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
dst
,
const
half
*
src
)
const
{
__device__
void
Apply2
(
half
*
dst
,
const
half
*
src
)
const
{
const
half2
src2
=
*
(
reinterpret_cast
<
const
half2
*>
(
src
));
const
half2
src2
=
*
(
reinterpret_cast
<
const
half2
*>
(
src
));
const
float2
tanh_in
=
__half22float2
(
__hmul2
(
const
float2
tanh_in
=
__half22float2
(
__hmul2
(
...
...
oneflow/core/job_rewriter/auto_mixed_precision_lists.cpp
View file @
6046d8fb
...
@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
...
@@ -64,25 +64,21 @@ const AMPList& AutoMixedPrecisionLists::BlackList() {
"reduce_sum_like"
,
"reduce_sum_like"
,
"layer_norm_grad"
,
"layer_norm_grad"
,
"layer_norm"
,
"layer_norm"
,
"layer_norm_param_grad"
"layer_norm_param_grad"
,
};
return
black_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
GrayList
()
{
"add_n"
,
static
AMPList
gray_list
=
{
"add_n"
,
"tf_avg_pool_1d"
,
"tf_avg_pool_1d"
,
"tf_avg_pool_1d_grad"
,
"tf_avg_pool_1d_grad"
,
"tf_avg_pool_2d"
,
"tf_avg_pool_2d"
,
"tf_avg_pool_2d_grad"
,
"tf_avg_pool_2d_grad"
,
"tf_avg_pool_3d"
,
"tf_avg_pool_3d"
,
"tf_avg_pool_3d_grad"
,
"tf_avg_pool_3d_grad"
,
"avg_pool_1d"
,
//
"avg_pool_1d",
"avg_pool_1d_grad"
,
//
"avg_pool_1d_grad",
"avg_pool_2d"
,
//
"avg_pool_2d",
"avg_pool_2d_grad"
,
//
"avg_pool_2d_grad",
"avg_pool_3d"
,
//
"avg_pool_3d",
"avg_pool_3d_grad"
,
//
"avg_pool_3d_grad",
"bias_add"
,
"bias_add"
,
"sigmoid_grad"
,
"sigmoid_grad"
,
"tanh"
,
"tanh"
,
...
@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
...
@@ -139,13 +135,9 @@ const AMPList& AutoMixedPrecisionLists::GrayList() {
"group_norm_grad"
,
"group_norm_grad"
,
"silu"
,
"silu"
,
"silu_grad"
,
"silu_grad"
,
"fused_weighted_sum"
};
"fused_weighted_sum"
,
return
gray_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
ClearList
()
{
"broadcast_like"
,
// TODO(niuchong): tuple_identity
static
AMPList
clear_list
=
{
"broadcast_like"
,
"gather"
,
"gather"
,
"gather_nd"
,
"gather_nd"
,
"scatter_nd"
,
"scatter_nd"
,
...
@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
...
@@ -157,12 +149,12 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"tf_max_pool_2d_grad"
,
"tf_max_pool_2d_grad"
,
"tf_max_pool_3d"
,
"tf_max_pool_3d"
,
"tf_max_pool_3d_grad"
,
"tf_max_pool_3d_grad"
,
"max_pool_1d"
,
//
"max_pool_1d",
"max_pool_1d_grad"
,
//
"max_pool_1d_grad",
"max_pool_2d"
,
//
"max_pool_2d",
"max_pool_2d_grad"
,
//
"max_pool_2d_grad",
"max_pool_3d"
,
//
"max_pool_3d",
"max_pool_3d_grad"
,
//
"max_pool_3d_grad",
"reshape"
,
"reshape"
,
"reshape_like"
,
"reshape_like"
,
"relu"
,
"relu"
,
...
@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
...
@@ -199,7 +191,146 @@ const AMPList& AutoMixedPrecisionLists::ClearList() {
"pinned_identity"
,
"pinned_identity"
,
"to_contiguous"
,
"to_contiguous"
,
"copy"
,
"copy"
,
"upsample_nearest_2d"
};
"upsample_nearest_2d"
};
return
black_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
GrayList
()
{
static
AMPList
gray_list
=
{
// "add_n",
// "tf_avg_pool_1d",
// "tf_avg_pool_1d_grad",
// "tf_avg_pool_2d",
// "tf_avg_pool_2d_grad",
// "tf_avg_pool_3d",
// "tf_avg_pool_3d_grad",
"avg_pool_1d"
,
"avg_pool_1d_grad"
,
"avg_pool_2d"
,
"avg_pool_2d_grad"
,
"avg_pool_3d"
,
"avg_pool_3d_grad"
// "bias_add",
// "sigmoid_grad",
// "tanh",
// "tanh_grad",
// "sqrt",
// "sqrt_grad",
// "scalar_mul",
// "scalar_mul_by_tensor",
// "scalar_add",
// "scalar_div",
// "scalar_pow",
// "broadcast_add",
// "broadcast_sub",
// "broadcast_mul",
// "broadcast_div",
// "rms_norm",
// "rms_norm_grad",
// "rms_norm_param_grad",
// "dropout",
// "dropout_grad",
// "softmax",
// "softmax_grad",
// "log_softmax",
// "log_softmax_grad",
// "gelu",
// "gelu_grad",
// "fast_gelu",
// "fast_gelu_grad",
// "normalization",
// "normalization_grad",
// "normalization_add_relu",
// "normalization_add_relu_grad",
// "sparse_softmax_cross_entropy",
// "sparse_softmax_cross_entropy_grad",
// "nll",
// "nll_grad",
// "fused_tril_scale_softmax_mask_scale",
// "fused_tril_scale_softmax_mask_scale_grad",
// "fused_scale_mask_softmax_dropout",
// "fused_scale_mask_softmax_dropout_grad",
// "fused_scale_mask_softmax",
// "fused_scale_mask_softmax_grad",
// "fused_bias_add_scale_mask_softmax_dropout",
// "fused_bias_add_gelu",
// "fused_bias_add_gelu_grad",
// "fused_bias_add_mask_scale",
// "fused_fast_gelu_mul",
// "fused_fast_gelu_mul_grad",
// "acc",
// "reciprocal",
// "reciprocal_no_nan",
// "group_norm",
// "group_norm_param_grad",
// "group_norm_grad",
// "silu",
// "silu_grad",
// "fused_weighted_sum"
};
return
gray_list
;
}
const
AMPList
&
AutoMixedPrecisionLists
::
ClearList
()
{
// TODO(niuchong): tuple_identity
static
AMPList
clear_list
=
{
// "broadcast_like",
// "gather",
// "gather_nd",
// "scatter_nd",
// "scatter_nd_like",
// "unsorted_segment_sum_like",
// "tf_max_pool_1d",
// "tf_max_pool_1d_grad",
// "tf_max_pool_2d",
// "tf_max_pool_2d_grad",
// "tf_max_pool_3d",
// "tf_max_pool_3d_grad",
"max_pool_1d"
,
"max_pool_1d_grad"
,
"max_pool_2d"
,
"max_pool_2d_grad"
,
"max_pool_3d"
,
"max_pool_3d_grad"
// "reshape",
// "reshape_like",
// "relu",
// "relu_grad",
// "transpose",
// "random_mask_like",
// "concat",
// "split_like",
// "pad",
// "same_padding",
// "same_padding_grad",
// "tril",
// "slice",
// "slice_grad",
// "fused_scale_tril",
// "identity",
// "squeeze",
// "embedding",
// "embedding_grad",
// "expand",
// "expand_dims",
// "cast_to_static_shape",
// "parallel_cast",
// "hierarchical_parallel_cast",
// "hierarchical_parallel_cast_like",
// "repeat",
// "unpack",
// "pack",
// "nvtx_start",
// "nvtx_end",
// "narrow",
// "narrow_grad",
// "ones_like",
// "pinned_identity",
// "to_contiguous",
// "copy",
// "upsample_nearest_2d"
};
return
clear_list
;
return
clear_list
;
}
}
...
...
oneflow/user/kernels/fused_gelu_mul_kernel.cu
View file @
6046d8fb
...
@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
...
@@ -57,7 +57,7 @@ struct FusedFastGeluMulFunctor<half> {
OF_DEVICE_FUNC
FusedFastGeluMulFunctor
()
{}
OF_DEVICE_FUNC
FusedFastGeluMulFunctor
()
{}
OF_DEVICE_FUNC
half
operator
()(
const
half
x
,
const
half
m
)
const
{
OF_DEVICE_FUNC
half
operator
()(
const
half
x
,
const
half
m
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
float
tanh_in
=
const
float
tanh_in
=
__half2float
(
__float2half_rn
(
alpha
)
*
(
x
+
__float2half_rn
(
beta
)
*
x
*
x
*
x
));
__half2float
(
__float2half_rn
(
alpha
)
*
(
x
+
__float2half_rn
(
beta
)
*
x
*
x
*
x
));
const
float
tanh_out
=
TanhApprox
(
tanh_in
);
const
float
tanh_out
=
TanhApprox
(
tanh_in
);
...
@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
...
@@ -67,7 +67,7 @@ struct FusedFastGeluMulFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
y
,
const
half
*
x
,
const
half
*
m
)
const
{
__device__
void
Apply2
(
half
*
y
,
const
half
*
x
,
const
half
*
m
)
const
{
const
half2
x2
=
*
(
reinterpret_cast
<
const
half2
*>
(
x
));
const
half2
x2
=
*
(
reinterpret_cast
<
const
half2
*>
(
x
));
const
float2
tanh_in
=
__half22float2
(
const
float2
tanh_in
=
__half22float2
(
...
@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
...
@@ -170,7 +170,7 @@ struct FusedFastGeluMulGradFunctor<half> {
__device__
void
operator
()(
half
&
x_diff
,
half
&
m_diff
,
const
half
&
dy
,
const
half
&
x
,
__device__
void
operator
()(
half
&
x_diff
,
half
&
m_diff
,
const
half
&
dy
,
const
half
&
x
,
const
half
&
m
)
const
{
const
half
&
m
)
const
{
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
const
half
halpha
=
__float2half_rn
(
alpha
);
const
half
halpha
=
__float2half_rn
(
alpha
);
const
half
hbeta
=
__float2half_rn
(
beta
);
const
half
hbeta
=
__float2half_rn
(
beta
);
const
half
hone
=
__float2half_rn
(
1.0
F
);
const
half
hone
=
__float2half_rn
(
1.0
F
);
...
@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
...
@@ -193,7 +193,7 @@ struct FusedFastGeluMulGradFunctor<half> {
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#endif // (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
}
}
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
#if (__CUDA_ARCH__ >= 750 && CUDA_VERSION >= 11000)
|| defined(WITH_ROCM)
__device__
void
Apply2
(
half
*
x_diff
,
half
*
m_diff
,
const
half
*
dy
,
const
half
*
x
,
__device__
void
Apply2
(
half
*
x_diff
,
half
*
m_diff
,
const
half
*
dy
,
const
half
*
x
,
const
half
*
m
)
const
{
const
half
*
m
)
const
{
const
half2
dy2
=
*
(
reinterpret_cast
<
const
half2
*>
(
dy
));
const
half2
dy2
=
*
(
reinterpret_cast
<
const
half2
*>
(
dy
));
...
...
oneflow/user/kernels/group_norm_kernel.cu
View file @
6046d8fb
...
@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
...
@@ -499,7 +499,11 @@ REGISTER_GROUP_NORM_GRAD_CUDA_KERNEL(nv_bfloat16)
constexpr
int
kReduceBlockSize
=
512
;
constexpr
int
kReduceBlockSize
=
512
;
constexpr
int
kBlockSize
=
128
;
constexpr
int
kBlockSize
=
128
;
#ifdef WITH_ROCM
constexpr
int
kNumWaves
=
64
;
#else
constexpr
int
kNumWaves
=
32
;
constexpr
int
kNumWaves
=
32
;
#endif
inline
GPU
(
Error_t
)
GetReduceNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
inline
GPU
(
Error_t
)
GetReduceNumBlocks
(
int64_t
n
,
int
*
num_blocks
)
{
int
dev
;
int
dev
;
...
...
oneflow/user/kernels/layer_norm_gpu_kernel.cu
View file @
6046d8fb
This diff is collapsed.
Click to expand it.
oneflow/user/kernels/math_binary_elementwise_func.h
View file @
6046d8fb
...
@@ -28,7 +28,7 @@ limitations under the License.
...
@@ -28,7 +28,7 @@ limitations under the License.
#elif defined(__HIPCC__)
#elif defined(__HIPCC__)
#include <hip/
hsa_detail/
math_functions.h>
#include <hip/math_functions.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
#if defined(__HIP_DEVICE_COMPILE__)
#if defined(__HIP_DEVICE_COMPILE__)
...
...
oneflow/user/kernels/normalization_kernel.cu
View file @
6046d8fb
This diff is collapsed.
Click to expand it.
oneflow/user/ops/normalization_op.cpp
View file @
6046d8fb
...
@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
...
@@ -274,7 +274,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
CHECK_EQ_OR_RETURN
(
reserve_space_bits
%
split_num
,
0
);
CHECK_EQ_OR_RETURN
(
reserve_space_bits
%
split_num
,
0
);
reserve_space_bits
=
reserve_space_bits
/
split_num
;
reserve_space_bits
=
reserve_space_bits
/
split_num
;
}
}
#ifdef WITH_ROCM
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
64
)
/
64
)}));
#else
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
32
)
/
32
)}));
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
reserve_space_bits
,
32
)
/
32
)}));
#endif
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
})(
ctx
);
}
}
...
@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
...
@@ -284,8 +288,13 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
return
MakeFwTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
return
MakeFwTensorDescInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
const
auto
&
x_desc
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
const
auto
&
x_desc
=
ctx
->
InputTensorDesc
(
"x"
,
0
);
#ifdef WITH_ROCM
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
64
)
/
64
)}));
#else
reserve_space
->
set_shape
(
reserve_space
->
set_shape
(
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
32
)
/
32
)}));
Shape
({
static_cast
<
int64_t
>
(
RoundUp
(
x_desc
.
shape
().
elem_cnt
(),
32
)
/
32
)}));
#endif
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
})(
ctx
);
}
}
...
@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
...
@@ -302,7 +311,11 @@ user_op::DataTypeInferFn MakeFwDataTypeInferFn() {
/* static */
Maybe
<
void
>
NormalizationAddReluOp
::
InferDataType
(
user_op
::
InferContext
*
ctx
)
{
/* static */
Maybe
<
void
>
NormalizationAddReluOp
::
InferDataType
(
user_op
::
InferContext
*
ctx
)
{
return
MakeFwDataTypeInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
return
MakeFwDataTypeInferFn
([](
user_op
::
InferContext
*
ctx
,
const
user_op
::
TensorDesc
*
x
,
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
user_op
::
TensorDesc
*
reserve_space
)
->
Maybe
<
void
>
{
#ifdef WITH_ROCM
reserve_space
->
set_data_type
(
DataType
::
kInt64
);
#else
reserve_space
->
set_data_type
(
DataType
::
kInt32
);
reserve_space
->
set_data_type
(
DataType
::
kInt32
);
#endif
return
Maybe
<
void
>::
Ok
();
return
Maybe
<
void
>::
Ok
();
})(
ctx
);
})(
ctx
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment