Commit 1547bd93 authored by Rhett Ying's avatar Rhett Ying Committed by RhettYing
Browse files

[doc] use tqdm from tqdm.auto (#7191)

parent 69247f5b
...@@ -249,11 +249,11 @@ ...@@ -249,11 +249,11 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import tqdm\n", "from tqdm.auto import tqdm\n",
"for epoch in range(3):\n", "for epoch in range(3):\n",
" model.train()\n", " model.train()\n",
" total_loss = 0\n", " total_loss = 0\n",
" for step, data in tqdm.tqdm(enumerate(create_train_dataloader())):\n", " for step, data in tqdm(enumerate(create_train_dataloader())):\n",
" # Get node pairs with labels for loss calculation.\n", " # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, labels = data.node_pairs_with_labels\n", " compacted_pairs, labels = data.node_pairs_with_labels\n",
" node_feature = data.node_features[\"feat\"]\n", " node_feature = data.node_features[\"feat\"]\n",
...@@ -306,7 +306,7 @@ ...@@ -306,7 +306,7 @@
"\n", "\n",
"logits = []\n", "logits = []\n",
"labels = []\n", "labels = []\n",
"for step, data in tqdm.tqdm(enumerate(eval_dataloader)):\n", "for step, data in tqdm(enumerate(eval_dataloader)):\n",
" # Get node pairs with labels for loss calculation.\n", " # Get node pairs with labels for loss calculation.\n",
" compacted_pairs, label = data.node_pairs_with_labels\n", " compacted_pairs, label = data.node_pairs_with_labels\n",
"\n", "\n",
...@@ -370,4 +370,4 @@ ...@@ -370,4 +370,4 @@
}, },
"nbformat": 4, "nbformat": 4,
"nbformat_minor": 0 "nbformat_minor": 0
} }
\ No newline at end of file
...@@ -297,12 +297,12 @@ ...@@ -297,12 +297,12 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"import tqdm\n", "from tqdm.auto import tqdm\n",
"\n", "\n",
"for epoch in range(10):\n", "for epoch in range(10):\n",
" model.train()\n", " model.train()\n",
"\n", "\n",
" with tqdm.tqdm(train_dataloader) as tq:\n", " with tqdm(train_dataloader) as tq:\n",
" for step, data in enumerate(tq):\n", " for step, data in enumerate(tq):\n",
" x = data.node_features[\"feat\"]\n", " x = data.node_features[\"feat\"]\n",
" labels = data.labels\n", " labels = data.labels\n",
...@@ -328,7 +328,7 @@ ...@@ -328,7 +328,7 @@
"\n", "\n",
" predictions = []\n", " predictions = []\n",
" labels = []\n", " labels = []\n",
" with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():\n", " with tqdm(valid_dataloader) as tq, torch.no_grad():\n",
" for data in tq:\n", " for data in tq:\n",
" x = data.node_features[\"feat\"]\n", " x = data.node_features[\"feat\"]\n",
" labels.append(data.labels.cpu().numpy())\n", " labels.append(data.labels.cpu().numpy())\n",
......
...@@ -4,7 +4,7 @@ import pickle ...@@ -4,7 +4,7 @@ import pickle
import pandas as pd import pandas as pd
from ogb.utils import smiles2graph as smiles2graph_OGB from ogb.utils import smiles2graph as smiles2graph_OGB
from tqdm import tqdm from tqdm.auto import tqdm
from .. import backend as F from .. import backend as F
......
...@@ -3,7 +3,7 @@ import pickle ...@@ -3,7 +3,7 @@ import pickle
import numpy as np import numpy as np
from scipy.spatial.distance import cdist from scipy.spatial.distance import cdist
from tqdm import tqdm from tqdm.auto import tqdm
from .. import backend as F from .. import backend as F
from ..convert import graph as dgl_graph from ..convert import graph as dgl_graph
......
...@@ -12,7 +12,7 @@ import networkx.algorithms as A ...@@ -12,7 +12,7 @@ import networkx.algorithms as A
import numpy as np import numpy as np
import requests import requests
from tqdm import tqdm from tqdm.auto import tqdm
from .. import backend as F from .. import backend as F
from .graph_serialize import load_graphs, load_labels, save_graphs from .graph_serialize import load_graphs, load_labels, save_graphs
......
...@@ -5,7 +5,7 @@ from math import sqrt ...@@ -5,7 +5,7 @@ from math import sqrt
import torch import torch
from torch import nn from torch import nn
from tqdm import tqdm from tqdm.auto import tqdm
from ....base import EID, NID from ....base import EID, NID
from ....subgraph import khop_in_subgraph from ....subgraph import khop_in_subgraph
......
"""Network Embedding NN Modules""" """Network Embedding NN Modules"""
# pylint: disable= invalid-name # pylint: disable= invalid-name
import random import random
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from torch import nn from torch import nn
from torch.nn import init from torch.nn import init
from tqdm.auto import trange
from ...base import NID from ...base import NID
from ...convert import to_heterogeneous, to_homogeneous from ...convert import to_heterogeneous, to_homogeneous
...@@ -340,7 +341,7 @@ class MetaPath2Vec(nn.Module): ...@@ -340,7 +341,7 @@ class MetaPath2Vec(nn.Module):
num_nodes_total = hg.num_nodes() num_nodes_total = hg.num_nodes()
node_frequency = torch.zeros(num_nodes_total) node_frequency = torch.zeros(num_nodes_total)
# random walk # random walk
for idx in tqdm.trange(hg.num_nodes(node_metapath[0])): for idx in trange(hg.num_nodes(node_metapath[0])):
traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath) traces, _ = random_walk(g=hg, nodes=[idx], metapath=metapath)
for tr in traces.cpu().numpy(): for tr in traces.cpu().numpy():
tr_nids = [ tr_nids = [
......
...@@ -47,5 +47,7 @@ dependencies: ...@@ -47,5 +47,7 @@ dependencies:
- clang-format - clang-format
- pylint - pylint
- lintrunner - lintrunner
- jupyterlab
- ipywidgets
variables: variables:
DGL_HOME: __DGL_HOME__ DGL_HOME: __DGL_HOME__
...@@ -589,7 +589,7 @@ Transformer as a Graph Neural Network ...@@ -589,7 +589,7 @@ Transformer as a Graph Neural Network
# #
# .. code:: python # .. code:: python
# #
# from tqdm import tqdm # from tqdm.auto import tqdm
# import torch as th # import torch as th
# import numpy as np # import numpy as np
# #
......
...@@ -20,7 +20,6 @@ models with multi-GPU with ``DistributedDataParallel``. ...@@ -20,7 +20,6 @@ models with multi-GPU with ``DistributedDataParallel``.
""" """
###################################################################### ######################################################################
# Importing Packages # Importing Packages
# --------------- # ---------------
...@@ -42,9 +41,9 @@ import torch.multiprocessing as mp ...@@ -42,9 +41,9 @@ import torch.multiprocessing as mp
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torchmetrics.functional as MF import torchmetrics.functional as MF
import tqdm
from torch.distributed.algorithms.join import Join from torch.distributed.algorithms.join import Join
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from tqdm.auto import tqdm
###################################################################### ######################################################################
...@@ -155,7 +154,7 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device): ...@@ -155,7 +154,7 @@ def evaluate(rank, model, graph, features, itemset, num_classes, device):
is_train=False, is_train=False,
) )
for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader: for data in tqdm(dataloader) if rank == 0 else dataloader:
blocks = data.blocks blocks = data.blocks
x = data.node_features["feat"] x = data.node_features["feat"]
y.append(data.labels) y.append(data.labels)
...@@ -212,7 +211,7 @@ def train( ...@@ -212,7 +211,7 @@ def train(
total_loss = torch.tensor(0, dtype=torch.float, device=device) total_loss = torch.tensor(0, dtype=torch.float, device=device)
num_train_items = 0 num_train_items = 0
with Join([model]): with Join([model]):
for data in tqdm.tqdm(dataloader) if rank == 0 else dataloader: for data in tqdm(dataloader) if rank == 0 else dataloader:
# The input features are from the source nodes in the first # The input features are from the source nodes in the first
# layer's computation graph. # layer's computation graph.
x = data.node_features["feat"] x = data.node_features["feat"]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment