pytorch_compat.h 2.33 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
#pragma once

#include "common.h"
#include "Tensor.h"

namespace pytorch_compat {
Muyang Li's avatar
Muyang Li committed
7
8
9
inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
    assert(cond);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
10

Muyang Li's avatar
Muyang Li committed
11
template<typename T>
fengzch-das's avatar
fengzch-das committed
12
inline void C10_CUDA_CHECK(T ret) {
Muyang Li's avatar
Muyang Li committed
13
14
    return checkCUDA(ret);
}
Zhekai Zhang's avatar
Zhekai Zhang committed
15

Muyang Li's avatar
Muyang Li committed
16
17
namespace at {
using ::Tensor;
Zhekai Zhang's avatar
Zhekai Zhang committed
18

Muyang Li's avatar
Muyang Li committed
19
20
21
22
23
24
25
26
27
28
constexpr auto kFloat32  = Tensor::FP32;
constexpr auto kFloat    = Tensor::FP32;
constexpr auto kFloat16  = Tensor::FP16;
constexpr auto kBFloat16 = Tensor::BF16;
constexpr auto kInt32    = Tensor::INT32;
constexpr auto kInt64    = Tensor::INT64;

struct Generator {
    Generator() {
        throw std::runtime_error("Not implemented");
Zhekai Zhang's avatar
Zhekai Zhang committed
29
    }
Muyang Li's avatar
Muyang Li committed
30
31
32
33
34
    std::mutex mutex_;
};

namespace cuda {
using ::getCurrentDeviceProperties;
Zhekai Zhang's avatar
Zhekai Zhang committed
35

Muyang Li's avatar
Muyang Li committed
36
struct StreamWrapper {
fengzch-das's avatar
fengzch-das committed
37
38
    cudaStream_t st;
    cudaStream_t stream() const {
Muyang Li's avatar
Muyang Li committed
39
        return st;
Zhekai Zhang's avatar
Zhekai Zhang committed
40
    }
Muyang Li's avatar
Muyang Li committed
41
};
fengzch-das's avatar
fengzch-das committed
42
43
inline StreamWrapper getCurrentCUDAStream() {
    return StreamWrapper(::getCurrentCUDAStream());
Muyang Li's avatar
Muyang Li committed
44
45
}

fengzch-das's avatar
fengzch-das committed
46
struct CUDAGuard {
Muyang Li's avatar
Muyang Li committed
47
48
    int dev;
};
Zhekai Zhang's avatar
Zhekai Zhang committed
49

Muyang Li's avatar
Muyang Li committed
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
namespace detail {
inline Generator getDefaultCUDAGenerator() {
    return Generator();
}
} // namespace detail
} // namespace cuda

using CUDAGeneratorImpl = Generator;

template<typename T>
std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
    throw std::runtime_error("Not implemented");
}
} // namespace at

namespace torch {
using at::kFloat32;
using at::kFloat;
using at::kFloat16;
using at::kBFloat16;
using at::kInt32;
using at::kInt64;
constexpr Device kCUDA = Device::cuda();

using IntArrayRef   = std::vector<int>;
using TensorOptions = Tensor::TensorOptions;

inline Tensor empty_like(const Tensor &tensor) {
    return Tensor::empty_like(tensor);
}
inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
    return Tensor::empty(shape, options.dtype(), options.device());
}
inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
    return Tensor::empty(shape, options.dtype(), options.device()).zero_();
}

namespace nn {
namespace functional {
using PadFuncOptions = std::vector<int>;
inline Tensor pad(Tensor x, PadFuncOptions options) {
    throw std::runtime_error("Not implemented");
}
} // namespace functional
} // namespace nn

namespace indexing {
constexpr int None = 0;
struct Slice {
    int a;
    int b;
};
} // namespace indexing
} // namespace torch

namespace c10 {
using std::optional;
Zhekai Zhang's avatar
Zhekai Zhang committed
107
108
}

Muyang Li's avatar
Muyang Li committed
109
} // namespace pytorch_compat