"...text-generation-inference.git" did not exist on "c8bbbd812900ea52bab9b950f846ed2d975d9e78"
Unverified Commit 905c0aa5 authored by David Min's avatar David Min Committed by GitHub
Browse files

[Feature][Performance][GPU] Introducing UnifiedTensor for efficient zero-copy...

[Feature][Performance][GPU] Introducing UnifiedTensor for efficient zero-copy host memory access from GPU (#3086)

* Add pytorch-direct version

* Initial commit of unified tensor

* Merge branch 'master' of https://github.com/davidmin7/dgl



* Remove unnecessary things

* Fix error message

* Fix/Add descriptions

* whitespace fix

* add unpin

* disable IndexSelectCPUFromGPU with no CUDA

* add a newline for unified_tensor.py

* Apply changes based on feedback

* add 'os' module

* skip unified tensor unit test for cpu only

* Update tests/pytorch/test_unified_tensor.py
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>

* reflect feedback
Co-authored-by: default avatarshhssdm <shhssdm@gmail.com>
Co-authored-by: default avatarJinjing Zhou <VoVAllen@users.noreply.github.com>
Co-authored-by: default avatarxiang song(charlie.song) <classicxsong@gmail.com>
parent 7e923180
...@@ -238,6 +238,8 @@ macro(dgl_config_cuda out_variable) ...@@ -238,6 +238,8 @@ macro(dgl_config_cuda out_variable)
file(GLOB_RECURSE DGL_CUDA_SRC file(GLOB_RECURSE DGL_CUDA_SRC
src/array/cuda/*.cc src/array/cuda/*.cc
src/array/cuda/*.cu src/array/cuda/*.cu
src/array/cuda/uvm/*.cc
src/array/cuda/uvm/*.cu
src/kernel/cuda/*.cc src/kernel/cuda/*.cc
src/kernel/cuda/*.cu src/kernel/cuda/*.cu
src/partition/cuda/*.cu src/partition/cuda/*.cu
......
import dgl
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import dgl.nn.pytorch as dglnn
import time
import argparse
import tqdm
from model import SAGE
from load_graph import load_reddit, inductive_split, load_ogb
def compute_acc(pred, labels):
"""
Compute the accuracy of prediction given the labels.
"""
labels = labels.long()
return (th.argmax(pred, dim=1) == labels).float().sum() / len(pred)
def evaluate(model, g, nfeat, labels, val_nid, device):
"""
Evaluate the model on the validation set specified by ``val_nid``.
g : The entire graph.
inputs : The features of all the nodes.
labels : The labels of all the nodes.
val_nid : the node Ids for validation.
device : The GPU device to evaluate on.
"""
model.eval()
with th.no_grad():
pred = model.inference(g, nfeat, device, args.batch_size, args.num_workers)
model.train()
return compute_acc(pred[val_nid], labels[val_nid].to(pred.device))
def load_subtensor(nfeat, labels, seeds, input_nodes, device):
"""
Extracts features and labels for a subset of nodes
"""
batch_inputs = nfeat[input_nodes.to(device)]
batch_labels = labels[seeds].to(device)
return batch_inputs, batch_labels
#### Entry point
def run(args, device, data):
# Unpack data
n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels = data
in_feats = train_nfeat.shape[1]
train_nid = th.nonzero(train_g.ndata['train_mask'], as_tuple=True)[0]
val_nid = th.nonzero(val_g.ndata['val_mask'], as_tuple=True)[0]
test_nid = th.nonzero(~(test_g.ndata['train_mask'] | test_g.ndata['val_mask']), as_tuple=True)[0]
dataloader_device = th.device('cpu')
if args.sample_gpu:
train_nid = train_nid.to(device)
# copy only the csc to the GPU
train_g = train_g.formats(['csc'])
train_g = train_g.to(device)
dataloader_device = device
# Create PyTorch DataLoader for constructing blocks
sampler = dgl.dataloading.MultiLayerNeighborSampler(
[int(fanout) for fanout in args.fan_out.split(',')])
dataloader = dgl.dataloading.NodeDataLoader(
train_g,
train_nid,
sampler,
device=dataloader_device,
batch_size=args.batch_size,
shuffle=True,
drop_last=False,
num_workers=args.num_workers)
if args.data_cpu:
# Convert input feature tensor to unified tensor
train_nfeat = dgl.contrib.UnifiedTensor(train_nfeat, device=device)
# Define model and optimizer
model = SAGE(in_feats, args.num_hidden, n_classes, args.num_layers, F.relu, args.dropout)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Training loop
avg = 0
iter_tput = []
for epoch in range(args.num_epochs):
tic = time.time()
# Loop over the dataloader to sample the computation dependency graph as a list of
# blocks.
tic_step = time.time()
for step, (input_nodes, seeds, blocks) in enumerate(dataloader):
# Load the input features as well as output labels
batch_inputs, batch_labels = load_subtensor(train_nfeat, train_labels,
seeds, input_nodes, device)
blocks = [block.int().to(device) for block in blocks]
# Compute loss and prediction
batch_pred = model(blocks, batch_inputs)
loss = loss_fcn(batch_pred, batch_labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
iter_tput.append(len(seeds) / (time.time() - tic_step))
if step % args.log_every == 0:
acc = compute_acc(batch_pred, batch_labels)
gpu_mem_alloc = th.cuda.max_memory_allocated() / 1000000 if th.cuda.is_available() else 0
print('Epoch {:05d} | Step {:05d} | Loss {:.4f} | Train Acc {:.4f} | Speed (samples/sec) {:.4f} | GPU {:.1f} MB'.format(
epoch, step, loss.item(), acc.item(), np.mean(iter_tput[3:]), gpu_mem_alloc))
tic_step = time.time()
toc = time.time()
print('Epoch Time(s): {:.4f}'.format(toc - tic))
if epoch >= 5:
avg += toc - tic
if epoch % args.eval_every == 0 and epoch != 0:
eval_acc = evaluate(model, val_g, val_nfeat, val_labels, val_nid, device)
print('Eval Acc {:.4f}'.format(eval_acc))
test_acc = evaluate(model, test_g, test_nfeat, test_labels, test_nid, device)
print('Test Acc: {:.4f}'.format(test_acc))
print('Avg epoch time: {}'.format(avg / (epoch - 4)))
if __name__ == '__main__':
argparser = argparse.ArgumentParser()
argparser.add_argument('--gpu', type=int, default=0,
help="GPU device ID. Use -1 for CPU training")
argparser.add_argument('--dataset', type=str, default='reddit')
argparser.add_argument('--num-epochs', type=int, default=20)
argparser.add_argument('--num-hidden', type=int, default=16)
argparser.add_argument('--num-layers', type=int, default=2)
argparser.add_argument('--fan-out', type=str, default='10,25')
argparser.add_argument('--batch-size', type=int, default=1000)
argparser.add_argument('--log-every', type=int, default=20)
argparser.add_argument('--eval-every', type=int, default=5)
argparser.add_argument('--lr', type=float, default=0.003)
argparser.add_argument('--dropout', type=float, default=0.5)
argparser.add_argument('--num-workers', type=int, default=4,
help="Number of sampling processes. Use 0 for no extra process.")
argparser.add_argument('--sample-gpu', action='store_true',
help="Perform the sampling process on the GPU. Must have 0 workers.")
argparser.add_argument('--inductive', action='store_true',
help="Inductive learning setting")
argparser.add_argument('--data-cpu', action='store_true',
help="By default the script puts all node features and labels "
"on GPU when using it to save time for data copy. This may "
"be undesired if they cannot fit in GPU memory at once. "
"Setting this flag makes all node features to be located"
"in the unified tensor instead.")
args = argparser.parse_args()
if args.gpu >= 0:
device = th.device('cuda:%d' % args.gpu)
else:
device = th.device('cpu')
if args.dataset == 'reddit':
g, n_classes = load_reddit()
elif args.dataset == 'ogbn-products':
g, n_classes = load_ogb('ogbn-products')
else:
raise Exception('unknown dataset')
if args.inductive:
train_g, val_g, test_g = inductive_split(g)
train_nfeat = train_g.ndata.pop('features')
val_nfeat = val_g.ndata.pop('features')
test_nfeat = test_g.ndata.pop('features')
train_labels = train_g.ndata.pop('labels')
val_labels = val_g.ndata.pop('labels')
test_labels = test_g.ndata.pop('labels')
else:
train_g = val_g = test_g = g
train_nfeat = val_nfeat = test_nfeat = g.ndata.pop('features')
train_labels = val_labels = test_labels = g.ndata.pop('labels')
if not args.data_cpu:
train_nfeat = train_nfeat.to(device)
train_labels = train_labels.to(device)
# Pack data
data = n_classes, train_g, val_g, test_g, train_nfeat, train_labels, \
val_nfeat, val_labels, test_nfeat, test_labels
run(args, device, data)
...@@ -168,6 +168,35 @@ ...@@ -168,6 +168,35 @@
} \ } \
} while (0) } while (0)
/*
* Dispatch data type only based on bit-width (8-bit, 16-bit, 32-bit, 64-bit):
*
* ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, {
* // Now DType is the type which has the same bit-width with the
* // data type in array.
* // Do not use for computation, but only for read and write.
* // For instance, one can do this for a CPU array:
* DType *data = static_cast<DType *>(array->data);
* });
*/
#define ATEN_DTYPE_BITS_ONLY_SWITCH(val, DType, val_name, ...) do { \
if ((val).bits == 8) { \
typedef int8_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 16) { \
typedef int16_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 32) { \
typedef int32_t DType; \
{__VA_ARGS__} \
} else if ((val).bits == 64) { \
typedef int64_t DType; \
{__VA_ARGS__} \
} else { \
LOG(FATAL) << (val_name) << " can only be 8-bit, 16-bit, 32-bit, or 64-bit"; \
} \
} while (0)
/* /*
* Dispatch according to integral type of CSR graphs. * Dispatch according to integral type of CSR graphs.
* Identical to ATEN_ID_TYPE_SWITCH except for a different error message. * Identical to ATEN_ID_TYPE_SWITCH except for a different error message.
......
...@@ -546,6 +546,16 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type, ...@@ -546,6 +546,16 @@ DGL_DLL int DGLStreamStreamSynchronize(int device_type,
*/ */
DGL_DLL int DGLLoadTensorAdapter(const char *path); DGL_DLL int DGLLoadTensorAdapter(const char *path);
/*!
* \brief Pin host memory.
*/
int DGLArrayPinData(DGLArrayHandle handle, DLContext ctx);
/*!
* \brief Unpin host memory.
*/
int DGLArrayUnpinData(DGLArrayHandle handle, DLContext ctx);
/*! /*!
* \brief Bug report macro. * \brief Bug report macro.
* *
......
...@@ -140,6 +140,24 @@ class DeviceAPI { ...@@ -140,6 +140,24 @@ class DeviceAPI {
DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx, DGL_DLL virtual void SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_src, DGLStreamHandle event_src,
DGLStreamHandle event_dst); DGLStreamHandle event_dst);
/*!
* \brief Pin host memory using cudaHostRegister().
*
* \param ctx The context of pinning and mapping.
* \param ptr The host memory pointer to be pinned.
* \param nbytes The size to be pinned.
*/
DGL_DLL virtual void PinData(DGLContext ctx, void* ptr, size_t nbytes);
/*!
* \brief Unpin host memory ussing cudaHostUnregister().
*
* \param ctx The context to unmap and unpin.
* \param ptr The host memory pointer to be unpinned.
*/
DGL_DLL virtual void UnpinData(DGLContext ctx, void* ptr);
/*! /*!
* \brief Allocate temporal workspace for backend execution. * \brief Allocate temporal workspace for backend execution.
* *
......
...@@ -316,6 +316,26 @@ class NDArrayBase(_NDArrayBase): ...@@ -316,6 +316,26 @@ class NDArrayBase(_NDArrayBase):
raise ValueError("Unsupported target type %s" % str(type(target))) raise ValueError("Unsupported target type %s" % str(type(target)))
return target return target
def pin_memory_(self, ctx):
"""Pin host memory and map into GPU address space (in-place)
Parameters
----------
ctx : DGLContext
The target GPU to map the host memory space
"""
check_call(_LIB.DGLArrayPinData(self.handle, ctx))
def unpin_memory_(self, ctx):
"""Unpin host memory pinned by pin_memory_()
Parameters
----------
ctx : DGLContext
The target GPU to map the host memory space
"""
check_call(_LIB.DGLArrayUnpinData(self.handle, ctx))
def free_extension_handle(handle, type_code): def free_extension_handle(handle, type_code):
"""Free c++ extension type handle """Free c++ extension type handle
......
...@@ -2,3 +2,4 @@ from . import sampling ...@@ -2,3 +2,4 @@ from . import sampling
from . import graph_store from . import graph_store
from .dis_kvstore import KVClient, KVServer from .dis_kvstore import KVClient, KVServer
from .dis_kvstore import read_ip_config from .dis_kvstore import read_ip_config
from .unified_tensor import UnifiedTensor
"""Unified Tensor."""
from .. import backend as F
from .._ffi.function import _init_api
from .. import utils
class UnifiedTensor: #UnifiedTensor
'''Class for storing unified tensor. Declaration of
UnifiedTensor automatically pins the input tensor.
Parameters
----------
input : Tensor
Tensor which we want to convert into the
unified tensor.
device : device
Device to create the mapping of the unified tensor.
'''
def __init__(self, input, device):
if F.device_type(device) != 'cuda':
raise ValueError("Target device must be a cuda device")
if F.device_type(F.context(input)) != 'cpu':
raise ValueError("Input tensor must be a cpu tensor")
self._input = input
self._array = F.zerocopy_to_dgl_ndarray(self._input)
self._device = device
self._array.pin_memory_(utils.to_dgl_context(self._device))
def __len__(self):
return len(self._array)
def __repr__(self):
return self._input.__repr__()
def __getitem__(self, key):
'''Perform zero-copy access from GPU if the context of
the key is cuda. Otherwise, just safely fallback to the
backend specific indexing scheme.
Parameters
----------
key : Tensor
Tensor which contains the index ids
'''
if F.device_type(F.context(key)) != 'cuda':
return self._input[key]
else:
return F.zerocopy_from_dgl_ndarray(
_CAPI_DGLIndexSelectCPUFromGPU(self._array,
F.zerocopy_to_dgl_ndarray(key)))
def __setitem__(self, key, val):
self._input[key] = val
def __del__(self):
if hasattr(self, '_array') and self._array != None:
self._array.unpin_memory_(utils.to_dgl_context(self._device))
self._array = None
if hasattr(self, '_input'):
self._input = None
@property
def shape(self):
"""Shape of this tensor"""
return self._array.shape
@property
def dtype(self):
"""Type of this tensor"""
return self._array.dtype
@property
def device(self):
"""Device of this tensor"""
return self._device
_init_api("dgl.ndarray.uvm", __name__)
/*!
* Copyright (c) 2019 by Contributors
* \file array/cpu/array_index_select_uvm.cu
* \brief Array index select GPU implementation
*/
#include <dgl/array.h>
#include "../../../runtime/cuda/cuda_common.h"
#include "./array_index_select_uvm.cuh"
#include "../utils.h"
namespace dgl {
using runtime::NDArray;
namespace aten {
namespace impl {
template<typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
auto* thr_entry = runtime::CUDAThreadEntry::ThreadLocal();
const DType* array_data = static_cast<DType*>(array->data);
const IdType* idx_data = static_cast<IdType*>(index->data);
const int64_t arr_len = array->shape[0];
const int64_t len = index->shape[0];
int64_t num_feat = 1;
std::vector<int64_t> shape{len};
CHECK_EQ(array->ctx.device_type, kDLCPU);
CHECK_EQ(index->ctx.device_type, kDLGPU);
for (int d = 1; d < array->ndim; ++d) {
num_feat *= array->shape[d];
shape.emplace_back(array->shape[d]);
}
NDArray ret = NDArray::Empty(shape, array->dtype, index->ctx);
if (len == 0)
return ret;
DType* ret_data = static_cast<DType*>(ret->data);
if (num_feat == 1) {
const int nt = cuda::FindNumThreads(len);
const int nb = (len + nt - 1) / nt;
CUDA_KERNEL_CALL(IndexSelectSingleKernel, nb, nt, 0, thr_entry->stream,
array_data, idx_data, len, ret_data);
} else {
dim3 block(256, 1);
while (static_cast<int64_t>(block.x) >= 2*num_feat) {
block.x /= 2;
block.y *= 2;
}
const dim3 grid((len+block.y-1)/block.y);
CUDA_KERNEL_CALL(IndexSelectMultiKernel, grid, block, 0, thr_entry->stream,
array_data, num_feat, idx_data, len, ret_data);
}
return ret;
}
template NDArray IndexSelectCPUFromGPU<int8_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int8_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int16_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int16_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int32_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int32_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<int64_t, int64_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<float, int32_t>(NDArray, IdArray);
template NDArray IndexSelectCPUFromGPU<float, int64_t>(NDArray, IdArray);
} // namespace impl
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2021 by Contributors
* \file array/cpu/array_index_select_uvm.cuh
* \brief Array index select GPU kernel implementation
*/
#ifndef DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_
#define DGL_ARRAY_CUDA_ARRAY_INDEX_SELECT_UVM_CUH_
namespace dgl {
namespace aten {
namespace impl {
template <typename DType, typename IdType>
__global__ void IndexSelectSingleKernel(const DType* array, const IdType* index,
int64_t length, DType* out) {
int tx = blockIdx.x * blockDim.x + threadIdx.x;
int stride_x = gridDim.x * blockDim.x;
while (tx < length) {
out[tx] = array[index[tx]];
tx += stride_x;
}
}
template <typename DType, typename IdType>
__global__ void IndexSelectMultiKernel(
const DType* const array,
const int64_t num_feat,
const IdType* const index,
const int64_t length,
DType* const out) {
int64_t out_row = blockIdx.x*blockDim.y+threadIdx.y;
const int64_t stride = blockDim.y*gridDim.x;
while (out_row < length) {
int64_t col = threadIdx.x;
const int64_t in_row = index[out_row];
while (col < num_feat) {
out[out_row*num_feat+col] = array[in_row*num_feat+col];
col += blockDim.x;
}
out_row += stride;
}
}
} // namespace impl
} // namespace aten
} // namespace dgl
#endif
/*!
* Copyright (c) 2019 by Contributors
* \file array/uvm_array.cc
* \brief DGL array utilities implementation
*/
#include <dgl/array.h>
#include <sstream>
#include "../c_api_common.h"
#include "./uvm_array_op.h"
using namespace dgl::runtime;
namespace dgl {
namespace aten {
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA
CHECK_EQ(array->ctx.device_type, kDLCPU) << "Only the CPU device type input "
<< "array supported";
CHECK_EQ(index->ctx.device_type, kDLGPU) << "Only the GPU device type input "
<< "index supported";
CHECK_GE(array->ndim, 1) << "Only support array with at least 1 dimension";
CHECK_EQ(index->ndim, 1) << "Index array must be an 1D array.";
ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, "values", {
ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
return impl::IndexSelectCPUFromGPU<DType, IdType>(array, index);
});
});
#endif
LOG(FATAL) << "IndexSelectCPUFromGPU requires CUDA";
// Should be unreachable
return NDArray{};
}
DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexSelectCPUFromGPU")
.set_body([] (DGLArgs args, DGLRetValue* rv) {
NDArray array = args[0];
IdArray index = args[1];
*rv = IndexSelectCPUFromGPU(array, index);
});
} // namespace aten
} // namespace dgl
/*!
* Copyright (c) 2019 by Contributors
* \file array/array_op.h
* \brief Array operator templates
*/
#ifndef DGL_ARRAY_UVM_ARRAY_OP_H_
#define DGL_ARRAY_UVM_ARRAY_OP_H_
#include <dgl/array.h>
#include <utility>
namespace dgl {
namespace aten {
namespace impl {
// Take CPU array and GPU index, and then index with GPU.
template <typename DType, typename IdType>
NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index);
} // namespace impl
} // namespace aten
} // namespace dgl
#endif // DGL_ARRAY_UVM_ARRAY_OP_H_
...@@ -123,6 +123,14 @@ void DeviceAPI::SyncStreamFromTo(DGLContext ctx, ...@@ -123,6 +123,14 @@ void DeviceAPI::SyncStreamFromTo(DGLContext ctx,
DGLStreamHandle event_dst) { DGLStreamHandle event_dst) {
LOG(FATAL) << "Device does not support stream api."; LOG(FATAL) << "Device does not support stream api.";
} }
void DeviceAPI::PinData(DGLContext ctx, void* ptr, size_t nbytes) {
LOG(FATAL) << "Device does not support cudaHostRegister api.";
}
void DeviceAPI::UnpinData(DGLContext ctx, void* ptr) {
LOG(FATAL) << "Device does not support cudaHostUnregister api.";
}
} // namespace runtime } // namespace runtime
} // namespace dgl } // namespace dgl
......
...@@ -170,6 +170,16 @@ class CUDADeviceAPI final : public DeviceAPI { ...@@ -170,6 +170,16 @@ class CUDADeviceAPI final : public DeviceAPI {
->stream = static_cast<cudaStream_t>(stream); ->stream = static_cast<cudaStream_t>(stream);
} }
void PinData(DGLContext ctx, void* ptr, size_t nbytes) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaHostRegister(ptr, nbytes, cudaHostRegisterDefault));
}
void UnpinData(DGLContext ctx, void* ptr) {
CUDA_CALL(cudaSetDevice(ctx.device_id));
CUDA_CALL(cudaHostUnregister(ptr));
}
void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final { void* AllocWorkspace(DGLContext ctx, size_t size, DGLType type_hint) final {
return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size); return CUDAThreadEntry::ThreadLocal()->pool.AllocWorkspace(ctx, size);
} }
......
...@@ -510,3 +510,20 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle, ...@@ -510,3 +510,20 @@ int DGLArrayCopyToBytes(DGLArrayHandle handle,
nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr); nbytes, handle->ctx, cpu_ctx, handle->dtype, nullptr);
API_END(); API_END();
} }
int DGLArrayPinData(DGLArrayHandle handle,
DLContext ctx) {
API_BEGIN();
CHECK_EQ(ctx.device_type, kDLGPU);
DeviceAPI::Get(ctx)->PinData(ctx, handle->data,
GetDataSize(*handle));
API_END();
}
int DGLArrayUnpinData(DGLArrayHandle handle,
DLContext ctx) {
API_BEGIN();
CHECK_EQ(ctx.device_type, kDLGPU);
DeviceAPI::Get(ctx)->UnpinData(ctx, handle->data);
API_END();
}
import unittest, os
import torch as th
import dgl
import backend as F
@unittest.skipIf(os.name == 'nt', reason='Do not support windows yet')
@unittest.skipIf(F.ctx().type == 'cpu', reason='gpu only test')
def test_unified_tensor():
test_row_size = 65536
test_col_size = 128
rand_test_size = 8192
input = th.rand((test_row_size, test_col_size))
input_unified = dgl.contrib.UnifiedTensor(input, device=th.device('cuda'))
seq_idx = th.arange(0, test_row_size)
assert th.all(th.eq(input[seq_idx], input_unified[seq_idx]))
seq_idx = seq_idx.to(th.device('cuda'))
assert th.all(th.eq(input[seq_idx].to(th.device('cuda')), input_unified[seq_idx]))
rand_idx = th.randint(0, test_row_size, (rand_test_size,))
assert th.all(th.eq(input[rand_idx], input_unified[rand_idx]))
rand_idx = rand_idx.to(th.device('cuda'))
assert th.all(th.eq(input[rand_idx].to(th.device('cuda')), input_unified[rand_idx]))
if __name__ == '__main__':
test_unified_tensor()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment