pytorch_compat.h 2.4 KB
Newer Older
fengzch-das's avatar
fengzch-das committed
1
#include "hip/hip_runtime.h"
Zhekai Zhang's avatar
Zhekai Zhang committed
2
3
4
5
6
7
#pragma once

#include "common.h"
#include "Tensor.h"

namespace pytorch_compat {
Muyang Li's avatar
Muyang Li committed
8
9
10
inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
    assert(cond);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
11

Muyang Li's avatar
Muyang Li committed
12
template<typename T>
fengzch-das's avatar
fengzch-das committed
13
inline void C10_HIP_CHECK(T ret) {
Muyang Li's avatar
Muyang Li committed
14
15
    return checkCUDA(ret);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
16

Muyang Li's avatar
Muyang Li committed
17
18
namespace at {
using ::Tensor;
Zhekai Zhang's avatar
Zhekai Zhang committed
19

Muyang Li's avatar
Muyang Li committed
20
21
22
23
24
25
26
27
28
29
constexpr auto kFloat32  = Tensor::FP32;
constexpr auto kFloat    = Tensor::FP32;
constexpr auto kFloat16  = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32    = Tensor::INT32;
constexpr auto kInt64    = Tensor::INT64;

struct Generator {
    Generator() {
        throw std::runtime_error("Not implemented");
Zhekai Zhang's avatar
Zhekai Zhang committed
30
    }
Muyang Li's avatar
Muyang Li committed
31
32
33
34
35
    std::mutex mutex_;
};

namespace cuda {
using ::getCurrentDeviceProperties;
Zhekai Zhang's avatar
Zhekai Zhang committed
36

Muyang Li's avatar
Muyang Li committed
37
struct StreamWrapper {
fengzch-das's avatar
fengzch-das committed
38
39
    hipStream_t st;
    hipStream_t stream() const {
Muyang Li's avatar
Muyang Li committed
40
        return st;
Zhekai Zhang's avatar
Zhekai Zhang committed
41
    }
Muyang Li's avatar
Muyang Li committed
42
};
fengzch-das's avatar
fengzch-das committed
43
44
inline StreamWrapper getCurrentHIPStreamMasqueradingAsCUDA() {
    return StreamWrapper(::getCurrentHIPStreamMasqueradingAsCUDA());
Muyang Li's avatar
Muyang Li committed
45
46
}

fengzch-das's avatar
fengzch-das committed
47
struct HIPGuardMasqueradingAsCUDA {
Muyang Li's avatar
Muyang Li committed
48
49
    int dev;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
50

Muyang Li's avatar
Muyang Li committed
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
namespace detail {
inline Generator getDefaultCUDAGenerator() {
    return Generator();
}
} // namespace detail
} // namespace cuda

using CUDAGeneratorImpl = Generator;

template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
    throw std::runtime_error("Not implemented");
}
} // namespace at

namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();

using IntArrayRef   = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;

inline Tensor empty_like(const Tensor &tensor) {
    return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
    return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
    return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}

namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
    throw std::runtime_error("Not implemented");
}
} // namespace functional
} // namespace nn

namespace indexing {
constexpr int None = 0;
struct Slice {
    int a;
    int b;
};
} // namespace indexing
} // namespace torch

namespace c10 {
using std::optional;
Zhekai Zhang's avatar
Zhekai Zhang committed
108
109
}

Muyang Li's avatar
Muyang Li committed
110
} // namespace pytorch_compat