Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Oneflow
Commits
a715222c
Commit
a715222c
authored
Feb 28, 2023
by
yuguo
Browse files
0.9.1-rocm
parent
f262efc9
Changes
469
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
0 additions
and
1025 deletions
+0
-1025
oneflow/core/ep/rocm/cuda_device_manager.h
oneflow/core/ep/rocm/cuda_device_manager.h
+0
-54
oneflow/core/ep/rocm/cuda_device_manager_factory.cpp
oneflow/core/ep/rocm/cuda_device_manager_factory.cpp
+0
-117
oneflow/core/ep/rocm/cuda_event.cpp
oneflow/core/ep/rocm/cuda_event.cpp
+0
-56
oneflow/core/ep/rocm/cuda_event.h
oneflow/core/ep/rocm/cuda_event.h
+0
-50
oneflow/core/ep/rocm/cuda_stream.cpp
oneflow/core/ep/rocm/cuda_stream.cpp
+0
-180
oneflow/core/ep/rocm/cuda_stream.h
oneflow/core/ep/rocm/cuda_stream.h
+0
-168
oneflow/core/ep/rocm/primitive/add.hip.cpp
oneflow/core/ep/rocm/primitive/add.hip.cpp
+0
-139
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
+0
-151
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
...re/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
+0
-110
No files found.
Too many changes to show.
To preserve performance only
469 of 469+
files are displayed.
Plain diff
Email patch
oneflow/core/ep/rocm/cuda_device_manager.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
#include "oneflow/core/ep/include/device_manager.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
class
CudaDevice
;
class
CudaDeviceManager
:
public
DeviceManager
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaDeviceManager
);
CudaDeviceManager
(
DeviceManagerRegistry
*
registry
);
~
CudaDeviceManager
()
override
;
DeviceManagerRegistry
*
registry
()
const
override
;
std
::
shared_ptr
<
Device
>
GetDevice
(
size_t
device_index
)
override
;
size_t
GetDeviceCount
(
size_t
primary_device_index
)
override
;
size_t
GetDeviceCount
()
override
;
size_t
GetActiveDeviceIndex
()
override
;
void
SetActiveDeviceByIndex
(
size_t
device_index
)
override
;
private:
std
::
mutex
devices_mutex_
;
std
::
vector
<
std
::
shared_ptr
<
CudaDevice
>>
devices_
;
DeviceManagerRegistry
*
registry_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_DEVICE_MANAGER_H_
oneflow/core/ep/rocm/cuda_device_manager_factory.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/device_manager_factory.h"
#include "oneflow/core/ep/include/device_manager_registry.h"
#include "oneflow/core/ep/rocm/cuda_device_manager.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include <miopen/miopen.h>
#include <rccl.h>
namespace
oneflow
{
namespace
ep
{
namespace
{
std
::
string
GetCudaVersionString
(
int
version
)
{
return
std
::
to_string
(
version
/
1000
)
+
"."
+
std
::
to_string
((
version
%
1000
)
/
10
);
}
bool
GetCudnnVersion
(
size_t
*
major
,
size_t
*
minor
,
size_t
*
patch
)
{
miopenStatus_t
status
=
miopenGetVersion
(
major
,
minor
,
patch
);
if
(
status
==
miopenStatusSuccess
)
{
return
true
;
}
else
{
LOG
(
ERROR
)
<<
"Failed to get cuDNN version: "
<<
miopenGetErrorString
(
status
);
return
false
;
}
}
bool
GetCudnnVersionString
(
std
::
string
*
version
)
{
size_t
version_major
=
0
;
size_t
version_minor
=
0
;
size_t
version_patch
=
0
;
if
(
!
GetCudnnVersion
(
&
version_major
,
&
version_minor
,
&
version_patch
))
{
return
false
;
}
*
version
=
std
::
to_string
(
version_major
)
+
"."
+
std
::
to_string
(
version_minor
)
+
"."
+
std
::
to_string
(
version_patch
);
return
true
;
}
void
CudaDumpVersionInfo
()
{
{
int
cuda_runtime_version
=
0
;
hipError_t
err
=
hipRuntimeGetVersion
(
&
cuda_runtime_version
);
if
(
err
==
hipSuccess
)
{
LOG
(
INFO
)
<<
"CUDA runtime version: "
<<
GetCudaVersionString
(
cuda_runtime_version
);
}
else
{
LOG
(
ERROR
)
<<
"Failed to get cuda runtime version: "
<<
hipGetErrorString
(
err
);
}
}
{
std
::
string
cudnn_version_string
;
if
(
GetCudnnVersionString
(
&
cudnn_version_string
))
{
LOG
(
INFO
)
<<
"cuDNN version: "
<<
cudnn_version_string
;
}
}
{
int
nccl_version
=
0
;
ncclResult_t
result
=
ncclGetVersion
(
&
nccl_version
);
if
(
result
==
ncclSuccess
)
{
int
nccl_version_major
=
(
nccl_version
>=
20900
)
?
(
nccl_version
/
10000
)
:
(
nccl_version
/
1000
);
int
nccl_version_minor
=
(
nccl_version
>=
20900
)
?
(
nccl_version
%
10000
)
/
100
:
(
nccl_version
%
1000
)
/
100
;
int
nccl_version_patch
=
(
nccl_version
%
100
);
LOG
(
INFO
)
<<
"NCCL version: "
<<
nccl_version_major
<<
"."
<<
nccl_version_minor
<<
"."
<<
nccl_version_patch
;
}
else
{
LOG
(
ERROR
)
<<
"Failed to get NCCL version: "
<<
ncclGetErrorString
(
result
);
}
}
}
class
CudaDeviceManagerFactory
:
public
DeviceManagerFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaDeviceManagerFactory
);
CudaDeviceManagerFactory
()
=
default
;
~
CudaDeviceManagerFactory
()
override
=
default
;
std
::
unique_ptr
<
DeviceManager
>
NewDeviceManager
(
DeviceManagerRegistry
*
registry
)
override
{
return
std
::
make_unique
<
CudaDeviceManager
>
(
registry
);
}
DeviceType
device_type
()
const
override
{
return
DeviceType
::
kCUDA
;
}
std
::
string
device_type_name
()
const
override
{
return
"cuda"
;
}
void
DumpVersionInfo
()
const
override
{
CudaDumpVersionInfo
();
}
};
COMMAND
(
DeviceManagerRegistry
::
RegisterDeviceManagerFactory
(
std
::
make_unique
<
CudaDeviceManagerFactory
>
()))
}
// namespace
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/rocm/cuda_event.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_event.h"
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
CudaEvent
::
CudaEvent
(
unsigned
int
flags
)
:
cuda_event_
{}
{
OF_CUDA_CHECK
(
hipEventCreateWithFlags
(
&
cuda_event_
,
flags
));
}
CudaEvent
::~
CudaEvent
()
{
OF_CUDA_CHECK
(
hipEventDestroy
(
cuda_event_
));
}
Maybe
<
bool
>
CudaEvent
::
QueryDone
()
{
hipError_t
err
=
hipEventQuery
(
cuda_event_
);
if
(
err
==
hipSuccess
)
{
return
Maybe
<
bool
>
(
true
);
}
else
if
(
err
==
hipErrorNotReady
)
{
return
Maybe
<
bool
>
(
false
);
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
);
}
}
Maybe
<
void
>
CudaEvent
::
Sync
()
{
hipError_t
err
=
hipEventSynchronize
(
cuda_event_
);
if
(
err
==
hipSuccess
)
{
return
Maybe
<
void
>::
Ok
();
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
);
}
}
hipEvent_t
CudaEvent
::
cuda_event
()
{
return
cuda_event_
;
}
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/rocm/cuda_event.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
#include "oneflow/core/ep/include/event.h"
#ifdef WITH_ROCM
#include "oneflow/core/device/cuda_util.h"
namespace
oneflow
{
namespace
ep
{
class
CudaEvent
:
public
Event
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaEvent
);
explicit
CudaEvent
(
unsigned
int
flags
);
~
CudaEvent
()
override
;
Maybe
<
bool
>
QueryDone
()
override
;
Maybe
<
void
>
Sync
()
override
;
hipEvent_t
cuda_event
();
private:
hipEvent_t
cuda_event_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_EVENT_H_
oneflow/core/ep/rocm/cuda_stream.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/job/global_for.h"
#include "oneflow/core/job/resource_desc.h"
#include "oneflow/core/hardware/node_device_descriptor_manager.h"
#include "oneflow/core/hardware/cuda_device_descriptor.h"
#include "oneflow/core/ep/rocm/cuda_event.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
namespace
oneflow
{
namespace
ep
{
namespace
{
constexpr
size_t
kDefaultWorkspaceSize
=
4
*
1024
*
1024
;
// 4M
void
SetAffinityByDevice
(
int
dev_id
)
{
auto
node_device_desc_mgr
=
Singleton
<
hardware
::
NodeDeviceDescriptorManager
>::
Get
();
if
(
node_device_desc_mgr
==
nullptr
)
{
return
;
}
auto
node_device_desc
=
node_device_desc_mgr
->
GetLocalNodeDeviceDescriptor
();
auto
cuda_device
=
std
::
dynamic_pointer_cast
<
const
hardware
::
CudaDeviceDescriptor
>
(
node_device_desc
->
GetDevice
(
hardware
::
kCudaDeviceDescriptorClassName
,
dev_id
));
if
(
!
cuda_device
)
{
return
;
}
node_device_desc
->
Topology
()
->
SetCPUAffinityByPCIBusID
(
cuda_device
->
PCIBusID
());
node_device_desc
->
Topology
()
->
SetMemoryAffinityByPCIBusID
(
cuda_device
->
PCIBusID
());
}
}
// namespace
#ifdef WITH_ROCM_GRAPHS
CudaGraphExecutable
::
CudaGraphExecutable
()
:
graph_exec_
(
nullptr
),
dev_
(
-
1
)
{}
CudaGraphExecutable
::~
CudaGraphExecutable
()
{
Reset
();
}
void
CudaGraphExecutable
::
Update
(
hipGraph_t
graph
)
{
int
dev
=
-
1
;
OF_CUDA_CHECK
(
hipGetDevice
(
&
dev
));
if
(
dev
!=
dev_
)
{
Reset
();
}
dev_
=
dev
;
if
(
graph_exec_
!=
nullptr
)
{
hipGraphExecUpdateResult
update_result
{};
hipGraphNode_t
error_node
=
nullptr
;
OF_CUDA_CHECK
(
hipGraphExecUpdate
(
graph_exec_
,
graph
,
&
error_node
,
&
update_result
));
if
(
update_result
==
hipGraphExecUpdateSuccess
)
{
return
;
}
}
Reset
();
OF_CUDA_CHECK
(
hipGraphInstantiate
(
&
graph_exec_
,
graph
,
NULL
,
NULL
,
0
));
}
void
CudaGraphExecutable
::
Launch
(
hipStream_t
stream
)
const
{
OF_CUDA_CHECK
(
hipGraphLaunch
(
graph_exec_
,
stream
));
}
bool
CudaGraphExecutable
::
IsInstantiated
()
const
{
return
graph_exec_
!=
nullptr
;
}
void
CudaGraphExecutable
::
Reset
()
{
if
(
graph_exec_
!=
nullptr
)
{
CudaCurrentDeviceGuard
guard
(
dev_
);
OF_CUDA_CHECK
(
hipGraphExecDestroy
(
graph_exec_
));
}
}
#endif // WITH_ROCM_GRAPHS
CudaStream
::
CudaStream
(
CudaDevice
*
device
)
:
device_index_
(
device
->
device_index
()),
device_
(
device
)
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
// cuda_stream
OF_CUDA_CHECK
(
hipStreamCreate
(
&
cuda_stream_
));
// cublas_handle
OF_CUBLAS_CHECK
(
hipblasCreate
(
&
cublas_handle_
));
OF_CUBLAS_CHECK
(
hipblasSetStream
(
cublas_handle_
,
cuda_stream_
));
workspace_size_
=
kDefaultWorkspaceSize
;
OF_CUDA_CHECK
(
hipMalloc
(
&
workspace_
,
workspace_size_
));
OF_CUDNN_CHECK
(
hipdnnCreate
(
&
cudnn_handle_
));
OF_CUDNN_CHECK
(
hipdnnSetStream
(
cudnn_handle_
,
cuda_stream_
));
}
CudaStream
::~
CudaStream
()
{
CudaCurrentDeviceGuard
guard
(
device_index_
);
OF_CUDA_CHECK
(
hipStreamSynchronize
(
cuda_stream_
));
OF_CUDNN_CHECK
(
hipdnnDestroy
(
cudnn_handle_
));
OF_CUBLAS_CHECK
(
hipblasDestroy
(
cublas_handle_
));
OF_CUDA_CHECK
(
hipStreamDestroy
(
cuda_stream_
));
OF_CUDA_CHECK
(
hipFree
(
workspace_
));
}
Maybe
<
void
>
CudaStream
::
OnExecutionContextSetup
()
{
OF_CUDA_CHECK
(
hipSetDevice
(
device_index_
));
SetAffinityByDevice
(
device_index_
);
return
Maybe
<
void
>::
Ok
();
}
Maybe
<
void
>
CudaStream
::
OnExecutionContextTeardown
()
{
return
Maybe
<
void
>::
Ok
();
}
DeviceType
CudaStream
::
device_type
()
const
{
return
DeviceType
::
kCUDA
;
}
CudaDevice
*
CudaStream
::
device
()
const
{
return
device_
;
}
Maybe
<
void
>
CudaStream
::
Sync
()
{
hipError_t
err
=
hipStreamSynchronize
(
cuda_stream_
);
if
(
err
==
hipSuccess
)
{
return
Maybe
<
void
>::
Ok
();
}
else
{
return
Error
::
RuntimeError
()
<<
hipGetErrorString
(
err
)
<<
" ("
<<
err
<<
") "
;
}
}
void
CudaStream
::
RecordEvent
(
Event
*
event
)
{
auto
*
cuda_event
=
static_cast
<
CudaEvent
*>
(
event
);
// NOLINT
OF_CUDA_CHECK
(
hipEventRecord
(
cuda_event
->
cuda_event
(),
cuda_stream_
));
}
hipStream_t
CudaStream
::
cuda_stream
()
const
{
return
cuda_stream_
;
}
hipblasHandle_t
CudaStream
::
cublas_handle
()
const
{
return
cublas_handle_
;
}
void
*
CudaStream
::
cublas_workspace
()
const
{
return
workspace_
;
}
size_t
CudaStream
::
cublas_workspace_size
()
const
{
return
workspace_size_
;
}
hipdnnHandle_t
CudaStream
::
cudnn_handle
()
const
{
return
cudnn_handle_
;
}
const
hipDeviceProp_t
&
CudaStream
::
device_properties
()
const
{
return
device_
->
properties
();
}
int
CudaStream
::
cuda_arch
()
const
{
return
device_
->
properties
().
major
*
100
+
device_
->
properties
().
minor
*
10
;
}
#ifdef WITH_ROCM_GRAPHS
void
CudaStream
::
BeginGraphCapture
()
{
CHECK
(
!
is_graph_capturing_
);
is_graph_capturing_
=
true
;
OF_CUDA_CHECK
(
hipStreamBeginCapture
(
cuda_stream_
,
hipStreamCaptureModeThreadLocal
));
}
void
CudaStream
::
EndGraphCapture
(
CudaGraphExecutable
*
executable
)
{
hipGraph_t
graph
=
nullptr
;
OF_CUDA_CHECK
(
hipStreamEndCapture
(
cuda_stream_
,
&
graph
));
executable
->
Update
(
graph
);
OF_CUDA_CHECK
(
hipGraphDestroy
(
graph
));
is_graph_capturing_
=
false
;
}
bool
CudaStream
::
IsGraphCapturing
()
const
{
return
is_graph_capturing_
;
}
void
CudaStream
::
LaunchGraph
(
const
CudaGraphExecutable
*
executable
)
{
executable
->
Launch
(
cuda_stream_
);
}
#endif // WITH_ROCM_GRAPHS
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
oneflow/core/ep/rocm/cuda_stream.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#ifndef ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#define ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
#include "oneflow/core/ep/include/stream.h"
#include "oneflow/core/ep/rocm/cuda_device.h"
#ifdef WITH_ROCM
#include <hip/hip_runtime.h>
#include "oneflow/core/hipdnn/hipdnn.h"
// #if CUDA_VERSION >= 11000
// #define WITH_ROCM_GRAPHS
// #endif // CUDA_VERSION >= 11000
#include "oneflow/core/device/cuda_util.h"
namespace
oneflow
{
namespace
ep
{
class
CudaDevice
;
#ifdef WITH_ROCM_GRAPHS
class
CudaGraphExecutable
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaGraphExecutable
);
CudaGraphExecutable
();
~
CudaGraphExecutable
();
void
Update
(
hipGraph_t
graph
);
void
Launch
(
hipStream_t
stream
)
const
;
bool
IsInstantiated
()
const
;
private:
void
Reset
();
hipGraphExec_t
graph_exec_
;
int
dev_
;
};
#endif // WITH_ROCM_GRAPHS
struct
CudaLaunchConfig
{
dim3
grid_dim
;
dim3
block_dim
;
size_t
shared_mem_size
;
CudaLaunchConfig
()
:
grid_dim
{},
block_dim
{},
shared_mem_size
(
0
)
{}
CudaLaunchConfig
(
unsigned
int
grid_size
,
unsigned
int
block_size
,
size_t
shared_mem_size
)
:
grid_dim
(
grid_size
),
block_dim
(
block_size
),
shared_mem_size
(
shared_mem_size
)
{}
};
class
CudaStream
:
public
Stream
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
CudaStream
);
explicit
CudaStream
(
CudaDevice
*
device
);
~
CudaStream
()
override
;
static
constexpr
uint32_t
kDefaultBlockSize
=
256
;
DeviceType
device_type
()
const
override
;
CudaDevice
*
device
()
const
override
;
Maybe
<
void
>
Sync
()
override
;
void
RecordEvent
(
Event
*
event
)
override
;
Maybe
<
void
>
OnExecutionContextSetup
()
override
;
Maybe
<
void
>
OnExecutionContextTeardown
()
override
;
hipStream_t
cuda_stream
()
const
;
hipblasHandle_t
cublas_handle
()
const
;
// #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle() const;
// #endif
hipdnnHandle_t
cudnn_handle
()
const
;
void
*
cublas_workspace
()
const
;
size_t
cublas_workspace_size
()
const
;
const
hipDeviceProp_t
&
device_properties
()
const
;
int
cuda_arch
()
const
;
void
InitLaunchConfigWithWaves
(
CudaLaunchConfig
*
config
,
size_t
elem_cnt
,
size_t
block_size
,
size_t
max_waves
)
const
{
const
uint32_t
max_grid_size
=
max_waves
*
device_properties
().
multiProcessorCount
*
(
device_properties
().
maxThreadsPerMultiProcessor
/
block_size
);
const
uint32_t
grid_size
=
std
::
min
<
uint32_t
>
(
max_grid_size
,
(
elem_cnt
+
block_size
-
1
)
/
block_size
);
config
->
grid_dim
=
dim3
(
grid_size
);
config
->
block_dim
=
dim3
(
block_size
);
config
->
shared_mem_size
=
0
;
}
#ifdef __HIPCC__
template
<
typename
...
Params
,
typename
...
Args
>
void
LaunchKernel
(
void
(
*
kernel
)(
Params
...),
const
CudaLaunchConfig
&
launch_config
,
Args
...
args
)
{
kernel
<<<
launch_config
.
grid_dim
,
launch_config
.
block_dim
,
launch_config
.
shared_mem_size
,
cuda_stream
()
>>>
(
args
...);
}
template
<
typename
...
Params
,
typename
...
Args
>
void
LaunchKernel
(
void
(
*
kernel
)(
Params
...),
size_t
elem_cnt
,
size_t
max_waves
,
Args
...
args
)
{
constexpr
uint32_t
block_size
=
kDefaultBlockSize
;
CudaLaunchConfig
config
{};
InitLaunchConfigWithWaves
(
&
config
,
elem_cnt
,
block_size
,
max_waves
);
LaunchKernel
(
kernel
,
config
,
args
...);
}
template
<
typename
...
Params
,
typename
...
Args
>
void
LaunchKernelDefaultWaves
(
void
(
*
kernel
)(
Params
...),
size_t
elem_cnt
,
Args
...
args
)
{
const
size_t
default_waves
=
32
;
LaunchKernel
(
kernel
,
elem_cnt
,
default_waves
,
args
...);
}
#endif // __HIPCC__
#ifdef WITH_ROCM_GRAPHS
void
BeginGraphCapture
();
void
EndGraphCapture
(
CudaGraphExecutable
*
executable
);
bool
IsGraphCapturing
()
const
;
void
LaunchGraph
(
const
CudaGraphExecutable
*
executable
);
#endif // WITH_ROCM_GRAPHS
private:
hipStream_t
cuda_stream_
{};
hipblasHandle_t
cublas_handle_
{};
// #if CUDA_VERSION >= 10010
// cublasLtHandle_t cublas_lt_handle_{};
// #endif
hipdnnHandle_t
cudnn_handle_
{};
int
device_index_
;
void
*
workspace_
{};
size_t
workspace_size_
{};
#ifdef WITH_ROCM_GRAPHS
bool
is_graph_capturing_
{};
#endif // WITH_ROCM_GRAPHS
CudaDevice
*
device_
;
};
}
// namespace ep
}
// namespace oneflow
#endif // WITH_ROCM
#endif // ONEFLOW_CORE_EP_ROCM_CUDA_STREAM_H_
oneflow/core/ep/rocm/primitive/add.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/common/preprocessor.h"
#include "oneflow/core/ep/include/primitive/add.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/device/cuda_pseudo_bfloat16.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
{
template
<
typename
...
Args
>
struct
AddFunctor
;
template
<
typename
T
>
struct
AddFunctor
<
T
>
{
__device__
T
operator
()(
T
x
)
const
{
return
x
;
}
};
template
<
typename
T
,
typename
U
,
typename
...
Args
>
struct
AddFunctor
<
T
,
U
,
Args
...
>
{
__device__
T
operator
()(
T
x0
,
U
x1
,
Args
...
xs
)
const
{
return
x0
+
AddFunctor
<
U
,
Args
...
>
()(
x1
,
xs
...);
}
};
template
<
typename
T
,
typename
...
Args
>
__global__
void
AddGpu
(
const
Args
*
...
srcs
,
T
*
dst
,
size_t
count
)
{
CUDA_1D_KERNEL_LOOP_T
(
size_t
,
i
,
count
)
{
dst
[
i
]
=
AddFunctor
<
Args
...
>
()(
srcs
[
i
]...);
}
}
template
<
typename
T
,
typename
...
Args
>
void
LaunchAddGpu
(
hipStream_t
stream
,
const
Args
*
...
srcs
,
T
*
dst
,
size_t
count
)
{
AddGpu
<
T
,
Args
...
>
<<<
BlocksNum4ThreadsNum
(
count
),
kCudaThreadsNumPerBlock
,
0
,
stream
>>>
(
srcs
...,
dst
,
count
);
}
template
<
typename
T
>
void
DispatchLaunch
(
hipStream_t
stream
,
const
T
*
const
*
srcs
,
size_t
arity
,
T
*
dst
,
size_t
count
)
{
if
(
arity
==
0
)
{
OF_CUDA_CHECK
(
hipMemsetAsync
(
dst
,
0
,
count
*
sizeof
(
T
),
stream
));
}
else
if
(
arity
==
1
)
{
OF_CUDA_CHECK
(
hipMemcpyAsync
(
dst
,
srcs
[
0
],
count
*
sizeof
(
T
),
hipMemcpyDefault
,
stream
));
}
else
if
(
arity
==
2
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Binary
<
AddFunctor
<
T
,
T
>
,
T
,
T
,
T
>
(
AddFunctor
<
T
,
T
>
(),
count
,
dst
,
srcs
[
0
],
srcs
[
1
],
stream
)));
}
else
if
(
arity
==
3
)
{
OF_CUDA_CHECK
((
cuda
::
elementwise
::
Ternary
<
AddFunctor
<
T
,
T
,
T
>
,
T
,
T
,
T
,
T
>
(
AddFunctor
<
T
,
T
,
T
>
(),
count
,
dst
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
stream
)));
}
else
if
(
arity
==
4
)
{
LaunchAddGpu
<
T
,
T
,
T
,
T
,
T
>
(
stream
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
srcs
[
3
],
dst
,
count
);
}
else
if
(
arity
==
5
)
{
LaunchAddGpu
<
T
,
T
,
T
,
T
,
T
,
T
>
(
stream
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
srcs
[
3
],
srcs
[
4
],
dst
,
count
);
}
else
if
(
arity
==
6
)
{
LaunchAddGpu
<
T
,
T
,
T
,
T
,
T
,
T
,
T
>
(
stream
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
srcs
[
3
],
srcs
[
4
],
srcs
[
5
],
dst
,
count
);
}
else
if
(
arity
==
7
)
{
LaunchAddGpu
<
T
,
T
,
T
,
T
,
T
,
T
,
T
,
T
>
(
stream
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
srcs
[
3
],
srcs
[
4
],
srcs
[
5
],
srcs
[
6
],
dst
,
count
);
}
else
if
(
arity
==
8
)
{
LaunchAddGpu
<
T
,
T
,
T
,
T
,
T
,
T
,
T
,
T
,
T
>
(
stream
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
srcs
[
3
],
srcs
[
4
],
srcs
[
5
],
srcs
[
6
],
srcs
[
7
],
dst
,
count
);
}
else
{
DispatchLaunch
(
stream
,
srcs
+
7
,
arity
-
7
,
dst
,
count
);
LaunchAddGpu
<
T
,
T
,
T
,
T
,
T
,
T
,
T
,
T
,
T
>
(
stream
,
srcs
[
0
],
srcs
[
1
],
srcs
[
2
],
srcs
[
3
],
srcs
[
4
],
srcs
[
5
],
srcs
[
6
],
dst
,
dst
,
count
);
}
}
template
<
typename
T
>
class
AddImpl
:
public
Add
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
AddImpl
);
AddImpl
()
=
default
;
~
AddImpl
()
override
=
default
;
using
Add
::
Launch
;
void
Launch
(
Stream
*
stream
,
const
void
*
const
*
srcs
,
size_t
arity
,
void
*
dst
,
size_t
count
)
override
{
hipStream_t
cuda_stream
=
stream
->
As
<
CudaStream
>
()
->
cuda_stream
();
DispatchLaunch
(
cuda_stream
,
reinterpret_cast
<
const
T
*
const
*>
(
srcs
),
arity
,
reinterpret_cast
<
T
*>
(
dst
),
count
);
}
};
template
<
typename
T
>
std
::
unique_ptr
<
Add
>
NewAdd
()
{
return
std
::
unique_ptr
<
Add
>
(
new
AddImpl
<
T
>
());
}
class
AddFactoryImpl
:
public
AddFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
AddFactoryImpl
);
AddFactoryImpl
()
=
default
;
~
AddFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
Add
>
New
(
DataType
data_type
)
override
{
#define MAKE_NEW_ADD_ENTRY(type_cpp, type_proto) {type_proto, NewAdd<type_cpp>},
static
const
std
::
map
<
DataType
,
std
::
function
<
std
::
unique_ptr
<
Add
>
()
>>
new_add_handle
{
OF_PP_FOR_EACH_TUPLE
(
MAKE_NEW_ADD_ENTRY
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)};
#undef MAKE_NEW_ADD_ENTRY
const
auto
it
=
new_add_handle
.
find
(
data_type
);
if
(
it
!=
new_add_handle
.
end
())
{
return
it
->
second
();
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
AddFactory
,
AddFactoryImpl
);
}
// namespace
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
oneflow/core/ep/rocm/primitive/binary_functor.hip.h
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/common/primitive/binary_functor.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
src0
,
Src
src1
)
const
{
return
pow
(
src0
,
src1
);
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
bool
,
bool
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
bool
operator
()(
bool
src0
,
bool
src1
)
const
{
return
static_cast
<
bool
>
(
pow
(
static_cast
<
double
>
(
src0
),
static_cast
<
double
>
(
src1
)));
}
};
template
<
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kPow
,
half
,
half
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
return
static_cast
<
half
>
(
pow
(
static_cast
<
float
>
(
src0
),
static_cast
<
float
>
(
src1
)));
}
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kGeluBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{
#if defined(__CUDA_ARCH__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#elif defined(__HIP_DEVICE_COMPILE__)
coef
=
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#else
coef
=
std
::
sqrt
(
static_cast
<
Src
>
(
2.0
)
/
std
::
acos
(
static_cast
<
Src
>
(
-
1.0
)));
#endif
}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
return
static_cast
<
Src
>
(
0.5
)
*
(
static_cast
<
Src
>
(
1.0
)
+
erf
(
static_cast
<
Src
>
(
M_SQRT1_2
)
*
x
)
+
x
*
coef
*
exp
(
static_cast
<
Src
>
(
-
0.5
)
*
x
*
x
))
*
dy
;
}
Src
coef
;
};
template
<
typename
Src
,
typename
Dst
>
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
BinaryOp
::
kTanhBackwardWithDyX
,
Src
,
Dst
>
{
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
{}
OF_DEVICE_FUNC
Dst
operator
()(
Src
dy
,
Src
x
)
const
{
Src
tanh_val
=
tanh
(
x
);
return
static_cast
<
Dst
>
(
dy
*
(
static_cast
<
Src
>
(
1.0
)
-
tanh_val
*
tanh_val
));
}
};
// /*********nv_bfloat16_kernel*******/
// #if CUDA_VERSION >= 11000
// template<>
// struct BinaryFunctor<DeviceType::kCUDA, BinaryOp::kPow, nv_bfloat16, nv_bfloat16> {
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) {}
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const {
// return static_cast<nv_bfloat16>(pow(static_cast<float>(src0), static_cast<float>(src1)));
// }
// };
// #define SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(op) \
// template<> \
// struct BinaryFunctor<DeviceType::kCUDA, op, nv_bfloat16, nv_bfloat16> { \
// OF_DEVICE_FUNC BinaryFunctor(Scalar attr0, Scalar attr1) : float_functor(attr0, attr1) {} \
// \
// BinaryFunctor<DeviceType::kCUDA, op, float, float> float_functor; \
// OF_DEVICE_FUNC nv_bfloat16 operator()(nv_bfloat16 src0, nv_bfloat16 src1) const { \
// return __float2bfloat16(float_functor(__bfloat162float(src0), __bfloat162float(src1))); \
// } \
// };
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kEluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kCeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kGeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardswishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardsigmoidBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kHardtanhBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kLeakyReluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kMishBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSeluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSiluBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftsignBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftplusBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kSoftshrinkBackwardWithDyY);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kTanhBackwardWithDyX);
// SPECIALIZATION_PSEUDO_BFLOAT16_BINARY_FUNCTOR(BinaryOp::kThresholdBackwardWithDyX);
// #endif // CUDA_VERSION >= 11000
#define SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR(op) \
template
<
>
\
struct
BinaryFunctor
<
DeviceType
::
kCUDA
,
op
,
half
,
half
>
{
\
OF_DEVICE_FUNC
BinaryFunctor
(
Scalar
attr0
,
Scalar
attr1
)
:
float_functor
(
attr0
,
attr1
)
{}
\
\
BinaryFunctor
<
DeviceType
::
kCUDA
,
op
,
float
,
float
>
float_functor
;
\
OF_DEVICE_FUNC
half
operator
()(
half
src0
,
half
src1
)
const
{
\
return
__float2half
(
float_functor
(
__half2float
(
src0
),
__half2float
(
src1
)));
\
}
\
};
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kEluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kCeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kGeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardswishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kHardshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kMishBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSiluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSeluBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftplusBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftsignBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kSoftshrinkBackwardWithDyY
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kThresholdBackwardWithDyX
);
SPECIALIZATION_PSEUDO_HALF_BINARY_FUNCTOR
(
BinaryOp
::
kTanhBackwardWithDyX
);
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
oneflow/core/ep/rocm/primitive/broadcast_elementwise_binary.hip.cpp
deleted
100644 → 0
View file @
f262efc9
/*
Copyright 2020 The OneFlow Authors. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "oneflow/core/ep/include/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/common/primitive/broadcast_elementwise_binary.h"
#include "oneflow/core/ep/rocm/primitive/type_seq.h"
#include "oneflow/core/ep/rocm/cuda_stream.h"
#include "oneflow/core/hip/elementwise.hip.h"
#include "oneflow/core/ep/rocm/primitive/binary_functor.hip.h"
namespace
oneflow
{
namespace
ep
{
namespace
primitive
{
namespace
broadcast_elementwise_binary
{
template
<
BinaryOp
binary_op
,
typename
Src
,
typename
Dst
>
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
NewBroadcastElementwiseBinary
(
Scalar
attr0
,
Scalar
attr1
);
namespace
{
class
BroadcastElementwiseBinaryFactoryImpl
:
public
BroadcastElementwiseBinaryFactory
{
public:
OF_DISALLOW_COPY_AND_MOVE
(
BroadcastElementwiseBinaryFactoryImpl
);
BroadcastElementwiseBinaryFactoryImpl
()
=
default
;
~
BroadcastElementwiseBinaryFactoryImpl
()
override
=
default
;
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
Scalar
(),
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
)
override
{
return
New
(
op
,
src_type
,
dst_type
,
max_num_dims
,
attr0
,
Scalar
());
}
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
New
(
BinaryOp
binary_op
,
DataType
src_type
,
DataType
dst_type
,
size_t
max_num_dims
,
Scalar
attr0
,
Scalar
attr1
)
override
{
if
(
max_num_dims
>
kMaxNumDims
)
{
return
nullptr
;
}
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY(binary_op, data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
data_type_pair
),
\
OF_PP_PAIR_SECOND
(
data_type_pair
)),
\
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
\
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY( \
binary_op
,
src_data_type_pair
,
dst_data_type_pair
)
\
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
src_data_type_pair
),
\
OF_PP_PAIR_SECOND
(
dst_data_type_pair
)),
\
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
src_data_type_pair
),
\
OF_PP_PAIR_FIRST
(
dst_data_type_pair
)
>
},
#define MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY(binary_op, data_type_pair) \
{
std
::
make_tuple
(
binary_op
,
OF_PP_PAIR_SECOND
(
data_type_pair
),
\
OF_PP_PAIR_SECOND
(
data_type_pair
)),
\
NewBroadcastElementwiseBinary
<
binary_op
,
OF_PP_PAIR_FIRST
(
data_type_pair
),
\
OF_PP_PAIR_FIRST
(
data_type_pair
)
>
},
static
const
std
::
map
<
std
::
tuple
<
BinaryOp
,
DataType
,
DataType
>
,
std
::
function
<
std
::
unique_ptr
<
BroadcastElementwiseBinary
>
(
Scalar
,
Scalar
)
>>
new_broadcast_elementwise_binary_handle
{
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
,
BINARY_MATH_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
,
BINARY_COMPARISION_OP_SEQ
BINARY_LOGICAL_OP_SEQ
,
CUDA_PRIMITIVE_ALL_TYPE_SEQ
,
CUDA_PRIMITIVE_BOOL_TYPE_SEQ
)
OF_PP_SEQ_PRODUCT_FOR_EACH_TUPLE
(
MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_ACTIVATION_GRAD_ENTRY
,
BINARY_ACTIVATION_BACKWARD_OP_SEQ
,
CUDA_PRIMITIVE_FLOATING_TYPE_SEQ
)};
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_COMPARASION_AND_LOGICAL_ENTRY
#undef MAKE_NEW_BROADCAST_ELEMENTWISE_BINARY_MATH_ENTRY
const
auto
it
=
new_broadcast_elementwise_binary_handle
.
find
(
std
::
make_tuple
(
binary_op
,
src_type
,
dst_type
));
if
(
it
!=
new_broadcast_elementwise_binary_handle
.
end
())
{
return
it
->
second
(
attr0
,
attr1
);
}
else
{
return
nullptr
;
}
}
};
REGISTER_PRIMITIVE_FACTORY
(
DeviceType
::
kCUDA
,
BroadcastElementwiseBinaryFactory
,
BroadcastElementwiseBinaryFactoryImpl
);
}
// namespace
}
// namespace broadcast_elementwise_binary
}
// namespace primitive
}
// namespace ep
}
// namespace oneflow
\ No newline at end of file
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment