Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
9592a1f3
Commit
9592a1f3
authored
Jul 20, 2024
by
Lianmin Zheng
Browse files
Fix random dataset (#671)
parent
35759efa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
93 additions
and
35 deletions
+93
-35
python/sglang/bench_serving.py
python/sglang/bench_serving.py
+93
-35
No files found.
python/sglang/bench_serving.py
View file @
9592a1f3
...
...
@@ -192,6 +192,36 @@ class BenchmarkMetrics:
p99_itl_ms
:
float
default_sharegpt_path
=
"ShareGPT_V3_unfiltered_cleaned_split.json"
def
download_sharegpt_dataset
(
path
):
url
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
print
(
f
"Downloading dataset from
{
url
}
"
)
try
:
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
block_size
=
8192
with
open
(
path
,
"wb"
)
as
f
,
tqdm
(
desc
=
"Downloading"
,
total
=
total_size
,
unit
=
"iB"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
)
as
progress_bar
:
for
data
in
response
.
iter_content
(
block_size
):
size
=
f
.
write
(
data
)
progress_bar
.
update
(
size
)
print
(
f
"Dataset downloaded and saved to
{
path
}
"
)
except
requests
.
RequestException
as
e
:
raise
Exception
(
f
"Failed to download dataset:
{
e
}
"
)
def
sample_sharegpt_requests
(
dataset_path
:
str
,
num_requests
:
int
,
...
...
@@ -201,36 +231,13 @@ def sample_sharegpt_requests(
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
default_dataset_path
=
"ShareGPT_V3_unfiltered_cleaned_split.json"
url
=
"https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json"
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_dataset_path
):
print
(
f
"Downloading dataset from
{
url
}
"
)
try
:
response
=
requests
.
get
(
url
,
stream
=
True
)
response
.
raise_for_status
()
total_size
=
int
(
response
.
headers
.
get
(
"content-length"
,
0
))
block_size
=
8192
with
open
(
default_dataset_path
,
"wb"
)
as
f
,
tqdm
(
desc
=
"Downloading"
,
total
=
total_size
,
unit
=
"iB"
,
unit_scale
=
True
,
unit_divisor
=
1024
,
)
as
progress_bar
:
for
data
in
response
.
iter_content
(
block_size
):
size
=
f
.
write
(
data
)
progress_bar
.
update
(
size
)
print
(
f
"Dataset downloaded and saved to
{
default_dataset_path
}
"
)
dataset_path
=
default_dataset_path
except
requests
.
RequestException
as
e
:
raise
Exception
(
f
"Failed to download dataset:
{
e
}
"
)
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_sharegpt_path
):
download_sharegpt_dataset
(
default_sharegpt_path
)
dataset_path
=
default_sharegpt_path
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_
datase
t_path
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_
sharegp
t_path
)
# Load the dataset.
...
...
@@ -279,6 +286,7 @@ def sample_random_requests(
num_prompts
:
int
,
range_ratio
:
float
,
tokenizer
:
PreTrainedTokenizerBase
,
dataset_path
:
str
,
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
input_lens
=
np
.
random
.
randint
(
...
...
@@ -291,13 +299,62 @@ def sample_random_requests(
output_len
+
1
,
size
=
num_prompts
,
)
offsets
=
np
.
random
.
randint
(
0
,
tokenizer
.
vocab_size
,
size
=
num_prompts
)
input_requests
=
[]
for
i
in
range
(
num_prompts
):
prompt
=
tokenizer
.
decode
(
[(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
for
j
in
range
(
input_lens
[
i
])]
)
input_requests
.
append
((
prompt
,
int
(
input_lens
[
i
]),
int
(
output_lens
[
i
])))
if
True
:
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
# Download sharegpt if necessary
if
not
os
.
path
.
isfile
(
dataset_path
)
and
not
os
.
path
.
isfile
(
default_sharegpt_path
):
download_sharegpt_dataset
(
default_sharegpt_path
)
dataset_path
=
default_sharegpt_path
else
:
dataset_path
=
(
dataset_path
if
os
.
path
.
isfile
(
dataset_path
)
else
default_sharegpt_path
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[
(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Shuffle the dataset.
random
.
shuffle
(
dataset
)
# Filter out sequences that are too long or too short
input_requests
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
i
in
range
(
num_prompts
):
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt_token_ids
=
tokenizer
(
prompt
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
if
prompt_len
<=
input_lens
[
i
]:
input_ids
=
prompt_token_ids
[:
input_lens
[
i
]]
else
:
ratio
=
(
input_lens
[
i
]
+
prompt_len
-
1
)
//
prompt_len
input_ids
=
(
prompt_token_ids
*
ratio
)[:
input_lens
[
i
]]
prompt
=
tokenizer
.
decode
(
input_ids
)
input_requests
.
append
((
prompt
,
int
(
input_lens
[
i
]),
int
(
output_lens
[
i
])))
else
:
# Sample token ids from random integers. This can cause some NaN issues.
offsets
=
np
.
random
.
randint
(
0
,
tokenizer
.
vocab_size
,
size
=
num_prompts
)
input_requests
=
[]
for
i
in
range
(
num_prompts
):
prompt
=
tokenizer
.
decode
(
[
(
offsets
[
i
]
+
i
+
j
)
%
tokenizer
.
vocab_size
for
j
in
range
(
input_lens
[
i
])
]
)
input_requests
.
append
((
prompt
,
int
(
input_lens
[
i
]),
int
(
output_lens
[
i
])))
print
(
f
"#Input tokens:
{
np
.
sum
(
input_lens
)
}
"
)
print
(
f
"#Output tokens:
{
np
.
sum
(
output_lens
)
}
"
)
...
...
@@ -575,6 +632,7 @@ def fire(args: argparse.Namespace):
num_prompts
=
args
.
num_prompts
,
range_ratio
=
args
.
random_range_ratio
,
tokenizer
=
tokenizer
,
dataset_path
=
args
.
dataset_path
,
)
else
:
raise
ValueError
(
f
"Unknown dataset:
{
args
.
dataset_name
}
"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment