Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
93872128
Commit
93872128
authored
Sep 26, 2024
by
zhuwenwen
Browse files
support head_dim 160 and update benchmark_throughput.py
parent
a087cda8
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
14 additions
and
13 deletions
+14
-13
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+4
-5
csrc/attention/attention_kernels.cu
csrc/attention/attention_kernels.cu
+6
-0
csrc/attention/static_switch.h
csrc/attention/static_switch.h
+0
-3
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+4
-5
No files found.
benchmarks/benchmark_throughput.py
View file @
93872128
...
...
@@ -348,12 +348,11 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
warmup_prompt
=
"hi"
*
10
warmup_requests
=
[(
warmup_prompt
,
10
,
10
)
for
_
in
range
(
1
)]
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
prompt
=
"hi"
*
(
args
.
input_len
-
1
)
requests
=
[(
prompt
,
args
.
input_len
,
args
.
output_len
)
for
_
in
range
(
args
.
num_prompts
)]
...
...
@@ -363,7 +362,7 @@ def main(args: argparse.Namespace):
if
args
.
backend
==
"vllm"
:
run_args
=
[
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
warmup_requests
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
...
...
csrc/attention/attention_kernels.cu
View file @
93872128
...
...
@@ -757,6 +757,9 @@ void paged_attention_v1_launcher(
case
128
:
LAUNCH_PAGED_ATTENTION_V1
(
128
);
break
;
case
160
:
LAUNCH_PAGED_ATTENTION_V1
(
160
);
break
;
case
192
:
LAUNCH_PAGED_ATTENTION_V1
(
192
);
break
;
...
...
@@ -921,6 +924,9 @@ void paged_attention_v2_launcher(
case
128
:
LAUNCH_PAGED_ATTENTION_V2
(
128
);
break
;
case
160
:
LAUNCH_PAGED_ATTENTION_V2
(
160
);
break
;
case
192
:
LAUNCH_PAGED_ATTENTION_V2
(
192
);
break
;
...
...
csrc/attention/static_switch.h
View file @
93872128
...
...
@@ -37,9 +37,6 @@
} else if (HEADDIM == 112) { \
constexpr static int HEAD_SIZE = 112; \
return __VA_ARGS__(); \
} else if (HEADDIM == 120) { \
constexpr static int HEAD_SIZE = 120; \
return __VA_ARGS__(); \
} else if (HEADDIM == 128) { \
constexpr static int HEAD_SIZE = 128; \
return __VA_ARGS__(); \
...
...
vllm/benchmarks/benchmark_throughput.py
View file @
93872128
...
...
@@ -348,12 +348,11 @@ def main(args: argparse.Namespace):
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
warmup_prompt
=
"hi"
*
10
warmup_requests
=
[(
warmup_prompt
,
10
,
10
)
for
_
in
range
(
1
)]
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
prompt
=
"hi"
*
(
args
.
input_len
-
1
)
requests
=
[(
prompt
,
args
.
input_len
,
args
.
output_len
)
for
_
in
range
(
args
.
num_prompts
)]
...
...
@@ -363,7 +362,7 @@ def main(args: argparse.Namespace):
if
args
.
backend
==
"vllm"
:
run_args
=
[
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
warmup_requests
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment