Commit 9592a1f3 authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Fix random dataset (#671)

parent 35759efa
...@@ -192,19 +192,12 @@ class BenchmarkMetrics: ...@@ -192,19 +192,12 @@ class BenchmarkMetrics:
p99_itl_ms: float p99_itl_ms: float
def sample_sharegpt_requests( default_sharegpt_path = "ShareGPT_V3_unfiltered_cleaned_split.json"
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
default_dataset_path = "ShareGPT_V3_unfiltered_cleaned_split.json" def download_sharegpt_dataset(path):
url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json" url = "https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
if not os.path.isfile(dataset_path) and not os.path.isfile(default_dataset_path):
print(f"Downloading dataset from {url}") print(f"Downloading dataset from {url}")
try: try:
response = requests.get(url, stream=True) response = requests.get(url, stream=True)
...@@ -213,7 +206,7 @@ def sample_sharegpt_requests( ...@@ -213,7 +206,7 @@ def sample_sharegpt_requests(
total_size = int(response.headers.get("content-length", 0)) total_size = int(response.headers.get("content-length", 0))
block_size = 8192 block_size = 8192
with open(default_dataset_path, "wb") as f, tqdm( with open(path, "wb") as f, tqdm(
desc="Downloading", desc="Downloading",
total=total_size, total=total_size,
unit="iB", unit="iB",
...@@ -224,13 +217,27 @@ def sample_sharegpt_requests( ...@@ -224,13 +217,27 @@ def sample_sharegpt_requests(
size = f.write(data) size = f.write(data)
progress_bar.update(size) progress_bar.update(size)
print(f"Dataset downloaded and saved to {default_dataset_path}") print(f"Dataset downloaded and saved to {path}")
dataset_path = default_dataset_path
except requests.RequestException as e: except requests.RequestException as e:
raise Exception(f"Failed to download dataset: {e}") raise Exception(f"Failed to download dataset: {e}")
def sample_sharegpt_requests(
dataset_path: str,
num_requests: int,
tokenizer: PreTrainedTokenizerBase,
fixed_output_len: Optional[int] = None,
) -> List[Tuple[str, int, int]]:
if fixed_output_len is not None and fixed_output_len < 4:
raise ValueError("output_len too small")
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(default_sharegpt_path):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else: else:
dataset_path = ( dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_dataset_path dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
) )
# Load the dataset. # Load the dataset.
...@@ -279,6 +286,7 @@ def sample_random_requests( ...@@ -279,6 +286,7 @@ def sample_random_requests(
num_prompts: int, num_prompts: int,
range_ratio: float, range_ratio: float,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
dataset_path: str,
) -> List[Tuple[str, int, int]]: ) -> List[Tuple[str, int, int]]:
input_lens = np.random.randint( input_lens = np.random.randint(
...@@ -291,11 +299,60 @@ def sample_random_requests( ...@@ -291,11 +299,60 @@ def sample_random_requests(
output_len + 1, output_len + 1,
size=num_prompts, size=num_prompts,
) )
if True:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if not os.path.isfile(dataset_path) and not os.path.isfile(
default_sharegpt_path
):
download_sharegpt_dataset(default_sharegpt_path)
dataset_path = default_sharegpt_path
else:
dataset_path = (
dataset_path if os.path.isfile(dataset_path) else default_sharegpt_path
)
# Load the dataset.
with open(dataset_path) as f:
dataset = json.load(f)
# Filter out the conversations with less than 2 turns.
dataset = [data for data in dataset if len(data["conversations"]) >= 2]
# Only keep the first two turns of each conversation.
dataset = [
(data["conversations"][0]["value"], data["conversations"][1]["value"])
for data in dataset
]
# Shuffle the dataset.
random.shuffle(dataset)
# Filter out sequences that are too long or too short
input_requests: List[Tuple[str, int, int]] = []
for i in range(num_prompts):
# Tokenize the prompts and completions.
prompt = dataset[i][0]
prompt_token_ids = tokenizer(prompt).input_ids
prompt_len = len(prompt_token_ids)
if prompt_len <= input_lens[i]:
input_ids = prompt_token_ids[: input_lens[i]]
else:
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
prompt = tokenizer.decode(input_ids)
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
else:
# Sample token ids from random integers. This can cause some NaN issues.
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts) offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
input_requests = [] input_requests = []
for i in range(num_prompts): for i in range(num_prompts):
prompt = tokenizer.decode( prompt = tokenizer.decode(
[(offsets[i] + i + j) % tokenizer.vocab_size for j in range(input_lens[i])] [
(offsets[i] + i + j) % tokenizer.vocab_size
for j in range(input_lens[i])
]
) )
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i]))) input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
...@@ -575,6 +632,7 @@ def fire(args: argparse.Namespace): ...@@ -575,6 +632,7 @@ def fire(args: argparse.Namespace):
num_prompts=args.num_prompts, num_prompts=args.num_prompts,
range_ratio=args.random_range_ratio, range_ratio=args.random_range_ratio,
tokenizer=tokenizer, tokenizer=tokenizer,
dataset_path=args.dataset_path,
) )
else: else:
raise ValueError(f"Unknown dataset: {args.dataset_name}") raise ValueError(f"Unknown dataset: {args.dataset_name}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment