Commit 0351bb6b authored by Baber's avatar Baber
Browse files

handle nltk punkt_tab

parent 07429c86
......@@ -4,7 +4,11 @@ import uuid
import numpy as np
import wonderwords
import nltk
from nltk import sent_tokenize
from packaging.version import parse as parse_version
from importlib.metadata import version
from tqdm import tqdm
from transformers import AutoTokenizer
......@@ -32,6 +36,27 @@ WORDS = sorted(list(set(words)))
# Positions
DEPTHS = list(np.round(np.linspace(0, 100, num=40, endpoint=True)).astype(int))
NLTK_MIN_VERSION = "3.9.1"
RANK = os.environ.get("LOCAL_RANK", "0")
def download_nltk_resources():
"""Download 'punkt' if not already installed"""
assert (
(nltk_version := parse_version(version("nltk")))
>= parse_version(NLTK_MIN_VERSION)
), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
if RANK == "0":
nltk.download("punkt_tab")
print("Downloaded punkt_tab on rank 0")
download_nltk_resources()
def generate_random_number(num_digits=7):
lower_bound = 10 ** (num_digits - 1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment