torch.cpp 2.84 KB
Newer Older
Zhekai Zhang's avatar
Zhekai Zhang committed
1
2
#include "torch.h"

fengzch-das's avatar
fengzch-das committed
3
#include <ATen/hip/HIPContext.h>
Zhekai Zhang's avatar
Zhekai Zhang committed
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24

using spdlog::fmt_lib::format;

template<typename To, typename Ti>
static To int_cast(Ti x) {
    if (x < std::numeric_limits<To>::min() || x > std::numeric_limits<To>::max()) {
        throw std::runtime_error("integer overflow");
    }
    return static_cast<To>(x);
}

Tensor from_torch(at::Tensor input) {
    Tensor result;

    const int ndims = int_cast<int>(input.ndimension());
    for (int i = 0; i < ndims; i++) {
        result.shape.dataExtent.push_back(int_cast<decltype(result.shape.dataExtent)::value_type>(input.size(i)));
        result.shape.dataStride.push_back(int_cast<decltype(result.shape.dataStride)::value_type>(input.stride(i)));
    }

    static const std::map<at::ScalarType, Tensor::ScalarType> mapType = {
Muyang Li's avatar
Muyang Li committed
25
26
27
28
29
30
31
32
33
34
        {at::ScalarType::Char, Tensor::INT8},
        {at::ScalarType::Byte, Tensor::INT8},
        {at::ScalarType::Int, Tensor::INT32},
        {at::ScalarType::Long, Tensor::INT64},
        {at::ScalarType::Float, Tensor::FP32},
        {at::ScalarType::Half, Tensor::FP16},
        {at::ScalarType::BFloat16, Tensor::BF16},
        {at::ScalarType::Short, Tensor::INT16},
        {at::ScalarType::Float8_e4m3fn, Tensor::FP8_E4M3},
        {at::ScalarType::Float8_e5m2, Tensor::FP8_E5M2},
Zhekai Zhang's avatar
Zhekai Zhang committed
35
36
37
    };

    result.scalarType = mapType.at(input.scalar_type());
Muyang Li's avatar
Muyang Li committed
38
    result.buffer     = std::make_shared<BufferTorchTensor>(std::move(input));
Zhekai Zhang's avatar
Zhekai Zhang committed
39

fengzch-das's avatar
fengzch-das committed
40
    Tensor::lockBuffer(result.buffer, getCurrentHIPStreamMasqueradingAsCUDA());
Zhekai Zhang's avatar
Zhekai Zhang committed
41
42
43
44
45
46
47
48
49
50
51
52
53

    return result;
}

at::Tensor to_torch(Tensor input) {
    assert(input.is_contiguous());

    std::vector<int64_t> shape;
    for (size_t i = 0; i < input.ndims(); i++) {
        shape.push_back(input.size(i));
    }

    static const std::map<Tensor::ScalarType, at::ScalarType> mapType = {
Muyang Li's avatar
Muyang Li committed
54
55
56
57
58
59
60
61
62
        {Tensor::INT8, at::ScalarType::Byte},
        {Tensor::INT32, at::ScalarType::Int},
        {Tensor::INT64, at::ScalarType::Long},
        {Tensor::FP32, at::ScalarType::Float},
        {Tensor::FP16, at::ScalarType::Half},
        {Tensor::BF16, at::ScalarType::BFloat16},
        {Tensor::INT16, at::ScalarType::Short},
        {Tensor::FP8_E4M3, at::ScalarType::Float8_e4m3fn},
        {Tensor::FP8_E5M2, at::ScalarType::Float8_e5m2},
Zhekai Zhang's avatar
Zhekai Zhang committed
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
    };

    c10::TensorOptions opts(mapType.at(input.scalar_type()));
    if (input.device().type == Device::CPU) {
        opts = opts.device("cpu");
    } else {
        opts = opts.device(format("cuda:{}", input.device().idx));
    }

    at::Tensor result = torch::empty(at::IntArrayRef(shape), opts);
    from_torch(result).copy_(input);

    return result;
}

TorchOpContext::TorchOpContext() {
fengzch-das's avatar
fengzch-das committed
79
    stackCUDAStreams.push(at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream());
Zhekai Zhang's avatar
Zhekai Zhang committed
80
81
82
}

TorchOpContext::~TorchOpContext() {
fengzch-das's avatar
fengzch-das committed
83
    assert(stackCUDAStreams.top() == at::hip::getCurrentHIPStreamMasqueradingAsCUDA().stream());
Zhekai Zhang's avatar
Zhekai Zhang committed
84
    stackCUDAStreams.pop();
Muyang Li's avatar
Muyang Li committed
85
}