Unverified Commit 969660c7 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Recover from corrupted cache file in bench serving (#6510)

parent 16d4f680
......@@ -24,6 +24,7 @@ import warnings
from argparse import ArgumentParser
from dataclasses import dataclass, field
from datetime import datetime
from json import JSONDecodeError
from pathlib import Path
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
......@@ -588,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
filename = os.path.join("/tmp", url.split("/")[-1])
# Check if the cache file already exists
if os.path.exists(filename):
if is_file_valid_json(filename):
return filename
print(f"Downloading from {url} to {filename}")
......@@ -616,6 +617,22 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
return filename
def is_file_valid_json(path):
if not os.path.isfile(path):
return False
# TODO can fuse into the real file open later
try:
with open(path) as f:
json.load(f)
return True
except JSONDecodeError as e:
print(
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
)
return False
@dataclass
class DatasetRow:
prompt: str
......@@ -755,7 +772,7 @@ def sample_sharegpt_requests(
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and dataset_path == "":
if not is_file_valid_json(dataset_path) and dataset_path == "":
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
......@@ -853,7 +870,7 @@ def sample_random_requests(
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path):
if not is_file_valid_json(dataset_path):
dataset_path = download_and_cache_file(SHAREGPT_URL)
# Load the dataset.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment