nv_util.h 2.11 KB
Newer Older
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
/*
 * Copyright (c) 2023, NVIDIA CORPORATION.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#pragma once

19
#include <hip/hip_runtime_api.h>
20
21
22
23
24
25
26
27
28
29
30
31
32
33

#include <stdexcept>
#include <string>

#define CUDA_CHECK(val) \
  { nv::cuda_check_((val), __FILE__, __LINE__); }

namespace nv {

class CudaException : public std::runtime_error {
 public:
  CudaException(const std::string& what) : runtime_error(what) {}
};

34
35
inline void cuda_check_(hipError_t val, const char* file, int line) {
  if (val != hipSuccess) {
36
    throw CudaException(std::string(file) + ":" + std::to_string(line) + ": CUDA error " +
37
                        std::to_string(val) + ": " + hipGetErrorString(val));
38
39
40
41
42
  }
}

class CudaDeviceRestorer {
 public:
43
44
  CudaDeviceRestorer() { CUDA_CHECK(hipGetDevice(&dev_)); }
  ~CudaDeviceRestorer() { CUDA_CHECK(hipSetDevice(dev_)); }
45
46
47
48
49
50
51
52
53
54
55
56
57
  void check_device(int device) const {
    if (device != dev_) {
      throw std::runtime_error(
          std::string(__FILE__) + ":" + std::to_string(__LINE__) +
          ": Runtime Error: The device id in the context is not consistent with configuration");
    }
  }

 private:
  int dev_;
};

inline int get_dev(const void* ptr) {
58
59
  hipPointerAttribute_t attr;
  CUDA_CHECK(hipPointerGetAttributes(&attr, ptr));
60
61
  int dev = -1;

62
63
#if DTKRT_VERSION >= 10000
  if (attr.type == hipMemoryTypeDevice)
64
#else
65
  if (attr.memoryType == hipMemoryTypeDevice)
66
67
68
69
70
71
72
73
74
75
#endif
  {
    dev = attr.device;
  }
  return dev;
}

inline void switch_to_dev(const void* ptr) {
  int dev = get_dev(ptr);
  if (dev >= 0) {
76
    CUDA_CHECK(hipSetDevice(dev));
77
78
79
80
  }
}

}  // namespace nv