Unverified Commit 969660c7 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Recover from corrupted cache file in bench serving (#6510)

parent 16d4f680
...@@ -24,6 +24,7 @@ import warnings ...@@ -24,6 +24,7 @@ import warnings
from argparse import ArgumentParser from argparse import ArgumentParser
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from json import JSONDecodeError
from pathlib import Path from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
...@@ -588,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): ...@@ -588,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
filename = os.path.join("/tmp", url.split("/")[-1]) filename = os.path.join("/tmp", url.split("/")[-1])
# Check if the cache file already exists # Check if the cache file already exists
if os.path.exists(filename): if is_file_valid_json(filename):
return filename return filename
print(f"Downloading from {url} to {filename}") print(f"Downloading from {url} to {filename}")
...@@ -616,6 +617,22 @@ def download_and_cache_file(url: str, filename: Optional[str] = None): ...@@ -616,6 +617,22 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename return filename
def is_file_valid_json(path):
if not os.path.isfile(path):
return False
# TODO can fuse into the real file open later
try:
with open(path) as f:
json.load(f)
return True
except JSONDecodeError as e:
print(
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
)
return False
@dataclass @dataclass
class DatasetRow: class DatasetRow:
prompt: str prompt: str
...@@ -755,7 +772,7 @@ def sample_sharegpt_requests( ...@@ -755,7 +772,7 @@ def sample_sharegpt_requests(
raise ValueError("output_len too small") raise ValueError("output_len too small")
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path) and dataset_path == "": if not is_file_valid_json(dataset_path) and dataset_path == "":
dataset_path = download_and_cache_file(SHAREGPT_URL) dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset. # Load the dataset.
...@@ -853,7 +870,7 @@ def sample_random_requests( ...@@ -853,7 +870,7 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary # Download sharegpt if necessary
if not os.path.isfile(dataset_path): if not is_file_valid_json(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL) dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset. # Load the dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment