Commit d9de5133 authored by YdrMaster's avatar YdrMaster
Browse files

issue/158/refactor: 支持天数的 pytorch 测试


Signed-off-by: default avatarYdrMaster <ydrml@hotmail.com>
parent e09a0b7c
#include "nvidia_handle.cuh"
namespace device::nvidia {
namespace device {
Handle::Handle(int device_id)
: InfiniopHandle{INFINI_DEVICE_NVIDIA, device_id},
namespace nvidia {
Handle::Handle(infiniDevice_t device, int device_id)
: InfiniopHandle{device, device_id},
_internal(std::make_shared<Handle::Internal>(device_id)) {}
auto Handle::internal() const -> const std::shared_ptr<Internal> & {
......@@ -83,9 +85,23 @@ cudnnDataType_t getCudnnDtype(infiniDtype_t dt) {
}
#endif
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(INFINI_DEVICE_NVIDIA, device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace nvidia
namespace iluvatar {
Handle::Handle(int device_id)
: nvidia::Handle(INFINI_DEVICE_ILUVATAR, device_id) {}
infiniStatus_t Handle::create(InfiniopHandle **handle_ptr, int device_id) {
*handle_ptr = new Handle(device_id);
return INFINI_STATUS_SUCCESS;
}
} // namespace device::nvidia
} // namespace iluvatar
} // namespace device
......@@ -4,13 +4,17 @@
#include "../../handle.h"
#include <memory>
namespace device::nvidia {
namespace device {
namespace nvidia {
struct Handle : public InfiniopHandle {
Handle(int device_id);
class Internal;
auto internal() const -> const std::shared_ptr<Internal> &;
protected:
Handle(infiniDevice_t device, int device_id);
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
......@@ -18,6 +22,19 @@ private:
std::shared_ptr<Internal> _internal;
};
} // namespace device::nvidia
} // namespace nvidia
namespace iluvatar {
struct Handle : public nvidia::Handle {
Handle(int device_id);
public:
static infiniStatus_t create(InfiniopHandle **handle_ptr, int device_id);
};
} // namespace iluvatar
} // namespace device
#endif // __INFINIOP_CUDA_HANDLE_H__
......@@ -262,6 +262,11 @@ def get_args():
action="store_true",
help="Run NVIDIA GPU test",
)
parser.add_argument(
"--iluvatar",
action="store_true",
help="Run Iluvatar GPU test",
)
parser.add_argument(
"--cambricon",
action="store_true",
......@@ -566,6 +571,8 @@ def get_test_devices(args):
devices_to_test.append(InfiniDeviceEnum.CPU)
if args.nvidia:
devices_to_test.append(InfiniDeviceEnum.NVIDIA)
if args.iluvatar:
devices_to_test.append(InfiniDeviceEnum.ILUVATAR)
if args.cambricon:
import torch_mlu
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment