torch.cpp 1.11 KB
Newer Older
1
2
3
4
5
6
/*!
 *  Copyright (c) 2020 by Contributors
 * \file torch/torch.cpp
 * \brief Implementation of PyTorch adapter library.
 */

7
#include <tensoradapter_exports.h>
8
9
10
11
12
#include <torch/torch.h>
#include <ATen/DLConvertor.h>
#include <vector>
#include <iostream>

13
14
15
16
17
18
#if DLPACK_VERSION > 040
// Compatibility across DLPack - note that this assumes that the ABI stays the same.
#define kDLGPU kDLCUDA
#define DLContext DLDevice
#endif

19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
namespace tensoradapter {

static at::Device get_device(DLContext ctx) {
  switch (ctx.device_type) {
   case kDLCPU:
    return at::Device(torch::kCPU);
    break;
   case kDLGPU:
    return at::Device(torch::kCUDA, ctx.device_id);
    break;
   default:
    // fallback to CPU
    return at::Device(torch::kCPU);
    break;
  }
}

extern "C" {

38
TA_EXPORTS DLManagedTensor* TAempty(
39
40
41
42
43
44
45
46
47
48
49
50
51
52
    std::vector<int64_t> shape,
    DLDataType dtype,
    DLContext ctx) {
  auto options = torch::TensorOptions()
    .layout(torch::kStrided)
    .device(get_device(ctx))
    .dtype(at::toScalarType(dtype));
  torch::Tensor tensor = torch::empty(shape, options);
  return at::toDLPack(tensor);
}

};

};  // namespace tensoradapter