pytorch_compat.h 2.33 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
#pragma once

#include "common.h"
#include "Tensor.h"

namespace pytorch_compat {
Muyang Li's avatar
Muyang Li committed
7
8
9
inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
    assert(cond);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
10

Muyang Li's avatar
Muyang Li committed
11
12
13
14
template<typename T>
inline void C10_CUDA_CHECK(T ret) {
    return checkCUDA(ret);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
15

Muyang Li's avatar
Muyang Li committed
16
17
namespace at {
using ::Tensor;
Zhekai Zhang's avatar
Zhekai Zhang committed
18

Muyang Li's avatar
Muyang Li committed
19
20
21
22
23
24
25
26
27
28
constexpr auto kFloat32  = Tensor::FP32;
constexpr auto kFloat    = Tensor::FP32;
constexpr auto kFloat16  = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32    = Tensor::INT32;
constexpr auto kInt64    = Tensor::INT64;

struct Generator {
    Generator() {
        throw std::runtime_error("Not implemented");
Zhekai Zhang's avatar
Zhekai Zhang committed
29
    }
Muyang Li's avatar
Muyang Li committed
30
31
32
33
34
    std::mutex mutex_;
};

namespace cuda {
using ::getCurrentDeviceProperties;
Zhekai Zhang's avatar
Zhekai Zhang committed
35

Muyang Li's avatar
Muyang Li committed
36
37
38
39
struct StreamWrapper {
    cudaStream_t st;
    cudaStream_t stream() const {
        return st;
Zhekai Zhang's avatar
Zhekai Zhang committed
40
    }
Muyang Li's avatar
Muyang Li committed
41
42
43
44
45
46
47
48
};
inline StreamWrapper getCurrentCUDAStream() {
    return StreamWrapper(::getCurrentCUDAStream());
}

struct CUDAGuard {
    int dev;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
49

Muyang Li's avatar
Muyang Li committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
namespace detail {
inline Generator getDefaultCUDAGenerator() {
    return Generator();
}
} // namespace detail
} // namespace cuda

using CUDAGeneratorImpl = Generator;

template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
    throw std::runtime_error("Not implemented");
}
} // namespace at

namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();

using IntArrayRef   = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;

inline Tensor empty_like(const Tensor &tensor) {
    return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
    return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
    return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}

namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
    throw std::runtime_error("Not implemented");
}
} // namespace functional
} // namespace nn

namespace indexing {
constexpr int None = 0;
struct Slice {
    int a;
    int b;
};
} // namespace indexing
} // namespace torch

namespace c10 {
using std::optional;
Zhekai Zhang's avatar
Zhekai Zhang committed
107
108
}

Muyang Li's avatar
Muyang Li committed
109
} // namespace pytorch_compat