Unverified Commit 0075a46a authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Upgrade NLTK version to circumvent unsafe pickling in v3.8.1 (#1102)



* Switch to nltk>3.8.1 and new data
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix nltk install
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4b2b39b4
datasets datasets
flax>=0.7.1 flax>=0.7.1
nltk nltk>=3.8.2
optax optax
...@@ -168,7 +168,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -168,7 +168,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download("punkt") nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"]) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
...@@ -147,7 +147,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn): ...@@ -147,7 +147,7 @@ def eval_model(state, test_ds, batch_size, var_collect, eval_fn):
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download("punkt") nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"]) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
...@@ -250,7 +250,7 @@ def eval_model( ...@@ -250,7 +250,7 @@ def eval_model(
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download("punkt") nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"]) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
...@@ -144,7 +144,7 @@ def eval_model(state, test_ds, batch_size, var_collect): ...@@ -144,7 +144,7 @@ def eval_model(state, test_ds, batch_size, var_collect):
def data_preprocess(dataset, vocab, word_id, max_seq_len): def data_preprocess(dataset, vocab, word_id, max_seq_len):
"""Convert tokens to numbers.""" """Convert tokens to numbers."""
nltk.download("punkt") nltk.download("punkt_tab")
dataset_size = len(dataset["sentence"]) dataset_size = len(dataset["sentence"])
output = np.zeros((dataset_size, max_seq_len), dtype=np.int32) output = np.zeros((dataset_size, max_seq_len), dtype=np.int32)
mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8) mask_3d = np.ones((dataset_size, max_seq_len, max_seq_len), dtype=np.uint8)
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
set -xe set -xe
pip install nltk==3.8.1 pip install "nltk>=3.8.2"
pip install pytest==8.2.1 pip install pytest==8.2.1
: ${TE_PATH:=/opt/transformerengine} : ${TE_PATH:=/opt/transformerengine}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment