Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
efb2f75f
Commit
efb2f75f
authored
Aug 10, 2024
by
zhuwenwen
Browse files
add benchmarks to vllm and add env of VLLM_USE_FLASH_ATTN_AUTO
parent
749242a0
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
524 additions
and
13 deletions
+524
-13
benchmarks/benchmark_throughput.py
benchmarks/benchmark_throughput.py
+7
-5
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+48
-8
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+464
-0
vllm/envs.py
vllm/envs.py
+5
-0
No files found.
benchmarks/benchmark_throughput.py
View file @
efb2f75f
...
...
@@ -136,7 +136,9 @@ def run_vllm(
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
llm
.
generate
(
warmup_prompts
,
warmup_sampling_params
,
use_tqdm
=
True
)
# dummy_prompt_token_ids = np.random.randint(10000,
...
...
@@ -335,10 +337,10 @@ if __name__ == "__main__":
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
#
parser.add_argument('--num-iters-warmup',
#
type=int,
#
default=1,
#
help='Number of iterations to run for warmup.')
parser
.
add_argument
(
'--num-iters-warmup'
,
type
=
int
,
default
=
1
,
help
=
'Number of iterations to run for warmup.'
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
efb2f75f
...
...
@@ -228,13 +228,25 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
use_naive_attn
=
False
# NOTE: Allow for switching between Triton and CK. Defaulting to triton.
self
.
use_triton_flash_attn
=
envs
.
VLLM_USE_TRITON_FLASH_ATTN
# NOTE: Allow automatic switching between Triton and CK. Defaulting to triton when seqlen > 8192
self
.
use_flash_attn_auto
=
envs
.
VLLM_USE_FLASH_ATTN_AUTO
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func_triton
=
flash_attn_varlen_func
from
flash_attn
import
flash_attn_varlen_func
# noqa: F401
self
.
attn_func_ck
=
flash_attn_varlen_func
logger
.
debug
(
"When SEQ_LEN > 8192, Use Triton FA in ROCmBackend, otherwise Use CK FA"
)
else
:
# from vllm.attention.ops.triton_flash_attention import ( # noqa: F401
# triton_attention)
from
vllm.attention.ops.flash_attn_triton_mqa_gqa
import
(
flash_attn_varlen_func
)
self
.
attn_func
=
flash_attn_varlen_func
# triton_attention
logger
.
debug
(
"Using Triton FA in ROCmBackend"
)
else
:
# if not using triton, navi3x/navi21/navi10 do not use flash-attn
# either
...
...
@@ -327,6 +339,32 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# When block_tables are not filled, it means q and k are the
# prompt, and they have the same length.
if
self
.
use_triton_flash_attn
:
if
self
.
use_flash_attn_auto
:
if
prefill_meta
.
max_prefill_seq_len
>
8192
:
out
=
self
.
attn_func_triton
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlens_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlens_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
out
=
self
.
attn_func_ck
(
q
=
query
,
k
=
key
,
v
=
value
,
cu_seqlens_q
=
prefill_meta
.
seq_start_loc
,
cu_seqlens_k
=
prefill_meta
.
seq_start_loc
,
max_seqlen_q
=
prefill_meta
.
max_prefill_seq_len
,
max_seqlen_k
=
prefill_meta
.
max_prefill_seq_len
,
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
else
:
# out, _ = self.attn_func(
# query,
# key,
...
...
@@ -338,6 +376,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
# prefill_meta.max_prefill_seq_len,
# True,
# self.scale,
# )
out
=
self
.
attn_func
(
q
=
query
,
k
=
key
,
...
...
@@ -349,6 +388,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
softmax_scale
=
self
.
scale
,
causal
=
True
,
)
elif
self
.
use_naive_attn
:
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Interleave for MQA workaround.
...
...
vllm/benchmarks/benchmark_throughput.py
0 → 100644
View file @
efb2f75f
"""Benchmark offline inference throughput."""
import
argparse
import
json
import
random
import
time
from
typing
import
List
,
Optional
,
Tuple
import
numpy
as
np
import
torch
from
tqdm
import
tqdm
from
transformers
import
(
AutoModelForCausalLM
,
AutoTokenizer
,
PreTrainedTokenizerBase
)
from
vllm.inputs
import
PromptStrictInputs
from
vllm.model_executor.layers.quantization
import
QUANTIZATION_METHODS
def
sample_requests
(
dataset_path
:
str
,
num_requests
:
int
,
tokenizer
:
PreTrainedTokenizerBase
,
fixed_output_len
:
Optional
[
int
],
)
->
List
[
Tuple
[
str
,
int
,
int
]]:
if
fixed_output_len
is
not
None
and
fixed_output_len
<
4
:
raise
ValueError
(
"output_len too small"
)
# Load the dataset.
with
open
(
dataset_path
)
as
f
:
dataset
=
json
.
load
(
f
)
# Filter out the conversations with less than 2 turns.
dataset
=
[
data
for
data
in
dataset
if
len
(
data
[
"conversations"
])
>=
2
]
# Only keep the first two turns of each conversation.
dataset
=
[(
data
[
"conversations"
][
0
][
"value"
],
data
[
"conversations"
][
1
][
"value"
])
for
data
in
dataset
]
# Shuffle the dataset.
random
.
shuffle
(
dataset
)
# Filter out sequences that are too long or too short
filtered_dataset
:
List
[
Tuple
[
str
,
int
,
int
]]
=
[]
for
i
in
range
(
len
(
dataset
)):
if
len
(
filtered_dataset
)
==
num_requests
:
break
# Tokenize the prompts and completions.
prompt
=
dataset
[
i
][
0
]
prompt_token_ids
=
tokenizer
(
prompt
).
input_ids
completion
=
dataset
[
i
][
1
]
completion_token_ids
=
tokenizer
(
completion
).
input_ids
prompt_len
=
len
(
prompt_token_ids
)
output_len
=
len
(
completion_token_ids
)
if
fixed_output_len
is
None
else
fixed_output_len
if
prompt_len
<
4
or
output_len
<
4
:
# Prune too short sequences.
continue
if
prompt_len
>
1024
or
prompt_len
+
output_len
>
2048
:
# Prune too long sequences.
continue
filtered_dataset
.
append
((
prompt
,
prompt_len
,
output_len
))
return
filtered_dataset
def
run_vllm
(
warmup_requests
:
List
[
Tuple
[
str
,
int
,
int
]],
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
str
,
quantization
:
Optional
[
str
],
tensor_parallel_size
:
int
,
seed
:
int
,
n
:
int
,
use_beam_search
:
bool
,
trust_remote_code
:
bool
,
dtype
:
str
,
max_model_len
:
Optional
[
int
],
enforce_eager
:
bool
,
kv_cache_dtype
:
str
,
quantization_param_path
:
Optional
[
str
],
device
:
str
,
enable_prefix_caching
:
bool
,
enable_chunked_prefill
:
bool
,
max_num_batched_tokens
:
int
,
distributed_executor_backend
:
Optional
[
str
],
gpu_memory_utilization
:
float
=
0.9
,
download_dir
:
Optional
[
str
]
=
None
,
)
->
float
:
from
vllm
import
LLM
,
SamplingParams
llm
=
LLM
(
model
=
model
,
tokenizer
=
tokenizer
,
quantization
=
quantization
,
tensor_parallel_size
=
tensor_parallel_size
,
seed
=
seed
,
trust_remote_code
=
trust_remote_code
,
dtype
=
dtype
,
max_model_len
=
max_model_len
,
gpu_memory_utilization
=
gpu_memory_utilization
,
enforce_eager
=
enforce_eager
,
kv_cache_dtype
=
kv_cache_dtype
,
quantization_param_path
=
quantization_param_path
,
device
=
device
,
enable_prefix_caching
=
enable_prefix_caching
,
download_dir
=
download_dir
,
enable_chunked_prefill
=
enable_chunked_prefill
,
max_num_batched_tokens
=
max_num_batched_tokens
,
distributed_executor_backend
=
distributed_executor_backend
,
)
# Add the requests to the engine.
prompts
=
[]
sampling_params
=
[]
for
prompt
,
_
,
output_len
in
requests
:
prompts
.
append
(
prompt
)
sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
# warmup
warmup_prompts
=
[]
warmup_sampling_params
=
[]
for
prompt
,
_
,
output_len
in
warmup_requests
:
warmup_prompts
.
append
(
prompt
)
warmup_sampling_params
.
append
(
SamplingParams
(
n
=
n
,
temperature
=
0.0
if
use_beam_search
else
1.0
,
top_p
=
1.0
,
use_beam_search
=
use_beam_search
,
ignore_eos
=
True
,
max_tokens
=
output_len
,
))
print
(
"Warming up..."
)
for
_
in
tqdm
(
range
(
args
.
num_iters_warmup
),
desc
=
"Warmup iterations"
):
llm
.
generate
(
warmup_prompts
,
warmup_sampling_params
,
use_tqdm
=
True
)
# dummy_prompt_token_ids = np.random.randint(10000,
# size=(args.num_prompts,
# args.input_len))
# dummy_inputs: List[PromptStrictInputs] = [{
# "prompt_token_ids": batch
# } for batch in dummy_prompt_token_ids.tolist()]
# def run_to_completion():
# llm.generate(dummy_inputs,
# sampling_params=sampling_params,
# use_tqdm=False)
# print("Warming up...")
# for _ in tqdm(range(args.num_iters_warmup), desc="Warmup iterations"):
# run_to_completion()
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
sampling_params
,
use_tqdm
=
True
)
end
=
time
.
perf_counter
()
return
end
-
start
def
run_hf
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
n
:
int
,
use_beam_search
:
bool
,
max_batch_size
:
int
,
trust_remote_code
:
bool
,
)
->
float
:
assert
not
use_beam_search
llm
=
AutoModelForCausalLM
.
from_pretrained
(
model
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
trust_remote_code
)
if
llm
.
config
.
model_type
==
"llama"
:
# To enable padding in the HF backend.
tokenizer
.
pad_token
=
tokenizer
.
eos_token
llm
=
llm
.
cuda
()
pbar
=
tqdm
(
total
=
len
(
requests
))
start
=
time
.
perf_counter
()
batch
:
List
[
str
]
=
[]
max_prompt_len
=
0
max_output_len
=
0
for
i
in
range
(
len
(
requests
)):
prompt
,
prompt_len
,
output_len
=
requests
[
i
]
# Add the prompt to the batch.
batch
.
append
(
prompt
)
max_prompt_len
=
max
(
max_prompt_len
,
prompt_len
)
max_output_len
=
max
(
max_output_len
,
output_len
)
if
len
(
batch
)
<
max_batch_size
and
i
!=
len
(
requests
)
-
1
:
# Check if we can add more requests to the batch.
_
,
next_prompt_len
,
next_output_len
=
requests
[
i
+
1
]
if
(
max
(
max_prompt_len
,
next_prompt_len
)
+
max
(
max_output_len
,
next_output_len
))
<=
2048
:
# We can add more requests to the batch.
continue
# Generate the sequences.
input_ids
=
tokenizer
(
batch
,
return_tensors
=
"pt"
,
padding
=
True
).
input_ids
llm_outputs
=
llm
.
generate
(
input_ids
=
input_ids
.
cuda
(),
do_sample
=
not
use_beam_search
,
num_return_sequences
=
n
,
temperature
=
1.0
,
top_p
=
1.0
,
use_cache
=
True
,
max_new_tokens
=
max_output_len
,
)
# Include the decoding time.
tokenizer
.
batch_decode
(
llm_outputs
,
skip_special_tokens
=
True
)
pbar
.
update
(
len
(
batch
))
# Clear the batch.
batch
=
[]
max_prompt_len
=
0
max_output_len
=
0
end
=
time
.
perf_counter
()
return
end
-
start
def
run_mii
(
requests
:
List
[
Tuple
[
str
,
int
,
int
]],
model
:
str
,
tensor_parallel_size
:
int
,
output_len
:
int
,
)
->
float
:
from
mii
import
client
,
serve
llm
=
serve
(
model
,
tensor_parallel
=
tensor_parallel_size
)
prompts
=
[
prompt
for
prompt
,
_
,
_
in
requests
]
start
=
time
.
perf_counter
()
llm
.
generate
(
prompts
,
max_new_tokens
=
output_len
)
end
=
time
.
perf_counter
()
client
=
client
(
model
)
client
.
terminate_server
()
return
end
-
start
def
main
(
args
:
argparse
.
Namespace
):
print
(
args
)
random
.
seed
(
args
.
seed
)
# Sample the requests.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
tokenizer
,
trust_remote_code
=
args
.
trust_remote_code
)
if
args
.
dataset
is
None
:
# Synthesize a prompt with the given input length.
warmup_prompt
=
"hi"
*
10
warmup_requests
=
[(
warmup_prompt
,
10
,
10
)
for
_
in
range
(
1
)]
prompt
=
"hi"
*
(
args
.
input_len
-
1
)
requests
=
[(
prompt
,
args
.
input_len
,
args
.
output_len
)
for
_
in
range
(
args
.
num_prompts
)]
else
:
requests
=
sample_requests
(
args
.
dataset
,
args
.
num_prompts
,
tokenizer
,
args
.
output_len
)
if
args
.
backend
==
"vllm"
:
elapsed_time
=
run_vllm
(
warmup_requests
,
requests
,
args
.
model
,
args
.
tokenizer
,
args
.
quantization
,
args
.
tensor_parallel_size
,
args
.
seed
,
args
.
n
,
args
.
use_beam_search
,
args
.
trust_remote_code
,
args
.
dtype
,
args
.
max_model_len
,
args
.
enforce_eager
,
args
.
kv_cache_dtype
,
args
.
quantization_param_path
,
args
.
device
,
args
.
enable_prefix_caching
,
args
.
enable_chunked_prefill
,
args
.
max_num_batched_tokens
,
args
.
distributed_executor_backend
,
args
.
gpu_memory_utilization
,
args
.
download_dir
)
elif
args
.
backend
==
"hf"
:
assert
args
.
tensor_parallel_size
==
1
elapsed_time
=
run_hf
(
requests
,
args
.
model
,
tokenizer
,
args
.
n
,
args
.
use_beam_search
,
args
.
hf_max_batch_size
,
args
.
trust_remote_code
)
elif
args
.
backend
==
"mii"
:
elapsed_time
=
run_mii
(
requests
,
args
.
model
,
args
.
tensor_parallel_size
,
args
.
output_len
)
else
:
raise
ValueError
(
f
"Unknown backend:
{
args
.
backend
}
"
)
total_num_tokens
=
sum
(
prompt_len
+
output_len
for
_
,
prompt_len
,
output_len
in
requests
)
if
args
.
dataset
is
None
:
total_out_tokens
=
args
.
output_len
*
args
.
num_prompts
else
:
total_out_tokens
=
sum
(
output_len
for
_
,
_
,
output_len
in
requests
)
print
(
f
"Latency:
{
elapsed_time
:.
2
f
}
s"
)
print
(
f
"All Throughput:
{
len
(
requests
)
/
elapsed_time
:.
2
f
}
requests/s, "
f
"
{
total_num_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
print
(
f
"Generate Throughput:
{
total_out_tokens
/
elapsed_time
:.
2
f
}
tokens/s"
)
# Output JSON results if specified
if
args
.
output_json
:
results
=
{
"elapsed_time"
:
elapsed_time
,
"num_requests"
:
len
(
requests
),
"total_num_tokens"
:
total_num_tokens
,
"requests_per_second"
:
len
(
requests
)
/
elapsed_time
,
"tokens_per_second"
:
total_num_tokens
/
elapsed_time
,
}
with
open
(
args
.
output_json
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
4
)
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
(
description
=
"Benchmark the throughput."
)
parser
.
add_argument
(
"--backend"
,
type
=
str
,
choices
=
[
"vllm"
,
"hf"
,
"mii"
],
default
=
"vllm"
)
parser
.
add_argument
(
"--dataset"
,
type
=
str
,
default
=
None
,
help
=
"Path to the dataset."
)
parser
.
add_argument
(
"--input-len"
,
type
=
int
,
default
=
None
,
help
=
"Input prompt length for each request"
)
parser
.
add_argument
(
"--output-len"
,
type
=
int
,
default
=
None
,
help
=
"Output length for each request. Overrides the "
"output length from the dataset."
)
parser
.
add_argument
(
"--model"
,
type
=
str
,
default
=
"facebook/opt-125m"
)
parser
.
add_argument
(
"--tokenizer"
,
type
=
str
,
default
=
None
)
parser
.
add_argument
(
'--quantization'
,
'-q'
,
choices
=
[
*
QUANTIZATION_METHODS
,
None
],
default
=
None
)
parser
.
add_argument
(
"--tensor-parallel-size"
,
"-tp"
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--n"
,
type
=
int
,
default
=
1
,
help
=
"Number of generated sequences per prompt."
)
parser
.
add_argument
(
"--use-beam-search"
,
action
=
"store_true"
)
parser
.
add_argument
(
'--num-iters-warmup'
,
type
=
int
,
default
=
1
,
help
=
'Number of iterations to run for warmup.'
)
parser
.
add_argument
(
"--num-prompts"
,
type
=
int
,
default
=
1000
,
help
=
"Number of prompts to process."
)
parser
.
add_argument
(
"--seed"
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
"--hf-max-batch-size"
,
type
=
int
,
default
=
None
,
help
=
"Maximum batch size for HF backend."
)
parser
.
add_argument
(
'--trust-remote-code'
,
action
=
'store_true'
,
help
=
'trust remote code from huggingface'
)
parser
.
add_argument
(
'--max-model-len'
,
type
=
int
,
default
=
None
,
help
=
'Maximum length of a sequence (including prompt and output). '
'If None, will be derived from the model.'
)
parser
.
add_argument
(
'--dtype'
,
type
=
str
,
default
=
'auto'
,
choices
=
[
'auto'
,
'half'
,
'float16'
,
'bfloat16'
,
'float'
,
'float32'
],
help
=
'data type for model weights and activations. '
'The "auto" option will use FP16 precision '
'for FP32 and FP16 models, and BF16 precision '
'for BF16 models.'
)
parser
.
add_argument
(
'--gpu-memory-utilization'
,
type
=
float
,
default
=
0.9
,
help
=
'the fraction of GPU memory to be used for '
'the model executor, which can range from 0 to 1.'
'If unspecified, will use the default value of 0.9.'
)
parser
.
add_argument
(
"--enforce-eager"
,
action
=
"store_true"
,
help
=
"enforce eager execution"
)
parser
.
add_argument
(
'--kv-cache-dtype'
,
type
=
str
,
choices
=
[
'auto'
,
'fp8'
,
'fp8_e5m2'
,
'fp8_e4m3'
],
default
=
"auto"
,
help
=
'Data type for kv cache storage. If "auto", will use model '
'data type. CUDA 11.8+ supports fp8 (=fp8_e4m3) and fp8_e5m2. '
'ROCm (AMD GPU) supports fp8 (=fp8_e4m3)'
)
parser
.
add_argument
(
'--quantization-param-path'
,
type
=
str
,
default
=
None
,
help
=
'Path to the JSON file containing the KV cache scaling factors. '
'This should generally be supplied, when KV cache dtype is FP8. '
'Otherwise, KV cache scaling factors default to 1.0, which may cause '
'accuracy issues. FP8_E5M2 (without scaling) is only supported on '
'cuda version greater than 11.8. On ROCm (AMD GPU), FP8_E4M3 is '
'instead supported for common inference criteria.'
)
parser
.
add_argument
(
"--device"
,
type
=
str
,
default
=
"cuda"
,
choices
=
[
"cuda"
,
"cpu"
],
help
=
'device type for vLLM execution, supporting CUDA and CPU.'
)
parser
.
add_argument
(
"--enable-prefix-caching"
,
action
=
'store_true'
,
help
=
"enable automatic prefix caching for vLLM backend."
)
parser
.
add_argument
(
"--enable-chunked-prefill"
,
action
=
'store_true'
,
help
=
"enable chunked prefill for vLLM backend."
)
parser
.
add_argument
(
'--max-num-batched-tokens'
,
type
=
int
,
default
=
None
,
help
=
'maximum number of batched tokens per '
'iteration'
)
parser
.
add_argument
(
'--download-dir'
,
type
=
str
,
default
=
None
,
help
=
'directory to download and load the weights, '
'default to the default cache dir of huggingface'
)
parser
.
add_argument
(
'--output-json'
,
type
=
str
,
default
=
None
,
help
=
'Path to save the throughput results in JSON format.'
)
parser
.
add_argument
(
'--distributed-executor-backend'
,
choices
=
[
'ray'
,
'mp'
],
default
=
None
,
help
=
'Backend to use for distributed serving. When more than 1 GPU '
'is used, will be automatically set to "ray" if installed '
'or "mp" (multiprocessing) otherwise.'
)
args
=
parser
.
parse_args
()
if
args
.
tokenizer
is
None
:
args
.
tokenizer
=
args
.
model
if
args
.
dataset
is
None
:
assert
args
.
input_len
is
not
None
assert
args
.
output_len
is
not
None
else
:
assert
args
.
input_len
is
None
if
args
.
backend
==
"vllm"
:
if
args
.
hf_max_batch_size
is
not
None
:
raise
ValueError
(
"HF max batch size is only for HF backend."
)
elif
args
.
backend
==
"hf"
:
if
args
.
hf_max_batch_size
is
None
:
raise
ValueError
(
"HF max batch size is required for HF backend."
)
if
args
.
quantization
is
not
None
:
raise
ValueError
(
"Quantization is only for vLLM backend."
)
elif
args
.
backend
==
"mii"
:
if
args
.
dtype
!=
"auto"
:
raise
ValueError
(
"dtype must be auto for MII backend."
)
if
args
.
n
!=
1
:
raise
ValueError
(
"n must be 1 for MII backend."
)
if
args
.
use_beam_search
:
raise
ValueError
(
"Beam search is not supported for MII backend."
)
if
args
.
quantization
is
not
None
:
raise
ValueError
(
"Quantization is only for vLLM backend."
)
if
args
.
hf_max_batch_size
is
not
None
:
raise
ValueError
(
"HF max batch size is only for HF backend."
)
if
args
.
tokenizer
!=
args
.
model
:
raise
ValueError
(
"Tokenizer must be the same as the model for MII "
"backend."
)
main
(
args
)
\ No newline at end of file
vllm/envs.py
View file @
efb2f75f
...
...
@@ -133,6 +133,11 @@ environment_variables: Dict[str, Callable[[], Any]] = {
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_TRITON_FLASH_ATTN"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# flag to control vllm to automatically switch between Triton FA and CK FA
"VLLM_USE_FLASH_ATTN_AUTO"
:
lambda
:
(
os
.
environ
.
get
(
"VLLM_USE_FLASH_ATTN_AUTO"
,
"False"
).
lower
()
in
(
"true"
,
"1"
)),
# local rank of the process in the distributed setting, used to determine
# the GPU device id
"LOCAL_RANK"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment