torch.cpp 2.83 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include "torch.h"

#include <ATen/cuda/CUDAContext.h>

using spdlog::fmt_lib::format;

template<typename To, typename Ti>
static To int_cast(Ti x) {
    if (x < std::numeric_limits<To>::min() || x > std::numeric_limits<To>::max()) {
        throw std::runtime_error("integer overflow");
    }
    return static_cast<To>(x);
}

Tensor from_torch(at::Tensor input) {
    Tensor result;

    const int ndims = int_cast<int>(input.ndimension());
    for (int i = 0; i < ndims; i++) {
        result.shape.dataExtent.push_back(int_cast<decltype(result.shape.dataExtent)::value_type>(input.size(i)));
        result.shape.dataStride.push_back(int_cast<decltype(result.shape.dataStride)::value_type>(input.stride(i)));
    }

    static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
25
        { at::ScalarType::Char, Tensor::INT8 },
Zhekai Zhang's avatar
Zhekai Zhang committed
26
27
28
29
30
31
        { at::ScalarType::Byte, Tensor::INT8 },
        { at::ScalarType::Int, Tensor::INT32 },
        { at::ScalarType::Long, Tensor::INT64 },
        { at::ScalarType::Float, Tensor::FP32 },
        { at::ScalarType::Half, Tensor::FP16 },
        { at::ScalarType::BFloat16, Tensor::BF16 },
32
        { at::ScalarType::Short,    Tensor::INT16 },
33
34
        { at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3 },
        { at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2 },
Zhekai Zhang's avatar
Zhekai Zhang committed
35
36
37
38
39
    };

    result.scalarType = mapType.at(input.scalar_type());
    result.buffer = std::make_shared<BufferTorchTensor>(std::move(input));

40
    Tensor::lockBuffer(result.buffer, getCurrentCUDAStream());
Zhekai Zhang's avatar
Zhekai Zhang committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59

    return result;
}

at::Tensor to_torch(Tensor input) {
    assert(input.is_contiguous());

    std::vector<int64_t> shape;
    for (size_t i = 0; i < input.ndims(); i++) {
        shape.push_back(input.size(i));
    }

    static const std::map<Tensor::ScalarType, at::ScalarType> mapType = {
        { Tensor::INT8, at::ScalarType::Byte  },
        { Tensor::INT32, at::ScalarType::Int  },
        { Tensor::INT64, at::ScalarType::Long  },
        { Tensor::FP32, at::ScalarType::Float  },
        { Tensor::FP16, at::ScalarType::Half  },
        { Tensor::BF16, at::ScalarType::BFloat16  },
60
        { Tensor::INT16,  at::ScalarType::Short   },
61
62
        { Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn },
        { Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2 },
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    };

    c10::TensorOptions opts(mapType.at(input.scalar_type()));
    if (input.device().type == Device::CPU) {
        opts = opts.device("cpu");
    } else {
        opts = opts.device(format("cuda:{}", input.device().idx));
    }

    at::Tensor result = torch::empty(at::IntArrayRef(shape), opts);
    from_torch(result).copy_(input);

    return result;
}

TorchOpContext::TorchOpContext() {
    stackCUDAStreams.push(at::cuda::getCurrentCUDAStream().stream());
}

TorchOpContext::~TorchOpContext() {
    assert(stackCUDAStreams.top() == at::cuda::getCurrentCUDAStream().stream());
    stackCUDAStreams.pop();
}