"examples/git@developer.sourcefind.cn:OpenDAS/dgl.git" did not exist on "0a42d863b740e3e13d79ee081d3792a4a04aed87"
Unverified Commit 6451807b authored by czkkkkkk's avatar czkkkkkk Committed by GitHub
Browse files

[Graphbolt] Add an EdgeAttribute function in FusedCSCSamplingGraph to access...

[Graphbolt] Add an EdgeAttribute function in FusedCSCSamplingGraph to access an edge attribute by name. (#6756)
parent 9bb36f1a
...@@ -154,6 +154,44 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder { ...@@ -154,6 +154,44 @@ class FusedCSCSamplingGraph : public torch::CustomClassHolder {
return edge_attributes_; return edge_attributes_;
} }
/**
* @brief Get the node attribute tensor by name.
*
* If the input name is empty, return nullopt. Otherwise, return the node
* attribute tensor by name.
*/
inline torch::optional<torch::Tensor> NodeAttribute(
torch::optional<std::string> name) const {
if (!name.has_value()) {
return torch::nullopt;
}
TORCH_CHECK(
node_attributes_.has_value() &&
node_attributes_.value().contains(name.value()),
"Node attribute ", name.value(), " does not exist.");
return torch::optional<torch::Tensor>(
node_attributes_.value().at(name.value()));
}
/**
* @brief Get the edge attribute tensor by name.
*
* If the input name is empty, return nullopt. Otherwise, return the edge
* attribute tensor by name.
*/
inline torch::optional<torch::Tensor> EdgeAttribute(
torch::optional<std::string> name) const {
if (!name.has_value()) {
return torch::nullopt;
}
TORCH_CHECK(
edge_attributes_.has_value() &&
edge_attributes_.value().contains(name.value()),
"Edge attribute ", name.value(), " does not exist.");
return torch::optional<torch::Tensor>(
edge_attributes_.value().at(name.value()));
}
/** @brief Set the csc index pointer tensor. */ /** @brief Set the csc index pointer tensor. */
inline void SetCSCIndptr(const torch::Tensor& indptr) { indptr_ = indptr; } inline void SetCSCIndptr(const torch::Tensor& indptr) { indptr_ = indptr; }
......
...@@ -537,9 +537,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors( ...@@ -537,9 +537,8 @@ c10::intrusive_ptr<FusedSampledSubgraph> FusedCSCSamplingGraph::SampleNeighbors(
const torch::Tensor& nodes, const std::vector<int64_t>& fanouts, const torch::Tensor& nodes, const std::vector<int64_t>& fanouts,
bool replace, bool layer, bool return_eids, bool replace, bool layer, bool return_eids,
torch::optional<std::string> probs_name) const { torch::optional<std::string> probs_name) const {
torch::optional<torch::Tensor> probs_or_mask = torch::nullopt; auto probs_or_mask = this->EdgeAttribute(probs_name);
if (probs_name.has_value() && !probs_name.value().empty()) { if (probs_name.has_value()) {
probs_or_mask = edge_attributes_.value().at(probs_name.value());
// Note probs will be passed as input for 'torch.multinomial' in deeper // Note probs will be passed as input for 'torch.multinomial' in deeper
// stack, which doesn't support 'torch.half' and 'torch.bool' data types. To // stack, which doesn't support 'torch.half' and 'torch.bool' data types. To
// avoid crashes, convert 'probs_or_mask' to 'float32' data type. // avoid crashes, convert 'probs_or_mask' to 'float32' data type.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment