Unverified Commit 0075a46a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade NLTK version to circumvent unsafe pickling in v3.8.1 (#1102)



* Switch to nltk>3.8.1 and new data
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix nltk install
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4b2b39b4
datasets
flax>=0.7.1
nltk
nltk>=3.8.2
optax
......@@ -168,7 +168,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt")
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
......@@ -147,7 +147,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt")
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
......@@ -250,7 +250,7 @@ def eval_model(
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt")
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
......@@ -144,7 +144,7 @@ def eval_model(state, test_ds, batch_size, var_collect):
def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers."""
nltk.download("punkt")
nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
......@@ -4,7 +4,7 @@
set -xe
pip install nltk==3.8.1
pip install "nltk>=3.8.2"
pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment