#ifndef PYTORCH_DEVICE_REGISTRY_H #define PYTORCH_DEVICE_REGISTRY_H #include #include #include #include #include inline std::string GetDeviceStr(const at::Device& device) { std::string str = DeviceTypeName(device.type(), true); if (device.has_index()) { str.push_back(':'); str.append(std::to_string(device.index())); } return str; } // Registry template class DeviceRegistry; template class DeviceRegistry { public: using FunctionType = Ret (*)(Args...); static const int MAX_DEVICE_TYPES = int8_t(at::DeviceType::COMPILE_TIME_MAX_DEVICE_TYPES); void Register(at::DeviceType device, FunctionType function) { funcs_[int8_t(device)] = function; } FunctionType Find(at::DeviceType device) const { return funcs_[int8_t(device)]; } static DeviceRegistry& instance() { static DeviceRegistry inst; return inst; } private: DeviceRegistry() { for (size_t i = 0; i < MAX_DEVICE_TYPES; ++i) { funcs_[i] = nullptr; } }; FunctionType funcs_[MAX_DEVICE_TYPES]; }; // get device of first tensor param template , at::Tensor>::value, bool> = true> at::Device GetFirstTensorDevice(T&& t, Args&&... args) { return std::forward(t).device(); } template , at::Tensor>::value, bool> = true> at::Device GetFirstTensorDevice(T&& t, Args&&... args) { return GetFirstTensorDevice(std::forward(args)...); } // check device consistency inline std::pair CheckDeviceConsistency( const at::Device& device, int index) { return {index, device}; } template , at::Tensor>::value, bool> = true> std::pair CheckDeviceConsistency(const at::Device& device, int index, T&& t, Args&&... args); template , at::Tensor>::value, bool> = true> std::pair CheckDeviceConsistency(const at::Device& device, int index, T&& t, Args&&... args) { auto new_device = std::forward(t).device(); if (new_device.type() != device.type() || new_device.index() != device.index()) { return {index, new_device}; } return CheckDeviceConsistency(device, index + 1, std::forward(args)...); } template < typename T, typename... Args, std::enable_if_t, at::Tensor>::value, bool>> std::pair CheckDeviceConsistency(const at::Device& device, int index, T&& t, Args&&... args) { return CheckDeviceConsistency(device, index + 1, std::forward(args)...); } // dispatch template auto Dispatch(const R& registry, const char* name, Args&&... args) { auto device = GetFirstTensorDevice(std::forward(args)...); auto inconsist = CheckDeviceConsistency(device, 0, std::forward(args)...); TORCH_CHECK(inconsist.first >= int(sizeof...(Args)), name, ": at param ", inconsist.first, ", inconsistent device: ", GetDeviceStr(inconsist.second).c_str(), " vs ", GetDeviceStr(device).c_str(), "\n") auto f_ptr = registry.Find(device.type()); TORCH_CHECK(f_ptr != nullptr, name, ": implementation for device ", GetDeviceStr(device).c_str(), " not found.\n") return f_ptr(std::forward(args)...); } // helper macro #define DEVICE_REGISTRY(key) DeviceRegistry::instance() #define REGISTER_DEVICE_IMPL(key, device, value) \ struct key##_##device##_registerer { \ key##_##device##_registerer() { \ DEVICE_REGISTRY(key).Register(at::k##device, value); \ } \ }; \ static key##_##device##_registerer _##key##_##device##_registerer; #define DISPATCH_DEVICE_IMPL(key, ...) \ Dispatch(DEVICE_REGISTRY(key), #key, __VA_ARGS__) #endif // PYTORCH_DEVICE_REGISTRY