Unverified Commit 40e53d65 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Add disk cache for loading ShareGPT dataset. (#542)

parent fb9296f0
...@@ -19,6 +19,7 @@ On the client side, run: ...@@ -19,6 +19,7 @@ On the client side, run:
import argparse import argparse
import asyncio import asyncio
import json import json
import os
import random import random
import time import time
from typing import AsyncGenerator, List, Tuple from typing import AsyncGenerator, List, Tuple
...@@ -37,43 +38,62 @@ def sample_requests( ...@@ -37,43 +38,62 @@ def sample_requests(
num_requests: int, num_requests: int,
tokenizer: AutoTokenizer, tokenizer: AutoTokenizer,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
# Load the dataset.
with open(dataset_path) as f: def load_dataset():
dataset = json.load(f) with open(dataset_path, encoding="utf-8") as f:
# Filter out the conversations with less than 2 turns. dataset = json.load(f)
dataset = [data for data in dataset if len(data["conversations"]) >= 2] # Filter out the conversations with less than 2 turns.
# Only keep the first two turns of each conversation. dataset = [data for data in dataset if len(data["conversations"]) >= 2]
dataset = [ # Only keep the first two turns of each conversation.
(data["conversations"][0]["value"], data["conversations"][1]["value"]) dataset = [
for data in dataset (data["conversations"][0]["value"], data["conversations"][1]["value"])
] for data in dataset
]
# Tokenize the prompts and completions.
prompts = [prompt for prompt, _ in dataset] # Tokenize the prompts and completions.
prompt_token_ids = tokenizer(prompts).input_ids prompts = [prompt for prompt, _ in dataset]
completions = [completion for _, completion in dataset] prompt_token_ids = tokenizer(prompts).input_ids
completion_token_ids = tokenizer(completions).input_ids completions = [completion for _, completion in dataset]
tokenized_dataset = [] completion_token_ids = tokenizer(completions).input_ids
for i in range(len(dataset)): tokenized_dataset = []
output_len = len(completion_token_ids[i]) for i in range(len(dataset)):
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len)) output_len = len(completion_token_ids[i])
tokenized_dataset.append((prompts[i], prompt_token_ids[i], output_len))
# Filter out too long sequences.
filtered_dataset: List[Tuple[str, int, int]] = [] # Filter out too long sequences.
for prompt, prompt_token_ids, output_len in tokenized_dataset: filtered_dataset: List[Tuple[str, int, int]] = []
prompt_len = len(prompt_token_ids) for prompt, prompt_token_ids, output_len in tokenized_dataset:
if prompt_len < 4 or output_len < 4: prompt_len = len(prompt_token_ids)
# Prune too short sequences. if prompt_len < 4 or output_len < 4:
# This is because TGI causes errors when the input or output length # Prune too short sequences.
# is too short. # This is because TGI causes errors when the input or output length
continue # is too short.
if prompt_len > 1024 or prompt_len + output_len > 2048: continue
# Prune too long sequences. if prompt_len > 1024 or prompt_len + output_len > 2048:
continue # Prune too long sequences.
filtered_dataset.append((prompt, prompt_len, output_len)) continue
filtered_dataset.append((prompt, prompt_len, output_len))
return filtered_dataset
try:
from diskcache import Cache
home_dir = os.path.expanduser("~")
cache = Cache(f"{home_dir}/.cache/sglang")
with Cache(cache.directory) as reference:
reference_key = f"{dataset_path}_{tokenizer.name_or_path}"
if reference_key in reference:
print("Reading dataset from cache...")
dataset = reference[reference_key]
else:
dataset = load_dataset()
reference[reference_key] = dataset
except ImportError:
dataset = load_dataset()
# Sample the requests. # Sample the requests.
sampled_requests = random.sample(filtered_dataset, num_requests) sampled_requests = random.sample(dataset, num_requests)
return sampled_requests return sampled_requests
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment