Unverified Commit 7b1639f1 authored by Andrzej Kotłowski's avatar Andrzej Kotłowski Committed by GitHub
Browse files

[TESTS] Make pytorch test_nn.py fully reproducible (#5887)

parent 3e2b5a04
......@@ -126,6 +126,11 @@ def abs(a):
pass
def seed(a):
"""Set seed to for random generator"""
pass
###############################################################################
# Tensor functions used *only* on index tensor
# ----------------
......
......@@ -100,3 +100,7 @@ def dot(a, b):
def abs(a):
return nd.abs(a)
def seed(a):
return mx.random.seed(a)
......@@ -93,3 +93,7 @@ def dot(a, b):
def abs(a):
return a.abs()
def seed(a):
return th.manual_seed(a)
......@@ -104,3 +104,7 @@ def dot(a, b):
def abs(a):
return tf.abs(a)
def seed(a):
return tf.random.set_seed(a)
import io
import pickle
import random
from copy import deepcopy
import backend as F
......@@ -8,6 +9,7 @@ import dgl
import dgl.function as fn
import dgl.nn.pytorch as nn
import networkx as nx
import numpy as np # For setting seed for scipy
import pytest
import scipy as sp
import torch
......@@ -24,6 +26,13 @@ from utils.graph_cases import (
random_graph,
)
# Set seeds to make tests fully reproducible.
SEED = 12345 # random.randint(1, 99999)
random.seed(SEED) # For networkx
np.random.seed(SEED) # For scipy
dgl.seed(SEED)
F.seed(SEED)
tmp_buffer = io.BytesIO()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment