Unverified Commit b3170335 authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

issue/411: adjust mudnn handle use to prevent mudnn handle mismatching

parents 1d064392 f34d4e3b
......@@ -21,11 +21,13 @@ class Handle::Internal {
_block_size[3],
_grid_size[3];
int _device_id;
template <typename T>
using Fn = std::function<infiniStatus_t(T)>;
public:
Internal(int);
Internal(int device_id);
infiniStatus_t useMublas(musaStream_t stream, const Fn<mublasHandle_t> &f) const;
infiniStatus_t useMudnn(musaStream_t stream, const Fn<::musa::dnn::Handle &> &f) const;
......
......@@ -11,7 +11,8 @@ auto Handle::internal() const -> const std::shared_ptr<Internal> & {
return _internal;
}
Handle::Internal::Internal(int device_id) {
Handle::Internal::Internal(int device_id)
: _device_id(device_id) {
musaDeviceProp prop;
musaGetDeviceProperties(&prop, device_id);
_warp_size = prop.warpSize;
......@@ -45,7 +46,7 @@ infiniStatus_t Handle::Internal::useMudnn(musaStream_t stream, const Fn<::musa::
if (opt_handle.has_value()) {
handle = std::move(*opt_handle);
} else {
handle = std::make_unique<::musa::dnn::Handle>();
handle = std::make_unique<::musa::dnn::Handle>(_device_id);
}
CHECK_MUDNN(handle->SetStream(stream));
CHECK_STATUS(f(*handle));
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment