"tests/python/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "219c9f1a2910a1f5d22f56fede28252d55eee277"
Unverified Commit 520cef88 authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Feature] Compatibility to DLPack 0.6 in tensoradapter (#3803)



* compatibility to DLPack 0.6 in tensoradapter

* fix

* oops
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
parent 27d3af01
...@@ -13,12 +13,6 @@ ...@@ -13,12 +13,6 @@
#include <dlpack/dlpack.h> #include <dlpack/dlpack.h>
#include <vector> #include <vector>
#if defined(WIN32) || defined(_WIN32)
#define TA_EXPORTS __declspec(dllexport)
#else
#define TA_EXPORTS
#endif
namespace tensoradapter { namespace tensoradapter {
extern "C" { extern "C" {
...@@ -31,7 +25,7 @@ extern "C" { ...@@ -31,7 +25,7 @@ extern "C" {
* \param ctx The device * \param ctx The device
* \return The allocated tensor * \return The allocated tensor
*/ */
TA_EXPORTS DLManagedTensor* TAempty( DLManagedTensor* TAempty(
std::vector<int64_t> shape, DLDataType dtype, DLContext ctx); std::vector<int64_t> shape, DLDataType dtype, DLContext ctx);
} }
......
/*!
* Copyright (c) 2020 by Contributors
* \file tensoradapter_exports.h
* \brief Header file for functions exposed by the adapter library.
*/
#ifndef TENSORADAPTER_EXPORTS_H_
#define TENSORADAPTER_EXPORTS_H_
#if defined(WIN32) || defined(_WIN32)
#define TA_EXPORTS __declspec(dllexport)
#else
#define TA_EXPORTS
#endif
#endif // TENSORADAPTER_EXPORTS_H_
...@@ -4,12 +4,18 @@ ...@@ -4,12 +4,18 @@
* \brief Implementation of PyTorch adapter library. * \brief Implementation of PyTorch adapter library.
*/ */
#include <tensoradapter.h> #include <tensoradapter_exports.h>
#include <torch/torch.h> #include <torch/torch.h>
#include <ATen/DLConvertor.h> #include <ATen/DLConvertor.h>
#include <vector> #include <vector>
#include <iostream> #include <iostream>
#if DLPACK_VERSION > 040
// Compatibility across DLPack - note that this assumes that the ABI stays the same.
#define kDLGPU kDLCUDA
#define DLContext DLDevice
#endif
namespace tensoradapter { namespace tensoradapter {
static at::Device get_device(DLContext ctx) { static at::Device get_device(DLContext ctx) {
...@@ -29,7 +35,7 @@ static at::Device get_device(DLContext ctx) { ...@@ -29,7 +35,7 @@ static at::Device get_device(DLContext ctx) {
extern "C" { extern "C" {
DLManagedTensor* TAempty( TA_EXPORTS DLManagedTensor* TAempty(
std::vector<int64_t> shape, std::vector<int64_t> shape,
DLDataType dtype, DLDataType dtype,
DLContext ctx) { DLContext ctx) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment