Unverified Commit bd80a6c0 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Ifeval: Dowload `punkt_tab` on rank 0 (#2267)

* download nltk `punkt_tab` on LOCAL_RANK=0

* remove print

* remove `time`

* nit
parent 060e8761
...@@ -15,13 +15,14 @@ ...@@ -15,13 +15,14 @@
"""Utility library of instructions.""" """Utility library of instructions."""
import functools import functools
import os
import random import random
import re import re
from importlib.metadata import version
import immutabledict import immutabledict
import nltk import nltk
import pkg_resources from packaging.version import parse as parse_version
from packaging import version
# Downloading 'punkt' with nltk<3.9 has a remote code vuln. # Downloading 'punkt' with nltk<3.9 has a remote code vuln.
...@@ -29,19 +30,22 @@ from packaging import version ...@@ -29,19 +30,22 @@ from packaging import version
# and https://github.com/nltk/nltk/issues/3266 # and https://github.com/nltk/nltk/issues/3266
# for more information. # for more information.
NLTK_MIN_VERSION = "3.9.1" NLTK_MIN_VERSION = "3.9.1"
RANK = os.environ.get("LOCAL_RANK", "0")
def download_nltk_resources(): def download_nltk_resources():
"""Download 'punkt' if not already installed""" """Download 'punkt' if not already installed"""
nltk_version = pkg_resources.get_distribution("nltk").version
assert ( assert (
version.parse(nltk_version) >= version.parse(NLTK_MIN_VERSION) (nltk_version := parse_version(version("nltk")))
>= parse_version(NLTK_MIN_VERSION)
), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability." ), f"`nltk` version {nltk_version} is not >= {NLTK_MIN_VERSION}. Please update `nltk` before proceeding--older versions are vulnerable to a remote code execution vulnerability."
try: try:
nltk.data.find("tokenizers/punkt_tab") nltk.data.find("tokenizers/punkt_tab")
except LookupError: except LookupError:
nltk.download("punkt_tab") if RANK == "0":
nltk.download("punkt_tab")
print("Downloaded punkt_tab on rank 0")
download_nltk_resources() download_nltk_resources()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment