"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "1ca608bcd155c771d0fed683a75d8367fe9c7144"
uvm_array.cc 2.41 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
/**
3
 *  Copyright (c) 2019-2022 by Contributors
4
5
 * @file array/uvm_array.cc
 * @brief DGL array utilities implementation
6
7
 */
#include <dgl/array.h>
8

9
#include <sstream>
10

11
#include "../c_api_common.h"
sangwzh's avatar
sangwzh committed
12
#include "uvm_array_op.h"
13
14
15
16
17
18
19
20

using namespace dgl::runtime;

namespace dgl {
namespace aten {

NDArray IndexSelectCPUFromGPU(NDArray array, IdArray index) {
#ifdef DGL_USE_CUDA
21
  CHECK(array.IsPinned()) << "Input array must be in pinned memory.";
22
  CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
23
24
  CHECK_GE(array->ndim, 1) << "Input array must have at least 1 dimension.";
  CHECK_EQ(index->ndim, 1) << "Index must be a 1D array.";
25
26
27
28
29
30
31

  ATEN_DTYPE_BITS_ONLY_SWITCH(array->dtype, DType, "values", {
    ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
      return impl::IndexSelectCPUFromGPU<DType, IdType>(array, index);
    });
  });
#endif
32
  LOG(FATAL) << "IndexSelectCPUFromGPU requires CUDA.";
33
34
35
36
  // Should be unreachable
  return NDArray{};
}

37
38
void IndexScatterGPUToCPU(NDArray dest, IdArray index, NDArray source) {
#ifdef DGL_USE_CUDA
39
  CHECK(dest.IsPinned()) << "Destination array must be in pinned memory.";
40
  CHECK_EQ(index->ctx.device_type, kDGLCUDA) << "Index must be on the GPU.";
41
42
  CHECK_EQ(source->ctx.device_type, kDGLCUDA)
      << "Source array must be on the GPU.";
43
  CHECK_EQ(dest->dtype, source->dtype) << "Destination array and source "
44
45
46
                                          "array must have the same dtype.";
  CHECK_GE(dest->ndim, 1)
      << "Destination array must have at least 1 dimension.";
47
48
49
50
51
52
53
54
55
56
57
58
  CHECK_EQ(index->ndim, 1) << "Index must be a 1D array.";

  ATEN_DTYPE_BITS_ONLY_SWITCH(source->dtype, DType, "values", {
    ATEN_ID_TYPE_SWITCH(index->dtype, IdType, {
      impl::IndexScatterGPUToCPU<DType, IdType>(dest, index, source);
    });
  });
#else
  LOG(FATAL) << "IndexScatterGPUToCPU requires CUDA.";
#endif
}

59
DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexSelectCPUFromGPU")
60
61
62
63
64
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray array = args[0];
      IdArray index = args[1];
      *rv = IndexSelectCPUFromGPU(array, index);
    });
65

66
DGL_REGISTER_GLOBAL("ndarray.uvm._CAPI_DGLIndexScatterGPUToCPU")
67
68
69
70
71
72
    .set_body([](DGLArgs args, DGLRetValue* rv) {
      NDArray dest = args[0];
      IdArray index = args[1];
      NDArray source = args[2];
      IndexScatterGPUToCPU(dest, index, source);
    });
73

74
75
}  // namespace aten
}  // namespace dgl