pytorch_compat.h 2.85 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#pragma once

#include "common.h"
#include "Tensor.h"

namespace pytorch_compat {
    inline void TORCH_CHECK(bool cond, const std::string &msg = "") {
        assert (cond);
    }

    template<typename T>
    inline void C10_CUDA_CHECK(T ret) {
        return checkCUDA(ret);
    }

    namespace at {
        using ::Tensor;

        constexpr auto kFloat32 = Tensor::FP32;
        constexpr auto kFloat = Tensor::FP32;
        constexpr auto kFloat16 = Tensor::FP16;
        constexpr auto kBFloat16 = Tensor::BF16;
        constexpr auto kInt32 = Tensor::INT32;
        constexpr auto kInt64 = Tensor::INT64;

        struct Generator {
            Generator() { throw std::runtime_error("Not implemented"); }
            std::mutex mutex_;
        };

        namespace cuda {
            using ::getCurrentDeviceProperties;

            struct StreamWrapper {
                cudaStream_t st;
                cudaStream_t stream() const { return st; }
            };
            inline StreamWrapper getCurrentCUDAStream() { 
                return StreamWrapper(::getCurrentCUDAStream());
            }

            struct CUDAGuard {
                int dev;
            };

            namespace detail {
                inline Generator getDefaultCUDAGenerator() {
                    return Generator();
                }
            }
        }

        using CUDAGeneratorImpl = Generator;

        template<typename T>
        std::unique_ptr<Generator> get_generator_or_default(std::optional<Generator> gen, T gen2) {
            throw std::runtime_error("Not implemented");
        }
    }

    namespace torch {
        using at::kFloat32;
        using at::kFloat;
        using at::kFloat16;
        using at::kBFloat16;
        using at::kInt32;
        using at::kInt64;
        constexpr Device kCUDA = Device::cuda();

        using IntArrayRef = std::vector<int>;
        using TensorOptions = Tensor::TensorOptions;

        inline Tensor empty_like(const Tensor &tensor) {
            return Tensor::empty_like(tensor);
        }
        inline Tensor empty(TensorShape shape, Tensor::TensorOptions options) {
            return Tensor::empty(shape, options.dtype(), options.device());
        }
        inline Tensor zeros(TensorShape shape, Tensor::TensorOptions options) {
            return Tensor::empty(shape, options.dtype(), options.device()).zero_();
        }

        namespace nn {
            namespace functional {
                using PadFuncOptions = std::vector<int>;
                inline Tensor pad(Tensor x, PadFuncOptions options) {
                    throw std::runtime_error("Not implemented");
                }
            }
        }

        namespace indexing {
            constexpr int None = 0;
            struct Slice {
                int a;
                int b;
            };
        }
    }

    namespace c10 {
        using std::optional;
    }

}