Unverified Commit 2104e4c1 authored by jberchtold-nvidia's avatar jberchtold-nvidia Committed by GitHub
Browse files

[JAX] Use "nyu-mll/glue" instead of "glue" for encoder datasets to fix 404 error (#2625)



* Use "nyu-mll/glue" instead of "glue" for encoder datasets to fix 404 error
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* rename mnist dataset path
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

* add dataset manifest
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>

---------
Signed-off-by: default avatarJeremy Berchtold <jberchtold@nvidia.com>
parent 2dbfbc74
# Datasets used by TE encoder tests. Pull these to pre-emptively cache datasets
ylecun/mnist
nyu-mll/glue
\ No newline at end of file
......@@ -219,11 +219,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -197,11 +197,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -307,11 +307,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -195,11 +195,11 @@ def get_datasets(max_seq_len):
vocab = {}
word_id = 0
train_ds = load_dataset("glue", "cola", split="train")
train_ds = load_dataset("nyu-mll/glue", "cola", split="train")
train_ds.set_format(type="np")
train_ds, vocab, word_id = data_preprocess(train_ds, vocab, word_id, max_seq_len)
test_ds = load_dataset("glue", "cola", split="validation")
test_ds = load_dataset("nyu-mll/glue", "cola", split="validation")
test_ds.set_format(type="np")
test_ds, vocab, word_id = data_preprocess(test_ds, vocab, word_id, max_seq_len)
return train_ds, test_ds, word_id
......
......@@ -146,7 +146,7 @@ def eval_model(state, test_ds, batch_size, var_collect):
def get_datasets():
"""Load MNIST train and test datasets into memory."""
train_ds = load_dataset("mnist", split="train", trust_remote_code=True)
train_ds = load_dataset("ylecun/mnist", split="train", trust_remote_code=True)
train_ds.set_format(type="np")
batch_size = train_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
......@@ -154,7 +154,7 @@ def get_datasets():
"image": train_ds["image"].astype(np.float32).reshape(shape) / 255.0,
"label": train_ds["label"],
}
test_ds = load_dataset("mnist", split="test", trust_remote_code=True)
test_ds = load_dataset("ylecun/mnist", split="test", trust_remote_code=True)
test_ds.set_format(type="np")
batch_size = test_ds["image"].shape[0]
shape = (batch_size, IMAGE_H, IMAGE_W, IMAGE_C)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment