Commit 9592a1f3 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Fix random dataset (#671)

parent 35759efa
......@@ -192,6 +192,36 @@ class BenchmarkMetrics:
p99_itl_ms: float
default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
def download_sharegpt_dataset(path):
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
print(f"Downloading dataset from {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
with open(path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
print(f"Dataset downloaded and saved to {path}")
except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}")
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
......@@ -201,36 +231,13 @@ def sample_sharegpt_requests(
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
default_dataset_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
if not os.path.isfile(dataset_path) and not os.path.isfile(default_dataset_path):
print(f"Downloading dataset from {url}")
try:
response = requests.get(url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get("content-length", 0))
block_size = 8192
with open(default_dataset_path, "wb") as f, tqdm(
desc="Downloading",
total=total_size,
unit="iB",
unit_scale=True,
unit_divisor=1024,
) as progress_bar:
for data in response.iter_content(block_size):
size = f.write(data)
progress_bar.update(size)
print(f"Dataset downloaded and saved to {default_dataset_path}")
dataset_path = default_dataset_path
except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_dataset_path
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
# Load the dataset.
......@@ -279,6 +286,7 @@ def sample_random_requests(
num_prompts: int,
range_ratio: float,
tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
) -> List[Tuple[str, int, int]]:
input_lens = np.random.randint(
......@@ -291,13 +299,62 @@ def sample_random_requests(
output_len + 1,
size=num_prompts,
)
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
if True:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(
default_sharegpt_path
):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: List[Tuple[str, int, int]] = []
for i in range(num_prompts):
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
if prompt_len <= input_lens[i]:
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = []
for i in range(num_prompts):
prompt = tokenizer.decode(
[
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
print(f"#Input tokens: {np.sum(input_lens)}")
print(f"#Output tokens: {np.sum(output_lens)}")
......@@ -575,6 +632,7 @@ def fire(args: argparse.Namespace):
num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio,
tokenizer=tokenizer,
dataset_path=args.dataset_path,
)
else:
raise ValueError(f"Unknown dataset: {args.dataset_name}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment