Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1678 additions
and
31 deletions
+1678
-31
oneflow/core/ep/cuda/cuda_device_manager.cpp
oneflow/core/ep/cuda/cuda_device_manager.cpp
+52
-0
oneflow/core/ep/cuda/cuda_device_manager.h
oneflow/core/ep/cuda/cuda_device_manager.h
+32
-0
oneflow/core/ep/cuda/cuda_device_manager_factory.cpp
oneflow/core/ep/cuda/cuda_device_manager_factory.cpp
+99
-0
oneflow/core/ep/cuda/cuda_event.cpp
oneflow/core/ep/cuda/cuda_event.cpp
+41
-0
oneflow/core/ep/cuda/cuda_event.h
oneflow/core/ep/cuda/cuda_event.h
+29
-0
oneflow/core/ep/cuda/cuda_stream.cpp
oneflow/core/ep/cuda/cuda_stream.cpp
+253
-1
oneflow/core/ep/cuda/cuda_stream.h
oneflow/core/ep/cuda/cuda_stream.h
+136
-0
oneflow/core/ep/cuda/primitive/add.cu
oneflow/core/ep/cuda/primitive/add.cu
+5
-5
oneflow/core/ep/cuda/primitive/binary_functor.cuh
oneflow/core/ep/cuda/primitive/binary_functor.cuh
+304
-13
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu
...ow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu
+5
-1
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
...w/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
+3
-3
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.cu
...re/ep/cuda/primitive/broadcast_elementwise_binary_math.cu
+1
-1
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
...e/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
+36
-0
oneflow/core/ep/cuda/primitive/broadcast_elementwise_unary.cu
...low/core/ep/cuda/primitive/broadcast_elementwise_unary.cu
+420
-0
oneflow/core/ep/cuda/primitive/broadcast_matmul.cpp
oneflow/core/ep/cuda/primitive/broadcast_matmul.cpp
+208
-1
oneflow/core/ep/cuda/primitive/constant_pad.cu
oneflow/core/ep/cuda/primitive/constant_pad.cu
+5
-0
oneflow/core/ep/cuda/primitive/copy_nd.cu
oneflow/core/ep/cuda/primitive/copy_nd.cu
+5
-2
oneflow/core/ep/cuda/primitive/elementwise_unary.cu
oneflow/core/ep/cuda/primitive/elementwise_unary.cu
+4
-0
oneflow/core/ep/cuda/primitive/fill.cu
oneflow/core/ep/cuda/primitive/fill.cu
+4
-4
oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad.cu
...ore/ep/cuda/primitive/math_elementwise_unary_math_grad.cu
+36
-0
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/ep/cuda/cuda_device_manager.cpp
View file @
a715222c
...
...
@@ -66,3 +66,55 @@ void CudaDeviceManager::SetActiveDeviceByIndex(size_t device_index) {
}
// namespace oneflow
#endif // WITH_CUDA
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
CudaDeviceManager
::
CudaDeviceManager
(
DeviceManagerRegistry
*
registry
)
:
registry_
(
registry
)
{}
CudaDeviceManager
::~
CudaDeviceManager
()
=
default
;
DeviceManagerRegistry
*
CudaDeviceManager
::
registry
()
const
{
return
registry_
;
}
std
::
shared_ptr
<
Device
>
CudaDeviceManager
::
GetDevice
(
size_t
device_index
)
{
std
::
lock_guard
<
std
::
mutex
>
lock
(
devices_mutex_
);
if
(
device_index
<
devices_
.
size
()
&&
devices_
.
at
(
device_index
))
{
return
devices_
.
at
(
device_index
);
}
auto
device
=
std
::
make_shared
<
CudaDevice
>
(
device_index
,
this
);
if
(
device_index
>=
devices_
.
size
())
{
devices_
.
resize
(
device_index
+
1
);
}
devices_
.
at
(
device_index
)
=
device
;
return
device
;
}
size_t
CudaDeviceManager
::
GetDeviceCount
(
size_t
primary_device_index
)
{
CudaCurrentDeviceGuard
guard
(
primary_device_index
);
return
this
->
GetDeviceCount
();
}
size_t
CudaDeviceManager
::
GetDeviceCount
()
{
int
count
=
0
;
hipError_t
err
=
hipGetDeviceCount
(
&
count
);
if
(
err
==
hipErrorNoDevice
||
err
==
hipErrorInsufficientDriver
)
{
return
0
;
}
OF_CUDA_CHECK
(
err
);
return
count
;
}
size_t
CudaDeviceManager
::
GetActiveDeviceIndex
()
{
int
device
=
0
;
OF_CUDA_CHECK
(
hipGetDevice
(
&
device
));
return
static_cast
<
size_t
>
(
device
);
}
void
CudaDeviceManager
::
SetActiveDeviceByIndex
(
size_t
device_index
)
{
OF_CUDA_CHECK
(
hipSetDevice
(
static_cast
<
int
>
(
device_index
)));
}
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/cuda/cuda_device_manager.h
View file @
a715222c
...
...
@@ -50,4 +50,36 @@ class CudaDeviceManager : public DeviceManager {
#endif // WITH_CUDA
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
class
CudaDevice
;
class
CudaDeviceManager
:
public
DeviceManager
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaDeviceManager
);
CudaDeviceManager
(
DeviceManagerRegistry
*
registry
);
~
CudaDeviceManager
()
override
;
DeviceManagerRegistry
*
registry
()
const
override
;
std
::
shared_ptr
<
Device
>
GetDevice
(
size_t
device_index
)
override
;
size_t
GetDeviceCount
(
size_t
primary_device_index
)
override
;
size_t
GetDeviceCount
()
override
;
size_t
GetActiveDeviceIndex
()
override
;
void
SetActiveDeviceByIndex
(
size_t
device_index
)
override
;
private:
std
::
mutex
devices_mutex_
;
std
::
vector
<
std
::
shared_ptr
<
CudaDevice
>>
devices_
;
DeviceManagerRegistry
*
registry_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_CUDA_DEVICE_MANAGER_H_
oneflow/core/ep/cuda/cuda_device_manager_factory.cpp
View file @
a715222c
...
...
@@ -117,3 +117,102 @@ COMMAND(DeviceManagerRegistry::RegisterDeviceManagerFactory(
}
// namespace oneflow
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rccl.h>
namespace
oneflow
{
namespace
ep
{
namespace
{
std
::
string
GetCudaVersionString
(
int
version
)
{
return
std
::
to_string
(
version
/
1000
)
+
"."
+
std
::
to_string
((
version
%
1000
)
/
10
);
}
bool
GetCudnnVersion
(
size_t
*
major
,
size_t
*
minor
,
size_t
*
patch
)
{
miopenStatus_t
status
=
miopenGetVersion
(
major
,
minor
,
patch
);
if
(
status
==
miopenStatusSuccess
)
{
return
true
;
}
else
{
LOG
(
ERROR
)
<<
"Failed to get cuDNN version: "
<<
miopenGetErrorString
(
status
);
return
false
;
}
}
bool
GetCudnnVersionString
(
std
::
string
*
version
)
{
size_t
version_major
=
0
;
size_t
version_minor
=
0
;
size_t
version_patch
=
0
;
if
(
!
GetCudnnVersion
(
&
version_major
,
&
version_minor
,
&
version_patch
))
{
return
false
;
}
*
version
=
std
::
to_string
(
version_major
)
+
"."
+
std
::
to_string
(
version_minor
)
+
"."
+
std
::
to_string
(
version_patch
);
return
true
;
}
void
CudaDumpVersionInfo
()
{
{
int
cuda_runtime_version
=
0
;
hipError_t
err
=
hipRuntimeGetVersion
(
&
cuda_runtime_version
);
if
(
err
==
hipSuccess
)
{
LOG
(
INFO
)
<<
"CUDA runtime version: "
<<
GetCudaVersionString
(
cuda_runtime_version
);
}
else
{
LOG
(
ERROR
)
<<
"Failed to get cuda runtime version: "
<<
hipGetErrorString
(
err
);
}
}
{
std
::
string
cudnn_version_string
;
if
(
GetCudnnVersionString
(
&
cudnn_version_string
))
{
LOG
(
INFO
)
<<
"cuDNN version: "
<<
cudnn_version_string
;
}
}
{
int
nccl_version
=
0
;
ncclResult_t
result
=
ncclGetVersion
(
&
nccl_version
);
if
(
result
==
ncclSuccess
)
{
int
nccl_version_major
=
(
nccl_version
>=
20900
)
?
(
nccl_version
/
10000
)
:
(
nccl_version
/
1000
);
int
nccl_version_minor
=
(
nccl_version
>=
20900
)
?
(
nccl_version
%
10000
)
/
100
:
(
nccl_version
%
1000
)
/
100
;
int
nccl_version_patch
=
(
nccl_version
%
100
);
LOG
(
INFO
)
<<
"NCCL version: "
<<
nccl_version_major
<<
"."
<<
nccl_version_minor
<<
"."
<<
nccl_version_patch
;
}
else
{
LOG
(
ERROR
)
<<
"Failed to get NCCL version: "
<<
ncclGetErrorString
(
result
);
}
}
}
class
CudaDeviceManagerFactory
:
public
DeviceManagerFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaDeviceManagerFactory
);
CudaDeviceManagerFactory
()
=
default
;
~
CudaDeviceManagerFactory
()
override
=
default
;
std
::
unique_ptr
<
DeviceManager
>
NewDeviceManager
(
DeviceManagerRegistry
*
registry
)
override
{
return
std
::
make_unique
<
CudaDeviceManager
>
(
registry
);
}
DeviceType
device_type
()
const
override
{
return
DeviceType
::
kCUDA
;
}
std
::
string
device_type_name
()
const
override
{
return
"cuda"
;
}
void
DumpVersionInfo
()
const
override
{
CudaDumpVersionInfo
();
}
};
COMMAND
(
DeviceManagerRegistry
::
RegisterDeviceManagerFactory
(
std
::
make_unique
<
CudaDeviceManagerFactory
>
()))
}
// namespace
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/cuda/cuda_event.cpp
View file @
a715222c
...
...
@@ -54,3 +54,44 @@ cudaEvent_t CudaEvent::cuda_event() { return cuda_event_; }
}
// namespace oneflow
#endif // WITH_CUDA
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
CudaEvent
::
CudaEvent
(
unsigned
int
flags
)
:
cuda_event_
{}
{
OF_CUDA_CHECK
(
hipEventCreateWithFlags
(
&
cuda_event_
,
flags
));
}
CudaEvent
::~
CudaEvent
()
{
OF_CUDA_CHECK
(
hipEventDestroy
(
cuda_event_
));
}
Maybe
<
bool
>
CudaEvent
::
QueryDone
()
{
hipError_t
err
=
hipEventQuery
(
cuda_event_
);
if
(
err
==
hipSuccess
)
{
return
Maybe
<
bool
>
(
true
);
}
else
if
(
err
==
hipErrorNotReady
)
{
return
Maybe
<
bool
>
(
false
);
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
);
}
}
Maybe
<
void
>
CudaEvent
::
Sync
()
{
hipError_t
err
=
hipEventSynchronize
(
cuda_event_
);
if
(
err
==
hipSuccess
)
{
return
Maybe
<
void
>::
Ok
();
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
);
}
}
hipEvent_t
CudaEvent
::
cuda_event
()
{
return
cuda_event_
;
}
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/cuda/cuda_event.h
View file @
a715222c
...
...
@@ -47,4 +47,33 @@ class CudaEvent : public Event {
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include "oneflow/core/device/cuda_util.h"
namespace
oneflow
{
namespace
ep
{
class
CudaEvent
:
public
Event
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaEvent
);
explicit
CudaEvent
(
unsigned
int
flags
);
~
CudaEvent
()
override
;
Maybe
<
bool
>
QueryDone
()
override
;
Maybe
<
void
>
Sync
()
override
;
hipEvent_t
cuda_event
();
private:
hipEvent_t
cuda_event_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_CUDA_EVENT_H_
oneflow/core/ep/cuda/cuda_stream.cpp
View file @
a715222c
...
...
@@ -42,6 +42,61 @@ void SetAffinityByDevice(int dev_id) {
node_device_desc
->
Topology
()
->
SetMemoryAffinityByPCIBusID
(
cuda_device
->
PCIBusID
());
}
void
CheckVersionCompatibility
(
int
compiletime_major
,
int
compiletime_minor
,
int
runtime_major
,
int
runtime_minor
,
const
std
::
string
&
name
)
{
if
(
runtime_major
!=
compiletime_major
||
runtime_minor
<
compiletime_minor
)
{
LOG
(
WARNING
)
<<
"Runtime version "
<<
runtime_major
<<
"."
<<
runtime_minor
<<
" of "
<<
name
<<
" incompatible with compiletime version "
<<
compiletime_major
<<
"."
<<
compiletime_minor
<<
"."
;
}
}
void
CheckCudaRuntimeVersion
()
{
#if !defined(CUDART_VERSION)
#error
#endif // !defined(CUDART_VERSION)
const
int
compiletime_major
=
CUDART_VERSION
/
1000
;
const
int
compiletime_minor
=
CUDART_VERSION
%
1000
/
10
;
int
runtime_version
=
0
;
OF_CUDA_CHECK
(
cudaRuntimeGetVersion
(
&
runtime_version
));
const
int
runtime_major
=
runtime_version
/
1000
;
const
int
runtime_minor
=
runtime_version
%
1000
/
10
;
CheckVersionCompatibility
(
compiletime_major
,
compiletime_minor
,
runtime_major
,
runtime_minor
,
"CUDA Runtime"
);
}
void
CheckCublasVersion
(
cublasHandle_t
handle
)
{
#if CUDA_VERSION >= 10020
#if (!defined(CUBLAS_VER_MAJOR)) || (!defined(CUBLAS_VER_MINOR))
#error
#endif // (!defined(CUBLAS_VER_MAJOR)) || (!defined(CUBLAS_VER_MINOR))
int
runtime_version
=
0
;
OF_CUBLAS_CHECK
(
cublasGetVersion
(
handle
,
&
runtime_version
));
int
runtime_major
=
0
;
int
runtime_minor
=
0
;
if
(
runtime_version
>=
100000
)
{
runtime_major
=
runtime_version
/
10000
;
runtime_minor
=
runtime_version
%
10000
/
100
;
}
else
{
runtime_major
=
runtime_version
/
1000
;
runtime_minor
=
runtime_version
%
1000
/
100
;
}
CheckVersionCompatibility
(
CUBLAS_VER_MAJOR
,
CUBLAS_VER_MINOR
,
runtime_major
,
runtime_minor
,
"cuBLAS"
);
#endif // CUDA_VERSION >= 10020
}
void
CheckCudnnVersion
()
{
#if (!defined(CUDNN_MAJOR)) || (!defined(CUDNN_MINOR))
#error
#endif // (!defined(CUDNN_MAJOR)) || (!defined(CUDNN_MINOR))
int
runtime_major
=
0
;
int
runtime_minor
=
0
;
OF_CUDNN_CHECK
(
cudnnGetProperty
(
libraryPropertyType
::
MAJOR_VERSION
,
&
runtime_major
));
OF_CUDNN_CHECK
(
cudnnGetProperty
(
libraryPropertyType
::
MINOR_VERSION
,
&
runtime_minor
));
CheckVersionCompatibility
(
CUDNN_MAJOR
,
CUDNN_MINOR
,
runtime_major
,
runtime_minor
,
"cuDNN"
);
}
}
// namespace
#ifdef WITH_CUDA_GRAPHS
...
...
@@ -83,11 +138,26 @@ void CudaGraphExecutable::Reset() {
CudaStream
::
CudaStream
(
CudaDevice
*
device
)
:
device_index_
(
device
->
device_index
()),
device_
(
device
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
const
bool
need_check_version
=
[]()
{
static
std
::
atomic
<
bool
>
version_checked
(
false
);
return
version_checked
.
exchange
(
true
)
==
false
;
}();
if
(
need_check_version
)
{
CheckCudaRuntimeVersion
();
}
// cuda_stream
OF_CUDA_CHECK
(
cudaStreamCreate
(
&
cuda_stream_
));
const
char
*
stream_flags_env_name
=
"ONEFLOW_EP_CUDA_STREAM_FLAGS"
;
if
(
std
::
getenv
(
stream_flags_env_name
)
!=
nullptr
)
{
const
unsigned
int
stream_flags
=
ParseIntegerFromEnv
(
stream_flags_env_name
,
0
);
OF_CUDA_CHECK
(
cudaStreamCreateWithFlags
(
&
cuda_stream_
,
stream_flags
));
}
else
{
OF_CUDA_CHECK
(
cudaStreamCreate
(
&
cuda_stream_
));
}
// cublas_handle
OF_CUBLAS_CHECK
(
cublasCreate
(
&
cublas_handle_
));
OF_CUBLAS_CHECK
(
cublasSetStream
(
cublas_handle_
,
cuda_stream_
));
if
(
need_check_version
)
{
CheckCublasVersion
(
cublas_handle_
);
}
#if CUDA_VERSION >= 10010
// cublas_lt_handle
OF_CUBLAS_CHECK
(
cublasLtCreate
(
&
cublas_lt_handle_
));
...
...
@@ -107,6 +177,7 @@ CudaStream::CudaStream(CudaDevice* device)
// cudnn_handle
OF_CUDNN_CHECK
(
cudnnCreate
(
&
cudnn_handle_
));
OF_CUDNN_CHECK
(
cudnnSetStream
(
cudnn_handle_
,
cuda_stream_
));
if
(
need_check_version
)
{
CheckCudnnVersion
();
}
}
CudaStream
::~
CudaStream
()
{
...
...
@@ -147,6 +218,15 @@ void CudaStream::RecordEvent(Event* event) {
OF_CUDA_CHECK
(
cudaEventRecord
(
cuda_event
->
cuda_event
(),
cuda_stream_
));
}
Maybe
<
void
>
CudaStream
::
GetAsyncError
()
{
cudaError_t
err
=
cudaGetLastError
();
if
(
err
==
cudaSuccess
)
{
return
Maybe
<
void
>::
Ok
();
}
else
{
return
Error
::
RuntimeError
()
<<
cudaGetErrorString
(
err
)
<<
" ("
<<
err
<<
") "
;
}
}
cudaStream_t
CudaStream
::
cuda_stream
()
const
{
return
cuda_stream_
;
}
cublasHandle_t
CudaStream
::
cublas_handle
()
const
{
return
cublas_handle_
;
}
...
...
@@ -196,3 +276,175 @@ void CudaStream::LaunchGraph(const CudaGraphExecutable* executable) {
}
// namespace oneflow
#endif // WITH_CUDA
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
namespace
{
constexpr
size_t
kDefaultWorkspaceSize
=
4
*
1024
*
1024
;
// 4M
void
SetAffinityByDevice
(
int
dev_id
)
{
auto
node_device_desc_mgr
=
Singleton
<
hardware
::
NodeDeviceDescriptorManager
>::
Get
();
if
(
node_device_desc_mgr
==
nullptr
)
{
return
;
}
auto
node_device_desc
=
node_device_desc_mgr
->
GetLocalNodeDeviceDescriptor
();
auto
cuda_device
=
std
::
dynamic_pointer_cast
<
const
hardware
::
CudaDeviceDescriptor
>
(
node_device_desc
->
GetDevice
(
hardware
::
kCudaDeviceDescriptorClassName
,
dev_id
));
if
(
!
cuda_device
)
{
return
;
}
node_device_desc
->
Topology
()
->
SetCPUAffinityByPCIBusID
(
cuda_device
->
PCIBusID
());
node_device_desc
->
Topology
()
->
SetMemoryAffinityByPCIBusID
(
cuda_device
->
PCIBusID
());
}
}
// namespace
#ifdef WITH_ROCM_GRAPHS
CudaGraphExecutable
::
CudaGraphExecutable
()
:
graph_exec_
(
nullptr
),
dev_
(
-
1
)
{}
CudaGraphExecutable
::~
CudaGraphExecutable
()
{
Reset
();
}
void
CudaGraphExecutable
::
Update
(
hipGraph_t
graph
)
{
int
dev
=
-
1
;
OF_CUDA_CHECK
(
hipGetDevice
(
&
dev
));
if
(
dev
!=
dev_
)
{
Reset
();
}
dev_
=
dev
;
if
(
graph_exec_
!=
nullptr
)
{
hipGraphExecUpdateResult
update_result
{};
hipGraphNode_t
error_node
=
nullptr
;
OF_CUDA_CHECK
(
hipGraphExecUpdate
(
graph_exec_
,
graph
,
&
error_node
,
&
update_result
));
if
(
update_result
==
hipGraphExecUpdateSuccess
)
{
return
;
}
}
Reset
();
OF_CUDA_CHECK
(
hipGraphInstantiate
(
&
graph_exec_
,
graph
,
NULL
,
NULL
,
0
));
}
void
CudaGraphExecutable
::
Launch
(
hipStream_t
stream
)
const
{
OF_CUDA_CHECK
(
hipGraphLaunch
(
graph_exec_
,
stream
));
}
bool
CudaGraphExecutable
::
IsInstantiated
()
const
{
return
graph_exec_
!=
nullptr
;
}
void
CudaGraphExecutable
::
Reset
()
{
if
(
graph_exec_
!=
nullptr
)
{
CudaCurrentDeviceGuard
guard
(
dev_
);
OF_CUDA_CHECK
(
hipGraphExecDestroy
(
graph_exec_
));
}
}
#endif // WITH_ROCM_GRAPHS
CudaStream
::
CudaStream
(
CudaDevice
*
device
)
:
device_index_
(
device
->
device_index
()),
device_
(
device
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
// cuda_stream
const
char
*
stream_flags_env_name
=
"ONEFLOW_EP_CUDA_STREAM_FLAGS"
;
if
(
std
::
getenv
(
stream_flags_env_name
)
!=
nullptr
)
{
const
unsigned
int
stream_flags
=
ParseIntegerFromEnv
(
stream_flags_env_name
,
0
);
OF_CUDA_CHECK
(
hipStreamCreateWithFlags
(
&
cuda_stream_
,
stream_flags
));
}
else
{
OF_CUDA_CHECK
(
hipStreamCreate
(
&
cuda_stream_
));
}
// cublas_handle
OF_CUBLAS_CHECK
(
hipblasCreate
(
&
cublas_handle_
));
OF_CUBLAS_CHECK
(
hipblasSetStream
(
cublas_handle_
,
cuda_stream_
));
workspace_size_
=
kDefaultWorkspaceSize
;
OF_CUDA_CHECK
(
hipMalloc
(
&
workspace_
,
workspace_size_
));
OF_CUDNN_CHECK
(
hipdnnCreate
(
&
cudnn_handle_
));
OF_CUDNN_CHECK
(
hipdnnSetStream
(
cudnn_handle_
,
cuda_stream_
));
}
CudaStream
::~
CudaStream
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipStreamSynchronize
(
cuda_stream_
));
OF_CUDNN_CHECK
(
hipdnnDestroy
(
cudnn_handle_
));
OF_CUBLAS_CHECK
(
hipblasDestroy
(
cublas_handle_
));
OF_CUDA_CHECK
(
hipStreamDestroy
(
cuda_stream_
));
OF_CUDA_CHECK
(
hipFree
(
workspace_
));
}
Maybe
<
void
>
CudaStream
::
OnExecutionContextSetup
()
{
OF_CUDA_CHECK
(
hipSetDevice
(
device_index_
));
SetAffinityByDevice
(
device_index_
);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
CudaStream
::
OnExecutionContextTeardown
()
{
return
Maybe
<
void
>::
Ok
();
}
DeviceType
CudaStream
::
device_type
()
const
{
return
DeviceType
::
kCUDA
;
}
CudaDevice
*
CudaStream
::
device
()
const
{
return
device_
;
}
Maybe
<
void
>
CudaStream
::
Sync
()
{
hipError_t
err
=
hipStreamSynchronize
(
cuda_stream_
);
if
(
err
==
hipSuccess
)
{
return
Maybe
<
void
>::
Ok
();
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
)
<<
" ("
<<
err
<<
") "
;
}
}
void
CudaStream
::
RecordEvent
(
Event
*
event
)
{
auto
*
cuda_event
=
static_cast
<
CudaEvent
*>
(
event
);
// NOLINT
OF_CUDA_CHECK
(
hipEventRecord
(
cuda_event
->
cuda_event
(),
cuda_stream_
));
}
Maybe
<
void
>
CudaStream
::
GetAsyncError
()
{
hipError_t
err
=
hipGetLastError
();
if
(
err
==
hipSuccess
)
{
return
Maybe
<
void
>::
Ok
();
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
)
<<
" ("
<<
err
<<
") "
;
}
}
hipStream_t
CudaStream
::
cuda_stream
()
const
{
return
cuda_stream_
;
}
hipblasHandle_t
CudaStream
::
cublas_handle
()
const
{
return
cublas_handle_
;
}
void
*
CudaStream
::
cublas_workspace
()
const
{
return
workspace_
;
}
size_t
CudaStream
::
cublas_workspace_size
()
const
{
return
workspace_size_
;
}
hipdnnHandle_t
CudaStream
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
const
hipDeviceProp_t
&
CudaStream
::
device_properties
()
const
{
return
device_
->
properties
();
}
int
CudaStream
::
cuda_arch
()
const
{
return
device_
->
properties
().
major
*
100
+
device_
->
properties
().
minor
*
10
;
}
#ifdef WITH_ROCM_GRAPHS
void
CudaStream
::
BeginGraphCapture
()
{
CHECK
(
!
is_graph_capturing_
);
is_graph_capturing_
=
true
;
OF_CUDA_CHECK
(
hipStreamBeginCapture
(
cuda_stream_
,
hipStreamCaptureModeThreadLocal
));
}
void
CudaStream
::
EndGraphCapture
(
CudaGraphExecutable
*
executable
)
{
hipGraph_t
graph
=
nullptr
;
OF_CUDA_CHECK
(
hipStreamEndCapture
(
cuda_stream_
,
&
graph
));
executable
->
Update
(
graph
);
OF_CUDA_CHECK
(
hipGraphDestroy
(
graph
));
is_graph_capturing_
=
false
;
}
bool
CudaStream
::
IsGraphCapturing
()
const
{
return
is_graph_capturing_
;
}
void
CudaStream
::
LaunchGraph
(
const
CudaGraphExecutable
*
executable
)
{
executable
->
Launch
(
cuda_stream_
);
}
#endif // WITH_ROCM_GRAPHS
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/cuda/cuda_stream.h
View file @
a715222c
...
...
@@ -79,6 +79,7 @@ class CudaStream : public Stream {
CudaDevice
*
device
()
const
override
;
Maybe
<
void
>
Sync
()
override
;
void
RecordEvent
(
Event
*
event
)
override
;
Maybe
<
void
>
GetAsyncError
()
override
;
Maybe
<
void
>
OnExecutionContextSetup
()
override
;
Maybe
<
void
>
OnExecutionContextTeardown
()
override
;
...
...
@@ -165,4 +166,139 @@ class CudaStream : public Stream {
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include "oneflow/core/hipdnn/hipdnn.h"
// #if CUDA_VERSION >= 11000
// #define WITH_ROCM_GRAPHS
// #endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_util.h"
namespace
oneflow
{
namespace
ep
{
class
CudaDevice
;
#ifdef WITH_ROCM_GRAPHS
class
CudaGraphExecutable
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaGraphExecutable
);
CudaGraphExecutable
();
~
CudaGraphExecutable
();
void
Update
(
hipGraph_t
graph
);
void
Launch
(
hipStream_t
stream
)
const
;
bool
IsInstantiated
()
const
;
private:
void
Reset
();
hipGraphExec_t
graph_exec_
;
int
dev_
;
};
#endif // WITH_ROCM_GRAPHS
struct
CudaLaunchConfig
{
dim3
grid_dim
;
dim3
block_dim
;
size_t
shared_mem_size
;
CudaLaunchConfig
()
:
grid_dim
{},
block_dim
{},
shared_mem_size
(
0
)
{}
CudaLaunchConfig
(
unsigned
int
grid_size
,
unsigned
int
block_size
,
size_t
shared_mem_size
)
:
grid_dim
(
grid_size
),
block_dim
(
block_size
),
shared_mem_size
(
shared_mem_size
)
{}
};
class
CudaStream
:
public
Stream
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaStream
);
explicit
CudaStream
(
CudaDevice
*
device
);
~
CudaStream
()
override
;
static
constexpr
uint32_t
kDefaultBlockSize
=
256
;
DeviceType
device_type
()
const
override
;
CudaDevice
*
device
()
const
override
;
Maybe
<
void
>
Sync
()
override
;
void
RecordEvent
(
Event
*
event
)
override
;
Maybe
<
void
>
GetAsyncError
()
override
;
Maybe
<
void
>
OnExecutionContextSetup
()
override
;
Maybe
<
void
>
OnExecutionContextTeardown
()
override
;
hipStream_t
cuda_stream
()
const
;
hipblasHandle_t
cublas_handle
()
const
;
hipdnnHandle_t
cudnn_handle
()
const
;
void
*
cublas_workspace
()
const
;
size_t
cublas_workspace_size
()
const
;
const
hipDeviceProp_t
&
device_properties
()
const
;
int
cuda_arch
()
const
;
void
InitLaunchConfigWithWaves
(
CudaLaunchConfig
*
config
,
size_t
elem_cnt
,
size_t
block_size
,
size_t
max_waves
)
const
{
const
uint32_t
max_grid_size
=
max_waves
*
device_properties
().
multiProcessorCount
*
(
device_properties
().
maxThreadsPerMultiProcessor
/
block_size
);
const
uint32_t
grid_size
=
std
::
min
<
uint32_t
>
(
max_grid_size
,
(
elem_cnt
+
block_size
-
1
)
/
block_size
);
config
->
grid_dim
=
dim3
(
grid_size
);
config
->
block_dim
=
dim3
(
block_size
);
config
->
shared_mem_size
=
0
;
}
#ifdef __HIPCC__
template
<
typename
...
Params
,
typename
...
Args
>
void
LaunchKernel
(
void
(
*
kernel
)(
Params
...),
const
CudaLaunchConfig
&
launch_config
,
Args
...
args
)
{
kernel
<<<
launch_config
.
grid_dim
,
launch_config
.
block_dim
,
launch_config
.
shared_mem_size
,
cuda_stream
()
>>>
(
args
...);
}
template
<
typename
...
Params
,
typename
...
Args
>
void
LaunchKernel
(
void
(
*
kernel
)(
Params
...),
size_t
elem_cnt
,
size_t
max_waves
,
Args
...
args
)
{
constexpr
uint32_t
block_size
=
kDefaultBlockSize
;
CudaLaunchConfig
config
{};
InitLaunchConfigWithWaves
(
&
config
,
elem_cnt
,
block_size
,
max_waves
);
LaunchKernel
(
kernel
,
config
,
args
...);
}
template
<
typename
...
Params
,
typename
...
Args
>
void
LaunchKernelDefaultWaves
(
void
(
*
kernel
)(
Params
...),
size_t
elem_cnt
,
Args
...
args
)
{
const
size_t
default_waves
=
32
;
LaunchKernel
(
kernel
,
elem_cnt
,
default_waves
,
args
...);
}
#endif // __HIPCC__
#ifdef WITH_ROCM_GRAPHS
void
BeginGraphCapture
();
void
EndGraphCapture
(
CudaGraphExecutable
*
executable
);
bool
IsGraphCapturing
()
const
;
void
LaunchGraph
(
const
CudaGraphExecutable
*
executable
);
#endif // WITH_ROCM_GRAPHS
private:
hipStream_t
cuda_stream_
{};
hipblasHandle_t
cublas_handle_
{};
hipdnnHandle_t
cudnn_handle_
{};
int
device_index_
;
void
*
workspace_
{};
size_t
workspace_size_
{};
#ifdef WITH_ROCM_GRAPHS
bool
is_graph_capturing_
{};
#endif // WITH_ROCM_GRAPHS
CudaDevice
*
device_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_CUDA_CUDA_STREAM_H_
oneflow/core/ep/cuda/primitive/add.cu
View file @
a715222c
...
...
@@ -47,17 +47,17 @@ __global__ void AddGpu(const Args*... srcs, T* dst, size_t count) {
}
template
<
typename
T
,
typename
...
Args
>
void
LaunchAddGpu
(
cuda
Stream_t
stream
,
const
Args
*
...
srcs
,
T
*
dst
,
size_t
count
)
{
void
LaunchAddGpu
(
GPU
(
Stream_t
)
stream
,
const
Args
*
...
srcs
,
T
*
dst
,
size_t
count
)
{
AddGpu
<
T
,
Args
...
>
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
srcs
...,
dst
,
count
);
}
template
<
typename
T
>
void
DispatchLaunch
(
cuda
Stream_t
stream
,
const
T
*
const
*
srcs
,
size_t
arity
,
T
*
dst
,
size_t
count
)
{
void
DispatchLaunch
(
GPU
(
Stream_t
)
stream
,
const
T
*
const
*
srcs
,
size_t
arity
,
T
*
dst
,
size_t
count
)
{
if
(
arity
==
0
)
{
OF_CUDA_CHECK
(
cuda
MemsetAsync
(
dst
,
0
,
count
*
sizeof
(
T
),
stream
));
OF_CUDA_CHECK
(
GPU
(
MemsetAsync
)
(
dst
,
0
,
count
*
sizeof
(
T
),
stream
));
}
else
if
(
arity
==
1
)
{
OF_CUDA_CHECK
(
cuda
MemcpyAsync
(
dst
,
srcs
[
0
],
count
*
sizeof
(
T
),
cuda
MemcpyDefault
,
stream
));
OF_CUDA_CHECK
(
GPU
(
MemcpyAsync
)
(
dst
,
srcs
[
0
],
count
*
sizeof
(
T
),
GPU
(
MemcpyDefault
)
,
stream
));
}
else
if
(
arity
==
2
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Binary
<
AddFunctor
<
T
,
T
>
,
T
,
T
,
T
>
(
AddFunctor
<
T
,
T
>
(),
count
,
dst
,
srcs
[
0
],
srcs
[
1
],
stream
)));
...
...
@@ -94,7 +94,7 @@ class AddImpl : public Add {
using
Add
::
Launch
;
void
Launch
(
Stream
*
stream
,
const
void
*
const
*
srcs
,
size_t
arity
,
void
*
dst
,
size_t
count
)
override
{
cuda
Stream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
GPU
(
Stream_t
)
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
DispatchLaunch
(
cuda_stream
,
reinterpret_cast
<
const
T
*
const
*>
(
srcs
),
arity
,
reinterpret_cast
<
T
*>
(
dst
),
count
);
}
...
...
oneflow/core/ep/cuda/primitive/binary_functor.cuh
View file @
a715222c
...
...
@@ -29,27 +29,77 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, Src, Dst> {
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
k
Pow
,
bool
,
bool
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
k
Fmod
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
return
static_cast
<
bool
>
(
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
return
fmod
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kFmod
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
return
fmod
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kFloorDiv
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
return
floor
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kFloorDiv
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
return
floor
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTruncDiv
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
return
truncf
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTruncDiv
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
return
trunc
(
src0
/
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kFloorMod
,
float
,
float
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
float
operator
()(
float
src0
,
float
src1
)
const
{
float
trunc_mod
=
fmod
(
src0
,
src1
);
return
(
trunc_mod
!=
static_cast
<
float
>
(
0
))
&&
((
src1
<
static_cast
<
float
>
(
0
))
!=
(
trunc_mod
<
static_cast
<
float
>
(
0
)))
?
trunc_mod
+
src1
:
trunc_mod
;
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
k
Pow
,
half
,
half
>
{
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
k
FloorMod
,
double
,
double
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
return
static_cast
<
half
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
OF_DEVICE_FUNC
double
operator
()(
double
src0
,
double
src1
)
const
{
double
trunc_mod
=
fmod
(
src0
,
src1
);
return
(
trunc_mod
!=
static_cast
<
double
>
(
0
))
&&
((
src1
<
static_cast
<
double
>
(
0
))
!=
(
trunc_mod
<
static_cast
<
double
>
(
0
)))
?
trunc_mod
+
src1
:
trunc_mod
;
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{
#if defined(__CUDA_ARCH__)
#if defined(__CUDA_ARCH__)
|| defined( __HIP_DEVICE_COMPILE__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#else
coef
=
std
::
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
std
::
acos
(
static_cast
<
Src
>
(
-
1.0
)));
...
...
@@ -65,6 +115,39 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kGeluBackwardWithDyX, Src, Dst
Src
coef
;
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kFastGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
// ref to: https://mlfromscratch.com/activation-functions-explained/#gelu
const
Src
one
=
static_cast
<
Src
>
(
1
);
const
Src
half
=
static_cast
<
Src
>
(
0.5
);
const
Src
pow3
=
x
*
x
*
x
;
const
Src
tanh_out
=
std
::
tanh
(
alpha
*
(
x
+
beta
*
pow3
));
const
Src
dtanh
=
alpha
*
(
half
*
x
+
beta
*
static_cast
<
Src
>
(
1.5
)
*
pow3
);
return
dy
*
(
half
+
half
*
tanh_out
+
dtanh
*
(
one
-
tanh_out
*
tanh_out
));
}
private:
const
Src
alpha
=
static_cast
<
Src
>
(
0.7978845608028654
);
const
Src
beta
=
static_cast
<
Src
>
(
0.044714998453855515
);
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kQuickGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
one
=
static_cast
<
Src
>
(
1.0
);
const
Src
sigmoid
=
one
/
(
one
+
exp
(
-
x
*
alpha
));
return
dy
*
(
sigmoid
+
alpha
*
x
*
(
sigmoid
*
(
one
-
sigmoid
)));
}
private:
const
Src
alpha
=
static_cast
<
Src
>
(
1.702
);
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
...
...
@@ -75,19 +158,114 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kTanhBackwardWithDyX, Src, Dst
}
};
/*********nv_bfloat16_kernel*******/
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
int
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
#if CUDA_VERSION >= 11000
OF_DEVICE_FUNC
Dst
operator
()(
int
src0
,
int
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
nv_bfloat16
,
nv_bfloat16
>
{
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
int8_t
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
int8_t
src0
,
int8_t
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
uint8_t
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
uint8_t
src0
,
uint8_t
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
int64_t
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kScalarExpPowerGrad
,
float
,
float
>
float_functor
;
OF_DEVICE_FUNC
Dst
operator
()(
int
src0
,
int
src1
)
const
{
return
static_cast
<
Dst
>
(
float_functor
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kAtanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
const
Src
one
=
static_cast
<
Src
>
(
1.0
);
return
dy
*
one
/
(
one
-
static_cast
<
Src
>
(
pow
(
x
,
2
)));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kIsCloseEqualNan
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
atol
(
attr0
.
Value
<
float
>
()),
rtol
(
attr1
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
bool
close
=
src0
==
src1
;
close
|=
(
isnan
(
src0
)
and
isnan
(
src1
));
if
(
atol
==
0
and
rtol
==
0
)
return
close
;
Src
allowed_error
=
static_cast
<
Src
>
(
atol
)
+
abs
(
static_cast
<
Src
>
(
rtol
)
*
src1
);
Src
actual_error
=
abs
(
src0
-
src1
);
close
|=
(
isfinite
(
actual_error
)
and
(
actual_error
<=
allowed_error
));
return
close
;
}
float
atol
,
rtol
;
};
OF_DEVICE_FUNC
nv_bfloat16
operator
()(
nv_bfloat16
src0
,
nv_bfloat16
src1
)
const
{
return
static_cast
<
nv_bfloat16
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kIsClose
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
atol
(
attr0
.
Value
<
float
>
()),
rtol
(
attr1
.
Value
<
float
>
())
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
bool
close
=
src0
==
src1
;
if
(
atol
==
0
and
rtol
==
0
)
return
close
;
Src
allowed_error
=
static_cast
<
Src
>
(
atol
)
+
abs
(
static_cast
<
Src
>
(
rtol
)
*
src1
);
Src
actual_error
=
abs
(
src0
-
src1
);
close
|=
(
isfinite
(
actual_error
)
and
(
actual_error
<=
allowed_error
));
return
close
;
}
float
atol
,
rtol
;
};
#define SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR(op, type) \
template<typename Dst> \
struct BinaryFunctor<DeviceType::kCUDA, op, type, Dst> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
OF_DEVICE_FUNC Dst operator()(type src0, type src1) const { \
return float_functor(static_cast<float>(src0), static_cast<float>(src1)); \
} \
BinaryFunctor<DeviceType::kCUDA, op, float, Dst> float_functor; \
};
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
,
bool
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
,
int
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
,
char
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
,
int8_t
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
,
uint8_t
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
,
int64_t
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
,
bool
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
,
int
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
,
char
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
,
int8_t
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
,
uint8_t
);
SPECIALIZATION_INTEGRAL_CLOSENESS_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
,
int64_t
);
/*********nv_bfloat16_kernel*******/
#if CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
...
...
@@ -99,6 +277,13 @@ struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16
} \
};
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kPow
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kFmod
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kFloorDiv
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kTruncDiv
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kFloorMod
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kScalarBasePowerGrad
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kScalarExpPowerGrad
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
...
...
@@ -115,6 +300,42 @@ SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kFastGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kQuickGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kAcosBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kAcoshBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kAsinBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kAsinhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kCosBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kCoshBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kErfBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kErfcBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kExpBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kExpm1BackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kLog2BackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kLog10BackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kLogSigmoidBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kReciprocalNoNanBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kRsqrtBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kSinBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kSinhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kSqrtBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kTanBackwardWithDyX
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kSigmoidBackwardWithDyY
);
SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR
(
BinaryOp
::
kAtanhBackwardWithDyX
);
#define SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR(op) \
template<typename Dst> \
struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, Dst> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
BinaryFunctor<DeviceType::kCUDA, op, float, Dst> float_functor; \
OF_DEVICE_FUNC Dst operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
return float_functor(__bfloat162float(src0), __bfloat162float(src1)); \
} \
};
SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
)
SPECIALIZATION_BFLOAT16_COMPARISON_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
)
#endif // CUDA_VERSION >= 11000
...
...
@@ -129,6 +350,13 @@ SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDy
} \
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kPow
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kFmod
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kFloorDiv
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTruncDiv
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kFloorMod
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kScalarBasePowerGrad
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kScalarExpPowerGrad
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
...
...
@@ -142,6 +370,69 @@ SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kFastGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kQuickGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kAcosBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kAcoshBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kAsinBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kAsinhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCosBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCoshBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kErfBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kErfcBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kExpBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kExpm1BackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kLog2BackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kLog10BackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kLogSigmoidBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kReciprocalNoNanBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kRsqrtBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSinBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSinhBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSqrtBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSigmoidBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kAtanhBackwardWithDyX
);
#define SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR(op) \
template<typename Dst> \
struct BinaryFunctor<DeviceType::kCUDA, op, half, Dst> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
BinaryFunctor<DeviceType::kCUDA, op, float, Dst> float_functor; \
OF_DEVICE_FUNC Dst operator()(half src0, half src1) const { \
return float_functor(__half2float(src0), __half2float(src1)); \
} \
};
SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR
(
BinaryOp
::
kIsCloseEqualNan
)
SPECIALIZATION_HALF_COMPARISON_BINARY_FUNCTOR
(
BinaryOp
::
kIsClose
)
#define SPECIALIZATION_GPU_BINARY_FUNCTOR(op, type) \
template<> \
struct BinaryFunctor<DeviceType::kCUDA, op, type, type> { \
OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : int_functor(attr0, attr1) {} \
\
BinaryFunctor<DeviceType::kCUDA, op, int, int> int_functor; \
OF_DEVICE_FUNC type operator()(type src0, type src1) const { \
return static_cast<type>(int_functor(static_cast<int>(src0), static_cast<int>(src1))); \
} \
};
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kPow
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kFmod
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorDiv
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kTruncDiv
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorMod
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarBasePowerGrad
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarExpPowerGrad
,
bool
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kPow
,
char
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kFmod
,
char
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorDiv
,
char
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kTruncDiv
,
char
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kFloorMod
,
char
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarBasePowerGrad
,
char
);
SPECIALIZATION_GPU_BINARY_FUNCTOR
(
BinaryOp
::
kScalarExpPowerGrad
,
char
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cu
View file @
a715222c
...
...
@@ -85,7 +85,11 @@ class BroadcastElementwiseBinaryFactoryImpl : public BroadcastElementwiseBinaryF
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_MATH_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh
View file @
a715222c
...
...
@@ -13,7 +13,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/
/
broadcast_elementwise_binary.h"
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
...
...
@@ -299,8 +299,8 @@ void DispatchLaunch(Stream* stream, size_t num_src0_dims, const int64_t* src0_di
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src0_dims
,
src0_dims
,
num_src1_dims
,
src1_dims
,
&
simplified_num_dims
,
simplified_src0_dims
,
simplified_src1_dims
,
simplified_dst_dims
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_
src1
_dims
,
src1
,
simplified_dst_dims
,
dst
);
CheckInplace
(
simplified_num_dims
,
simplified_src0_dims
,
src0
,
simplified_
dst
_dims
,
dst
);
CheckInplace
(
simplified_num_dims
,
simplified_src1_dims
,
src1
,
simplified_dst_dims
,
dst
);
if
(
IsDimsEquals
(
simplified_num_dims
,
simplified_src0_dims
,
simplified_num_dims
,
simplified_src1_dims
))
{
const
int64_t
elem_cnt
=
GetElementCount
(
simplified_num_dims
,
simplified_src0_dims
);
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math.cu
View file @
a715222c
...
...
@@ -27,7 +27,7 @@ namespace broadcast_elementwise_binary {
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
BINARY_MATH_OP_SEQ
_0
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
...
...
oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary_math1.cu
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
<
\
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
(
\
Scalar
attr0
,
Scalar
attr1
);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ_1
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/cuda/primitive/broadcast_elementwise_unary.cu
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_unary.h"
#include "oneflow/core/ep/cuda/primitive/unary_functor.cuh"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include "oneflow/core/cuda/elementwise.cuh"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_unary
{
namespace
{
constexpr
size_t
kMaxPackSize
=
4
;
template
<
size_t
max_pack_size
,
typename
Src
,
typename
Dst
>
size_t
GetPackSize
(
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
dst_dims
,
const
void
*
dst
)
{
static_assert
(
max_pack_size
>
0
&&
(
max_pack_size
&
(
max_pack_size
-
1
))
==
0
,
""
);
for
(
size_t
pack_size
=
max_pack_size
;
pack_size
>
2
;
pack_size
/=
2
)
{
bool
is_src_supported
=
IsPackSizeSupported
<
Src
>
(
pack_size
,
num_dims
,
src_dims
,
src
);
bool
is_dst_supported
=
IsPackSizeSupported
<
Dst
>
(
pack_size
,
num_dims
,
dst_dims
,
dst
);
if
(
is_src_supported
&&
is_dst_supported
)
{
return
pack_size
;
}
}
return
1
;
}
template
<
typename
Src
,
typename
Dst
,
size_t
max_dims
,
typename
IndexType
>
struct
BroadcastElementwiseUnaryParams
{
IndexToOffsetWithStrideCalculator
<
IndexType
,
max_dims
>
src_index_to_offset_helper
;
OffsetToIndexWithStrideCalculator
<
IndexType
,
max_dims
>
dst_offset_to_index_helper
;
IndexToOffsetWithStrideCalculator
<
IndexType
,
max_dims
>
dst_index_to_offset_helper
;
size_t
num_dims
;
IndexType
src_index_mask
[
max_dims
];
IndexType
count
{};
const
Src
*
src
{};
Dst
*
dst
{};
bool
dst_is_contiguous
;
Scalar
attr0
;
Scalar
attr1
;
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
struct
UnaryScalarFunctor
{
__host__
__device__
explicit
UnaryScalarFunctor
(
Src
scalar
)
:
scalar
(
scalar
)
{}
__device__
Dst
operator
()()
const
{
return
UnaryFunctor
<
DeviceType
::
kCUDA
,
unary_op
,
Src
,
Dst
>
()(
scalar
);
}
const
Src
scalar
;
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
struct
UnaryScalarPtrFunctorFactory
{
__host__
__device__
explicit
UnaryScalarPtrFunctorFactory
(
const
Src
*
scalar_ptr
)
:
scalar_ptr
(
scalar_ptr
)
{}
__device__
UnaryScalarFunctor
<
unary_op
,
Src
,
Dst
>
operator
()()
const
{
return
UnaryScalarFunctor
<
unary_op
,
Src
,
Dst
>
(
*
scalar_ptr
);
}
const
Src
*
scalar_ptr
;
};
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
pack_size
,
typename
IndexType
>
__global__
void
BroadcastElementwiseUnaryGpu
(
BroadcastElementwiseUnaryParams
<
Src
,
Dst
,
max_dims
,
IndexType
>
params
)
{
using
LoadPack
=
cuda
::
elementwise
::
Packed
<
Src
,
pack_size
>
;
using
StorePack
=
cuda
::
elementwise
::
Packed
<
Dst
,
pack_size
>
;
const
LoadPack
*
src
=
reinterpret_cast
<
const
LoadPack
*>
(
params
.
src
);
StorePack
*
dst
=
reinterpret_cast
<
StorePack
*>
(
params
.
dst
);
IndexType
src_index
[
max_dims
];
IndexType
dst_index
[
max_dims
];
size_t
num_dims
=
params
.
num_dims
;
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
Src
,
Dst
>
(
params
.
attr0
,
params
.
attr1
);
CUDA_1D_KERNEL_LOOP_T
(
IndexType
,
offset
,
params
.
count
)
{
params
.
dst_offset_to_index_helper
.
OffsetToNdIndex
(
offset
,
dst_index
,
num_dims
);
#pragma unroll
for
(
int
i
=
0
;
i
<
max_dims
;
++
i
)
{
if
(
i
<
num_dims
)
{
src_index
[
i
]
=
params
.
src_index_mask
[
i
]
*
dst_index
[
i
];
}
}
const
IndexType
src_offset
=
params
.
src_index_to_offset_helper
.
NdIndexToOffset
(
src_index
,
num_dims
);
LoadPack
src_pack
=
src
[
src_offset
];
StorePack
dst_pack
;
#pragma unroll
for
(
int
j
=
0
;
j
<
pack_size
;
++
j
)
{
dst_pack
.
elem
[
j
]
=
functor
(
src_pack
.
elem
[
j
]);
}
IndexType
dst_offset
=
offset
;
if
(
!
params
.
dst_is_contiguous
)
{
dst_offset
=
params
.
dst_index_to_offset_helper
.
NdIndexToOffset
(
dst_index
,
num_dims
);
}
dst
[
dst_offset
]
=
dst_pack
;
}
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
pack_size
,
typename
IndexType
>
void
LaunchKernel
(
CudaStream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
const
Src
*
src
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
Dst
*
dst
,
bool
continuous_output
,
Scalar
attr0
,
Scalar
attr1
,
size_t
count
)
{
BroadcastElementwiseUnaryParams
<
Src
,
Dst
,
max_dims
,
IndexType
>
params
;
for
(
size_t
i
=
0
;
i
<
num_dims
;
++
i
)
{
params
.
src_index_mask
[
i
]
=
(
src_dims
[
i
]
==
1
)
?
0
:
1
;
}
params
.
src_index_to_offset_helper
=
IndexToOffsetWithStrideCalculator
<
IndexType
,
max_dims
>
(
src_strides
,
num_dims
);
params
.
dst_offset_to_index_helper
=
OffsetToIndexWithStrideCalculator
<
IndexType
,
max_dims
>
(
dst_dims
,
num_dims
);
params
.
dst_index_to_offset_helper
=
IndexToOffsetWithStrideCalculator
<
IndexType
,
max_dims
>
(
dst_strides
,
num_dims
);
params
.
num_dims
=
num_dims
;
params
.
src
=
src
;
params
.
dst
=
dst
;
params
.
count
=
static_cast
<
IndexType
>
(
count
);
params
.
attr0
=
attr0
;
params
.
attr1
=
attr1
;
params
.
dst_is_contiguous
=
continuous_output
;
BroadcastElementwiseUnaryGpu
<
op
,
Src
,
Dst
,
max_dims
,
pack_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
cuda_stream
()
>>>
(
params
);
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
,
size_t
pack_size
>
void
DispatchIndexType
(
CudaStream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
const
Src
*
src
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
Dst
*
dst
,
bool
continuous_output
,
Scalar
attr0
,
Scalar
attr1
)
{
size_t
count
=
GetElementCount
(
num_dims
,
dst_dims
);
if
(
count
<
GetMaxVal
<
int32_t
>
()
/
2
)
{
LaunchKernel
<
op
,
Src
,
Dst
,
max_dims
,
pack_size
,
int32_t
>
(
stream
,
num_dims
,
src_dims
,
src_strides
,
src
,
dst_dims
,
dst_strides
,
dst
,
continuous_output
,
attr0
,
attr1
,
count
);
}
else
{
LaunchKernel
<
op
,
Src
,
Dst
,
max_dims
,
pack_size
,
int64_t
>
(
stream
,
num_dims
,
src_dims
,
src_strides
,
src
,
dst_dims
,
dst_strides
,
dst
,
continuous_output
,
attr0
,
attr1
,
count
);
}
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
max_dims
>
void
DispatchPackSize
(
CudaStream
*
stream
,
size_t
pack_size
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
const
Src
*
src
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
Dst
*
dst
,
bool
continuous_output
,
Scalar
attr0
,
Scalar
attr1
)
{
void
(
*
func
)(
CudaStream
*
/*stream*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src_dims*/
,
const
int64_t
*
/*src_strides*/
,
const
Src
*
/*src*/
,
const
int64_t
*
/*dst_dims*/
,
const
int64_t
*
/*dst_strides*/
,
Dst
*
/*dst*/
,
bool
/*continuous_output*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
if
(
pack_size
==
1
)
{
func
=
DispatchIndexType
<
op
,
Src
,
Dst
,
max_dims
,
1
>
;
}
else
if
(
pack_size
==
4
)
{
func
=
DispatchIndexType
<
op
,
Src
,
Dst
,
max_dims
,
4
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
num_dims
,
src_dims
,
src_strides
,
src
,
dst_dims
,
dst_strides
,
dst
,
continuous_output
,
attr0
,
attr1
);
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
>
void
DispatchNumDims
(
CudaStream
*
stream
,
size_t
pack_size
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
const
Src
*
src
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
Dst
*
dst
,
bool
continuous_output
,
Scalar
attr0
,
Scalar
attr1
)
{
void
(
*
func
)(
CudaStream
*
/*stream*/
,
size_t
/*pack_size*/
,
size_t
/*num_dims*/
,
const
int64_t
*
/*src_dims*/
,
const
int64_t
*
/*src_strides*/
,
const
Src
*
/*src*/
,
const
int64_t
*
/*dst_dims*/
,
const
int64_t
*
/*dst_strides*/
,
Dst
*
/*dst*/
,
bool
/*continuous_output*/
,
Scalar
/*attr0*/
,
Scalar
/*attr1*/
)
=
nullptr
;
if
(
num_dims
==
1
)
{
func
=
DispatchPackSize
<
op
,
Src
,
Dst
,
1
>
;
}
else
if
(
num_dims
==
2
)
{
func
=
DispatchPackSize
<
op
,
Src
,
Dst
,
2
>
;
}
else
if
(
num_dims
==
3
)
{
func
=
DispatchPackSize
<
op
,
Src
,
Dst
,
3
>
;
}
else
if
(
num_dims
==
4
)
{
func
=
DispatchPackSize
<
op
,
Src
,
Dst
,
4
>
;
}
else
if
(
num_dims
<=
kMaxNumDims
)
{
func
=
DispatchPackSize
<
op
,
Src
,
Dst
,
kMaxNumDims
>
;
}
else
{
UNIMPLEMENTED
();
}
func
(
stream
,
pack_size
,
num_dims
,
src_dims
,
src_strides
,
src
,
dst_dims
,
dst_strides
,
dst
,
continuous_output
,
attr0
,
attr1
);
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
>
void
LaunchWithSimplified
(
CudaStream
*
stream
,
size_t
simplified_num_dims
,
int64_t
*
simplified_src_dims
,
int64_t
*
simplified_src_strides
,
const
Src
*
src
,
int64_t
*
simplified_dst_dims
,
int64_t
*
simplified_dst_strides
,
Dst
*
dst
,
Scalar
attr0
,
Scalar
attr1
)
{
CHECK_LE
(
simplified_num_dims
,
kMaxNumDims
);
bool
src_enable_pack
=
(
simplified_src_strides
[
simplified_num_dims
-
1
]
==
1
);
bool
dst_enable_pack
=
(
simplified_dst_strides
[
simplified_num_dims
-
1
]
==
1
);
size_t
pack_size
=
1
;
// TODO(zzk): this pack has bug, will be fixed in future
// if (src_enable_pack && dst_enable_pack) {
// pack_size = GetPackSize<kMaxPackSize, Src, Dst>(simplified_num_dims, simplified_src_dims,
// src,
// simplified_dst_dims, dst);
// }
bool
continuous_output
=
true
;
for
(
int
i
=
simplified_num_dims
-
1
;
i
>=
0
;
i
--
)
{
if
((
i
==
simplified_num_dims
-
1
&&
simplified_dst_strides
[
i
]
!=
1
)
||
(
i
!=
simplified_num_dims
-
1
&&
simplified_dst_strides
[
i
]
!=
simplified_dst_strides
[
i
+
1
]
*
simplified_dst_dims
[
i
+
1
]))
{
continuous_output
=
false
;
break
;
}
}
simplified_src_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
simplified_dst_dims
[
simplified_num_dims
-
1
]
/=
pack_size
;
DispatchNumDims
<
op
,
Src
,
Dst
>
(
stream
,
pack_size
,
simplified_num_dims
,
simplified_src_dims
,
simplified_src_strides
,
src
,
simplified_dst_dims
,
simplified_dst_strides
,
dst
,
continuous_output
,
attr0
,
attr1
);
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
pack
,
bool
tail
>
__global__
void
LaunchFillKernel
(
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
Src
,
Dst
>
functor
,
Dst
*
dst
,
const
Src
*
src
,
size_t
pack_count
,
size_t
count
,
size_t
tail_count
,
Dst
*
tail_dst
)
{
using
StorePack
=
cuda
::
elementwise
::
Packed
<
Dst
,
pack
>
;
StorePack
pack_value
;
Dst
value
=
functor
(
*
src
);
#pragma unroll
for
(
size_t
i
=
0
;
i
<
pack
;
++
i
)
{
pack_value
.
elem
[
i
]
=
value
;
}
StorePack
*
pack_dst
=
reinterpret_cast
<
StorePack
*>
(
dst
);
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
pack_count
)
{
pack_dst
[
i
]
=
pack_value
;
}
if
(
tail
)
{
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
tail_count
)
{
tail_dst
[
i
]
=
value
;
}
}
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
CudaStream
*
stream
,
Dst
*
dst
,
const
Src
*
src
,
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
const
size_t
pack_count
=
count
/
pack
;
const
size_t
tail_offset
=
pack_count
*
pack
;
const
size_t
tail_count
=
count
-
tail_offset
;
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
op
,
Src
,
Dst
>
(
attr0
,
attr1
);
if
(
tail_count
>
0
)
{
LaunchFillKernel
<
op
,
Src
,
Dst
,
pack
,
true
>
<<<
BlocksNum4ThreadsNum
(
pack_count
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
cuda_stream
()
>>>
(
functor
,
dst
,
src
,
pack_count
,
count
,
tail_count
,
dst
+
tail_offset
);
}
else
{
LaunchFillKernel
<
op
,
Src
,
Dst
,
pack
,
false
>
<<<
BlocksNum4ThreadsNum
(
pack_count
),
kCudaThreadsNumPerBlock
,
0
,
stream
->
cuda_stream
()
>>>
(
functor
,
dst
,
src
,
pack_count
,
count
,
tail_count
,
dst
+
tail_offset
);
}
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
CudaStream
*
stream
,
Dst
*
dst
,
const
Src
*
src
,
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
LOG
(
FATAL
)
<<
"wrong alignment"
;
}
template
<
UnaryOp
op
,
typename
Src
,
typename
Dst
>
void
LaunchFill
(
CudaStream
*
stream
,
Dst
*
dst
,
const
Src
*
src
,
size_t
count
,
Scalar
attr0
,
Scalar
attr1
)
{
auto
uintptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
if
(
uintptr
%
16
==
0
&&
count
*
sizeof
(
Dst
)
>=
16
)
{
LaunchPackFill
<
op
,
Src
,
Dst
,
16
/
sizeof
(
Dst
)
>
(
stream
,
dst
,
src
,
count
,
attr0
,
attr1
);
}
else
if
(
uintptr
%
8
==
0
&&
count
*
sizeof
(
Dst
)
>=
8
)
{
LaunchPackFill
<
op
,
Src
,
Dst
,
8
/
sizeof
(
Dst
)
>
(
stream
,
dst
,
src
,
count
,
attr0
,
attr1
);
}
else
if
(
uintptr
%
4
==
0
&&
count
*
sizeof
(
Dst
)
>=
4
)
{
LaunchPackFill
<
op
,
Src
,
Dst
,
4
/
sizeof
(
Dst
)
>
(
stream
,
dst
,
src
,
count
,
attr0
,
attr1
);
}
else
if
(
uintptr
%
2
==
0
&&
count
*
sizeof
(
Dst
)
>=
2
)
{
LaunchPackFill
<
op
,
Src
,
Dst
,
2
/
sizeof
(
Dst
)
>
(
stream
,
dst
,
src
,
count
,
attr0
,
attr1
);
}
else
{
LaunchPackFill
<
op
,
Src
,
Dst
,
1
/
sizeof
(
Dst
)
>
(
stream
,
dst
,
src
,
count
,
attr0
,
attr1
);
}
}
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
class
BroadcastElementwiseUnaryImpl
:
public
BroadcastElementwiseUnary
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseUnaryImpl
);
BroadcastElementwiseUnaryImpl
(
Scalar
attr0
,
Scalar
attr1
)
:
attr0
(
attr0
),
attr1
(
attr1
)
{}
~
BroadcastElementwiseUnaryImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
size_t
num_src_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
size_t
num_dst_dims
,
const
int64_t
*
dst_dims
,
void
*
dst
)
override
{
CHECK_GT
(
num_src_dims
,
0
)
<<
"num_src_dims must greater than 0"
;
CHECK_GT
(
num_dst_dims
,
0
)
<<
"num_dst_dims must greater than 0"
;
int64_t
src_strides
[
kMaxNumDims
];
int64_t
dst_strides
[
kMaxNumDims
];
// init stride
for
(
int
i
=
num_src_dims
-
1
;
i
<
kMaxNumDims
;
++
i
)
{
src_strides
[
i
]
=
1
;
}
for
(
int
i
=
num_src_dims
-
2
;
i
>=
0
;
--
i
)
{
src_strides
[
i
]
=
src_dims
[
i
+
1
]
*
src_strides
[
i
+
1
];
}
for
(
int
i
=
num_dst_dims
-
1
;
i
<
kMaxNumDims
;
++
i
)
{
dst_strides
[
i
]
=
1
;
}
for
(
int
i
=
num_dst_dims
-
2
;
i
>=
0
;
--
i
)
{
dst_strides
[
i
]
=
dst_dims
[
i
+
1
]
*
dst_strides
[
i
+
1
];
}
Launch
(
stream
,
num_src_dims
,
src_dims
,
src_strides
,
src
,
num_dst_dims
,
dst_dims
,
dst_strides
,
dst
);
}
void
Launch
(
Stream
*
stream
,
size_t
num_src_dims
,
const
int64_t
*
src_dims
,
const
int64_t
*
src_strides
,
const
void
*
src_ptr
,
size_t
num_dst_dims
,
const
int64_t
*
dst_dims
,
const
int64_t
*
dst_strides
,
void
*
dst_ptr
)
override
{
CHECK_GT
(
num_src_dims
,
0
)
<<
"num_src_dims must greater than 0"
;
CHECK_GT
(
num_dst_dims
,
0
)
<<
"num_dst_dims must greater than 0"
;
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
Dst
*
dst
=
reinterpret_cast
<
Dst
*>
(
dst_ptr
);
const
Src
*
src
=
reinterpret_cast
<
const
Src
*>
(
src_ptr
);
size_t
simplified_num_dims
=
0
;
int64_t
simplified_src_dims
[
kMaxNumDims
];
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_strides
[
kMaxNumDims
];
int64_t
simplified_dst_strides
[
kMaxNumDims
];
SimplifyBroadcastDims
<
kMaxNumDims
>
(
num_src_dims
,
src_dims
,
src_strides
,
num_dst_dims
,
dst_dims
,
dst_strides
,
&
simplified_num_dims
,
simplified_src_dims
,
simplified_src_strides
,
simplified_dst_dims
,
simplified_dst_strides
);
CheckInplace
(
simplified_num_dims
,
simplified_src_dims
,
src
,
simplified_dst_dims
,
dst
);
CheckInplace
(
simplified_num_dims
,
simplified_src_strides
,
src
,
simplified_dst_strides
,
dst
);
if
(
simplified_num_dims
==
1
&&
simplified_src_dims
[
0
]
==
1
)
{
const
int64_t
elem_cnt
=
simplified_dst_dims
[
0
];
LaunchFill
<
unary_op
,
Src
,
Dst
>
(
cuda_stream
,
dst
,
src
,
elem_cnt
,
attr0
,
attr1
);
}
else
if
(
simplified_num_dims
==
1
&&
simplified_src_strides
[
0
]
==
1
&&
simplified_dst_strides
[
0
]
==
1
)
{
const
int64_t
elem_cnt
=
simplified_src_dims
[
0
];
auto
functor
=
UnaryFunctor
<
DeviceType
::
kCUDA
,
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
);
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Unary
<
decltype
(
functor
),
Dst
,
Src
>
(
functor
,
elem_cnt
,
dst
,
src
,
cuda_stream
->
cuda_stream
())));
}
else
{
LaunchWithSimplified
<
unary_op
,
Src
,
Dst
>
(
cuda_stream
,
simplified_num_dims
,
simplified_src_dims
,
simplified_src_strides
,
src
,
simplified_dst_dims
,
simplified_dst_strides
,
dst
,
attr0
,
attr1
);
}
}
protected:
Scalar
attr0
,
attr1
;
};
template
<
UnaryOp
unary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
NewBroadcastElementwiseUnary
(
Scalar
attr0
,
Scalar
attr1
)
{
return
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
(
new
BroadcastElementwiseUnaryImpl
<
unary_op
,
Src
,
Dst
>
(
attr0
,
attr1
));
}
class
BroadcastElementwiseUnaryFactoryImpl
:
public
BroadcastElementwiseUnaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseUnaryFactoryImpl
);
BroadcastElementwiseUnaryFactoryImpl
()
=
default
;
~
BroadcastElementwiseUnaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
New
(
UnaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
New
(
UnaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
New
(
UnaryOp
unary_op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
,
Scalar
attr1
)
override
{
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
#define MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY(unary_op, dtype_pair) \
{std::make_tuple(unary_op, OF_PP_PAIR_SECOND(dtype_pair), OF_PP_PAIR_SECOND(dtype_pair)), \
NewBroadcastElementwiseUnary<unary_op, OF_PP_PAIR_FIRST(dtype_pair), \
OF_PP_PAIR_FIRST(dtype_pair)>},
static
const
std
::
map
<
std
::
tuple
<
UnaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseUnary
>
(
Scalar
,
Scalar
)
>>
new_broadcast_elementwise_unary_handle
{
// For All Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY
,
UNARY_BROADCAST_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_SAME_DTYPE_BROADCAST_ELEMENTWISE_UNARY_ENTRY
const
auto
iter
=
new_broadcast_elementwise_unary_handle
.
find
(
std
::
make_tuple
(
unary_op
,
src_type
,
dst_type
));
if
(
iter
!=
new_broadcast_elementwise_unary_handle
.
end
())
{
return
iter
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastElementwiseUnaryFactory
,
BroadcastElementwiseUnaryFactoryImpl
);
}
// namespace
}
// namespace broadcast_elementwise_unary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/cuda/primitive/broadcast_matmul.cpp
View file @
a715222c
...
...
@@ -57,6 +57,7 @@ cudaDataType_t GetCudaDataType(DataType data_type) {
union
CublasScalarParameter
{
double
d
;
float
s
;
half
h
;
};
CublasScalarParameter
GetCublasScalarParameter
(
Scalar
scalar
,
cudaDataType_t
compute_type
)
{
...
...
@@ -65,6 +66,8 @@ CublasScalarParameter GetCublasScalarParameter(Scalar scalar, cudaDataType_t com
sp
.
d
=
scalar
.
Value
<
double
>
();
}
else
if
(
compute_type
==
CUDA_R_32F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
if
(
compute_type
==
CUDA_R_16F
)
{
sp
.
h
=
static_cast
<
half
>
(
scalar
.
Value
<
float
>
());
}
else
{
UNIMPLEMENTED
();
}
...
...
@@ -75,7 +78,15 @@ cudaDataType_t GetComputeType(DataType data_type) {
switch
(
data_type
)
{
case
kFloat
:
return
CUDA_R_32F
;
case
kDouble
:
return
CUDA_R_64F
;
case
kFloat16
:
return
CUDA_R_32F
;
case
kFloat16
:
{
const
bool
allow_half_accumulation
=
ParseBooleanFromEnv
(
"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"
,
false
);
if
(
allow_half_accumulation
)
{
return
CUDA_R_16F
;
}
else
{
return
CUDA_R_32F
;
}
}
#if CUDA_VERSION >= 11000
case
kBFloat16
:
return
CUDA_R_32F
;
#endif // CUDA_VERSION >= 11000
...
...
@@ -207,3 +218,199 @@ REGISTER_PRIMITIVE_FACTORY(DeviceType::kCUDA, BroadcastMatmulFactory, BroadcastM
}
// namespace oneflow
#endif // WITH_CUDA
#ifdef WITH_ROCM
#include "oneflow/core/ep/include/primitive/primitive.h"
#include "oneflow/core/ep/include/primitive/broadcast_matmul.h"
#include "oneflow/core/ep/common/primitive/broadcast_matmul.h"
#include "oneflow/core/common/optional.h"
#include "oneflow/core/device/cuda_util.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#include <hip/hip_runtime.h>
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_matmul
{
namespace
internal
{
namespace
{
constexpr
size_t
kMaxNumDims
=
8
;
Optional
<
hipblasDatatype_t
>
OptCudaDataType
(
DataType
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
return
HIPBLAS_R_16F
;
// #if CUDA_VERSION >= 11000
// case kBFloat16: return CUDA_R_16BF;
// #endif // CUDA_VERSION >= 11000
default:
return
NullOpt
;
}
}
hipblasDatatype_t
GetCudaDataType
(
DataType
data_type
)
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
CHECK
(
cuda_data_type
.
has_value
());
return
cuda_data_type
.
value_or
(
HIPBLAS_R_32F
);
}
union
CublasScalarParameter
{
double
d
;
float
s
;
half
h
;
};
CublasScalarParameter
GetCublasScalarParameter
(
Scalar
scalar
,
hipblasDatatype_t
compute_type
)
{
CublasScalarParameter
sp
{};
if
(
compute_type
==
HIPBLAS_R_64F
)
{
sp
.
d
=
scalar
.
Value
<
double
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_32F
)
{
sp
.
s
=
scalar
.
Value
<
float
>
();
}
else
if
(
compute_type
==
HIPBLAS_R_16F
)
{
sp
.
h
=
static_cast
<
half
>
(
scalar
.
Value
<
float
>
());
}
else
{
UNIMPLEMENTED
();
}
return
sp
;
}
hipblasDatatype_t
GetComputeType
(
DataType
data_type
)
{
switch
(
data_type
)
{
case
kFloat
:
return
HIPBLAS_R_32F
;
case
kDouble
:
return
HIPBLAS_R_64F
;
case
kFloat16
:
{
const
bool
allow_half_accumulation
=
ParseBooleanFromEnv
(
"ONEFLOW_MATMUL_ALLOW_HALF_PRECISION_ACCUMULATION"
,
true
);
if
(
allow_half_accumulation
)
{
return
HIPBLAS_R_16F
;
}
else
{
return
HIPBLAS_R_32F
;
}
}
// #if CUDA_VERSION >= 11000
// case kBFloat16: return HIPBLAS_R_32F;
// #endif // CUDA_VERSION >= 11000
default:
UNIMPLEMENTED
();
return
HIPBLAS_R_32F
;
}
}
void
LaunchBroadcastMatmul
(
Stream
*
stream
,
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
int64_t
num_batch_dims
,
const
int64_t
*
broadcast_batch_dims
,
const
int64_t
*
a_batch_dims
,
const
int64_t
*
b_batch_dims
,
const
int64_t
*
c_batch_dims
,
int64_t
m
,
int64_t
n
,
int64_t
k
,
Scalar
alpha
,
const
void
*
a
,
const
void
*
b
,
Scalar
beta
,
void
*
c
)
{
auto
*
cuda_stream
=
stream
->
As
<
CudaStream
>
();
const
auto
cuda_data_type
=
GetCudaDataType
(
data_type
);
const
auto
compute_type
=
GetComputeType
(
data_type
);
const
auto
sp_alpha
=
GetCublasScalarParameter
(
alpha
,
compute_type
);
const
auto
GetCublasOperation
=
[](
BlasTransposeType
transpose_type
)
{
if
(
transpose_type
==
BlasTransposeType
::
N
)
{
return
HIPBLAS_OP_N
;
}
else
if
(
transpose_type
==
BlasTransposeType
::
T
)
{
return
HIPBLAS_OP_T
;
}
else
{
UNIMPLEMENTED
();
return
HIPBLAS_OP_N
;
}
};
const
hipblasOperation_t
cublas_trans_a
=
GetCublasOperation
(
transpose_b
);
const
hipblasOperation_t
cublas_trans_b
=
GetCublasOperation
(
transpose_a
);
const
int
cublas_m
=
n
;
const
int
cublas_n
=
m
;
const
int
cublas_k
=
k
;
int
cublas_lda
=
0
;
if
(
transpose_b
==
BlasTransposeType
::
N
)
{
cublas_lda
=
n
;
}
else
if
(
transpose_b
==
BlasTransposeType
::
T
)
{
cublas_lda
=
k
;
}
else
{
UNIMPLEMENTED
();
}
int
cublas_ldb
=
0
;
if
(
transpose_a
==
BlasTransposeType
::
N
)
{
cublas_ldb
=
k
;
}
else
if
(
transpose_a
==
BlasTransposeType
::
T
)
{
cublas_ldb
=
m
;
}
else
{
UNIMPLEMENTED
();
}
const
int
cublas_ldc
=
n
;
hipblasGemmAlgo_t
algo
=
HIPBLAS_GEMM_DEFAULT
;
if
(
num_batch_dims
==
1
&&
c_batch_dims
[
0
]
!=
1
)
{
const
void
*
cublas_a
=
b
;
const
void
*
cublas_b
=
a
;
void
*
cublas_c
=
c
;
const
int64_t
a_batch_count
=
a_batch_dims
[
0
];
const
int64_t
b_batch_count
=
b_batch_dims
[
0
];
CHECK
(
a_batch_count
==
1
||
b_batch_count
==
1
||
a_batch_count
==
b_batch_count
);
CHECK_GT
(
a_batch_count
,
0
);
CHECK_GT
(
b_batch_count
,
0
);
const
int
batch_count
=
std
::
max
(
a_batch_count
,
b_batch_count
);
const
long
long
int
cublas_stride_a
=
b_batch_count
==
1
?
0
:
cublas_m
*
cublas_k
;
const
long
long
int
cublas_stride_b
=
a_batch_count
==
1
?
0
:
cublas_k
*
cublas_n
;
const
long
long
int
cublas_stride_c
=
cublas_m
*
cublas_n
;
const
auto
sp_beta
=
GetCublasScalarParameter
(
beta
,
compute_type
);
OF_CUBLAS_CHECK
(
hipblasGemmStridedBatchedEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_stride_a
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
cublas_stride_b
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
cublas_stride_c
,
batch_count
,
compute_type
,
algo
));
}
else
{
auto
func
=
[
&
](
const
void
*
batch_a
,
const
void
*
batch_b
,
void
*
batch_c
,
Scalar
batch_beta
)
{
const
auto
sp_beta
=
GetCublasScalarParameter
(
batch_beta
,
compute_type
);
const
void
*
cublas_a
=
batch_b
;
const
void
*
cublas_b
=
batch_a
;
void
*
cublas_c
=
batch_c
;
OF_CUBLAS_CHECK
(
hipblasGemmEx
(
cuda_stream
->
cublas_handle
(),
cublas_trans_a
,
cublas_trans_b
,
cublas_m
,
cublas_n
,
cublas_k
,
&
sp_alpha
,
cublas_a
,
cuda_data_type
,
cublas_lda
,
cublas_b
,
cuda_data_type
,
cublas_ldb
,
&
sp_beta
,
cublas_c
,
cuda_data_type
,
cublas_ldc
,
compute_type
,
algo
));
};
ForEachMatmul
<
kMaxNumDims
>
(
data_type
,
m
,
n
,
k
,
beta
,
num_batch_dims
,
broadcast_batch_dims
,
a_batch_dims
,
b_batch_dims
,
c_batch_dims
,
a
,
b
,
c
,
func
);
}
}
class
BroadcastMatmulFactoryImpl
:
public
BroadcastMatmulFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastMatmulFactoryImpl
);
BroadcastMatmulFactoryImpl
()
=
default
;
~
BroadcastMatmulFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastMatmul
>
New
(
DataType
data_type
,
BlasTransposeType
transpose_a
,
BlasTransposeType
transpose_b
,
size_t
max_num_dims
)
override
{
auto
cuda_data_type
=
OptCudaDataType
(
data_type
);
if
(
max_num_dims
<=
kMaxNumDims
&&
cuda_data_type
.
has_value
())
{
return
std
::
make_unique
<
BroadcastMatmulImpl
<
kMaxNumDims
>>
(
data_type
,
transpose_a
,
transpose_b
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastMatmulFactory
,
BroadcastMatmulFactoryImpl
);
}
// namespace
}
// namespace internal
}
// namespace broadcast_matmul
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/cuda/primitive/constant_pad.cu
View file @
a715222c
...
...
@@ -17,7 +17,11 @@ limitations under the License.
#include "oneflow/core/ep/common/primitive/constant_pad.h"
#include "oneflow/core/ep/cuda/primitive/type_seq.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda_runtime.h>
#endif
namespace
oneflow
{
...
...
@@ -186,6 +190,7 @@ template<typename T>
void
SimplifyThenLaunch
(
Stream
*
stream
,
size_t
num_dims
,
const
int64_t
*
src_dims
,
const
void
*
src
,
const
int64_t
*
padding_before
,
const
int64_t
*
padding_after
,
T
pad_val
,
void
*
dst
)
{
CHECK_GT
(
num_dims
,
0
)
<<
"num_dims must greater than 0"
;
CHECK_LE
(
num_dims
,
kMaxNumDims
);
int64_t
simplified_dst_dims
[
kMaxNumDims
];
int64_t
simplified_src_dims
[
kMaxNumDims
];
...
...
oneflow/core/ep/cuda/primitive/copy_nd.cu
View file @
a715222c
...
...
@@ -17,8 +17,11 @@ limitations under the License.
#include "oneflow/core/ep/include/primitive/copy_nd.h"
#include "oneflow/core/ep/common/primitive/copy_nd.h"
#include "oneflow/core/ep/cuda/cuda_stream.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#else
#include <cuda_runtime.h>
#endif
namespace
oneflow
{
namespace
ep
{
...
...
@@ -49,7 +52,7 @@ __global__ void CopyNdKernel(CopyNdKernelParams<num_dims, IndexType> params) {
template
<
size_t
num_dims
,
size_t
movement_size
,
typename
IndexType
>
void
LaunchKernel
(
Stream
*
stream
,
CopyNdKernelParams
<
num_dims
,
IndexType
>
params
)
{
cuda
Stream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
GPU
(
Stream_t
)
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
CopyNdKernel
<
num_dims
,
movement_size
,
IndexType
>
<<<
BlocksNum4ThreadsNum
(
params
.
count
),
kCudaThreadsNumPerBlock
,
0
,
cuda_stream
>>>
(
params
);
}
...
...
oneflow/core/ep/cuda/primitive/elementwise_unary.cu
View file @
a715222c
...
...
@@ -86,6 +86,10 @@ class ElementwiseUnaryFactoryImpl : public ElementwiseUnaryFactory {
UNARY_FLOATING_MATH_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)
// For Int Type OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_SAME_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_INT_MATH_OP_SEQ
,
CUDA_PRIMITIVE_INT_TYPE_SEQ
)
// For Utils OP
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_DIFFERENT_DTYPE_ELEMENTWISE_UNARY_ENTRY
,
UNARY_UTILS_OP_SEQ
,
UTIL_OPS_DATA_TYPE_SEQ
,
...
...
oneflow/core/ep/cuda/primitive/fill.cu
View file @
a715222c
...
...
@@ -71,20 +71,20 @@ nv_bfloat16 GetValue<nv_bfloat16>(Scalar value) {
#endif // CUDA_VERSION >= 11000
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
cuda
Stream_t
stream
,
T
*
dst
,
typename
std
::
enable_if
<
(
pack
!=
0
),
void
>::
type
LaunchPackFill
(
GPU
(
Stream_t
)
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
FillGpu
<
T
,
pack
>
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
dst
,
value
,
count
);
}
template
<
typename
T
,
size_t
pack
>
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
cuda
Stream_t
stream
,
T
*
dst
,
typename
std
::
enable_if
<
(
pack
==
0
),
void
>::
type
LaunchPackFill
(
GPU
(
Stream_t
)
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
LOG
(
FATAL
)
<<
"wrong alignment"
;
}
template
<
typename
T
>
void
LaunchFill
(
cuda
Stream_t
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
void
LaunchFill
(
GPU
(
Stream_t
)
stream
,
T
*
dst
,
T
value
,
size_t
count
)
{
auto
uintptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
dst
);
if
(
uintptr
%
16
==
0
)
{
LaunchPackFill
<
T
,
16
/
sizeof
(
T
)
>
(
stream
,
dst
,
value
,
count
);
...
...
@@ -107,7 +107,7 @@ class FillImpl : public Fill {
~
FillImpl
()
override
=
default
;
void
Launch
(
Stream
*
stream
,
void
*
dst
,
Scalar
value
,
size_t
count
)
override
{
cuda
Stream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
GPU
(
Stream_t
)
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
LaunchFill
<
T
>
(
cuda_stream
,
reinterpret_cast
<
T
*>
(
dst
),
GetValue
<
T
>
(
value
),
count
);
}
};
...
...
oneflow/core/ep/cuda/primitive/math_elementwise_unary_math_grad.cu
0 → 100644
View file @
a715222c
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/cuda/primitive/broadcast_elementwise_binary.cuh"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
#define INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
template std::unique_ptr<BroadcastElementwiseBinary> NewBroadcastElementwiseBinary< \
binary_op, OF_PP_PAIR_FIRST(data_type_pair), OF_PP_PAIR_FIRST(data_type_pair)>( \
Scalar attr0, Scalar attr1);
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
INSTANTIATE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
Prev
1
…
18
19
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment