Unverified Commit 3ec1eacf authored by Michał Marcinkiewicz's avatar Michał Marcinkiewicz Committed by GitHub
Browse files

Fix issue 193 (#203)

* Fixes for build with the new pytorch

* Fixes for build with the new pytorch

* Fixes for build with the new pytorch
parent d987d295
......@@ -153,7 +153,7 @@ hetero_sample(const vector<node_t> &node_types,
// Add the input nodes to the output nodes:
for (const auto &kv : input_node_dict) {
const auto &node_type = kv.key();
const auto &input_node = kv.value();
const torch::Tensor &input_node = kv.value();
const auto *input_node_data = input_node.data_ptr<int64_t>();
auto &samples = samples_dict.at(node_type);
......@@ -180,8 +180,8 @@ hetero_sample(const vector<node_t> &node_types,
auto &src_samples = samples_dict.at(src_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data = colptr_dict.at(rel_type).data_ptr<int64_t>();
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
const auto *colptr_data = ((torch::Tensor)colptr_dict.at(rel_type)).data_ptr<int64_t>();
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
......@@ -261,8 +261,8 @@ hetero_sample(const vector<node_t> &node_types,
const auto &dst_samples = samples_dict.at(dst_node_type);
auto &to_local_src_node = to_local_node_dict.at(src_node_type);
const auto *colptr_data = kv.value().data_ptr<int64_t>();
const auto *row_data = row_dict.at(rel_type).data_ptr<int64_t>();
const auto *colptr_data = ((torch::Tensor)kv.value()).data_ptr<int64_t>();
const auto *row_data = ((torch::Tensor)row_dict.at(rel_type)).data_ptr<int64_t>();
auto &rows = rows_dict.at(rel_type);
auto &cols = cols_dict.at(rel_type);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment