Unverified Commit bd80a6c0 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Ifeval: Dowload `punkt_tab` on rank 0 (#2267)

* download nltk `punkt_tab` on LOCAL_RANK=0

* remove print

* remove `time`

* nit
parent 060e8761
......@@ -15,13 +15,14 @@
"""Utility library of instructions."""
import functools
import os
import random
import re
from importlib.metadata import version
import immutabledict
import nltk
import pkg_resources
from packaging import version
from packaging.version import parse as parse_version
# Downloading 'punkt' with nltk<3.9 has a remote code vuln.
......@@ -29,19 +30,22 @@ from packaging import version
# and https://github.com/nltk/nltk/issues/3266
# for more information.
NLTK_MIN_VERSION = "3.9.1"
RANK = os.environ.get("LOCAL_RANK", "0")
def download_nltk_resources():
"""Download 'punkt' if not already installed"""
nltk_version = pkg_resources.get_distribution("nltk").version
assert (
version.parse(nltk_version) >= version.parse(NLTK_MIN_VERSION)
(nltk_version := parse_version(version("nltk")))
>= parse_version(NLTK_MIN_VERSION)
), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
try:
nltk.data.find("tokenizers/punkt_tab")
except LookupError:
nltk.download("punkt_tab")
if RANK == "0":
nltk.download("punkt_tab")
print("Downloaded punkt_tab on rank 0")
download_nltk_resources()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment