Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fcfc474d
Commit
fcfc474d
authored
Apr 09, 2025
by
zhuwenwen
Browse files
Merge tag 'v0.8.3' into v0.8.3-dev
parents
bb94d2e5
296c6572
Changes
503
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1120 additions
and
268 deletions
+1120
-268
tests/v1/tpu/test_sampler.py
tests/v1/tpu/test_sampler.py
+5
-64
tests/v1/tpu/test_topk_topp_sampler.py
tests/v1/tpu/test_topk_topp_sampler.py
+132
-0
tests/v1/tpu/worker/__init__.py
tests/v1/tpu/worker/__init__.py
+0
-0
tests/v1/tpu/worker/test_tpu_model_runner.py
tests/v1/tpu/worker/test_tpu_model_runner.py
+311
-0
tools/shellcheck.sh
tools/shellcheck.sh
+2
-2
vllm/__init__.py
vllm/__init__.py
+4
-16
vllm/_custom_ops.py
vllm/_custom_ops.py
+84
-13
vllm/_ipex_ops.py
vllm/_ipex_ops.py
+22
-9
vllm/assets/video.py
vllm/assets/video.py
+11
-8
vllm/attention/backends/cpu_mla.py
vllm/attention/backends/cpu_mla.py
+303
-0
vllm/attention/backends/mla/common.py
vllm/attention/backends/mla/common.py
+13
-5
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+8
-20
vllm/attention/ops/chunked_prefill_paged_decode.py
vllm/attention/ops/chunked_prefill_paged_decode.py
+103
-54
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+38
-47
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+2
-0
vllm/attention/ops/prefix_prefill.py
vllm/attention/ops/prefix_prefill.py
+2
-1
vllm/attention/ops/triton_merge_attn_states.py
vllm/attention/ops/triton_merge_attn_states.py
+9
-0
vllm/benchmarks/backend_request_func.py
vllm/benchmarks/backend_request_func.py
+9
-1
vllm/benchmarks/benchmark_serving.py
vllm/benchmarks/benchmark_serving.py
+31
-12
vllm/benchmarks/benchmark_throughput.py
vllm/benchmarks/benchmark_throughput.py
+31
-16
No files found.
tests/v1/tpu/test_sampler.py
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
tempfile
from
time
import
time
import
pytest
from
vllm
import
LLM
,
envs
...
...
@@ -15,60 +12,6 @@ if not envs.VLLM_USE_V1:
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"D4nt3/Qwen2.5-two-layers"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
def
test_sampler_compilation
(
model_name
:
str
,
monkeypatch
):
"""
Check that no recompilation happens despite changing sampling parameters.
We can't read XLA metrics from the engine process, hence we measure time.
"""
with
tempfile
.
TemporaryDirectory
()
as
temp_dir
:
monkeypatch
.
setenv
(
"VLLM_XLA_CACHE_PATH"
,
temp_dir
)
# Compiling model init may still take some time, enforce_eager to skip.
llm
=
LLM
(
model_name
,
enforce_eager
=
True
,
max_num_seqs
=
16
,
max_model_len
=
1024
,
gpu_memory_utilization
=
0.5
)
prompts
=
[
"A robot may not injure a human being"
,
"It is only with the heart that one can see rightly;"
,
]
# First inference should be slow
sampling_params
=
SamplingParams
(
temperature
=
0.7
,
# top_p=0.6, # TODO too slow!
top_k
=
10
,
min_p
=
0.2
,
max_tokens
=
16
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run1
=
time
()
-
s
# Second request with different params, but for which we
# compiled for in previous eager iteration.
sampling_params
=
SamplingParams
(
temperature
=
0.1
,
top_k
=
12
,
min_p
=
0.8
,
max_tokens
=
24
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run2
=
time
()
-
s
# Much faster after compiling
assert
run1
*
0.1
>
run2
print
(
"TIMES"
,
run1
,
run2
)
# Third request with min_p set to "None". It will not trigger
# recompilation as a default 0 value will be used.
sampling_params
=
SamplingParams
(
max_tokens
=
24
,
temperature
=
0.0
)
s
=
time
()
_
=
llm
.
generate
(
prompts
,
sampling_params
)
run3
=
time
()
-
s
assert
run1
*
0.1
>
run3
print
(
"TIMES"
,
run1
,
run3
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
"Qwen/Qwen2.5-1.5B-Instruct"
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
is_tpu
(),
reason
=
"This test needs a TPU"
)
...
...
@@ -77,13 +20,11 @@ def test_sampler_different(model_name: str):
Test significantly different sampling params to assert the model produces
different results.
"""
llm
=
LLM
(
model_name
,
enforce_eager
=
True
,
max_num_seqs
=
1
,
max_model_len
=
64
,
# TODO: setting to 0.5 or it will go OOM
gpu_memory_utilization
=
0.5
)
llm
=
LLM
(
model_name
,
enforce_eager
=
False
,
max_num_seqs
=
1
,
max_model_len
=
512
,
max_num_batched_tokens
=
512
)
prompts
=
[
"Write a short story about a robot that dreams for the first time."
]
...
...
tests/v1/tpu/test_topk_topp_sampler.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
math
import
pytest
import
torch
from
vllm.platforms
import
current_platform
from
vllm.v1.sample.ops.topk_topp_sampler
import
apply_top_k_top_p_tpu
if
not
current_platform
.
is_tpu
():
pytest
.
skip
(
"This test needs a TPU."
,
allow_module_level
=
True
)
import
torch_xla.core.xla_model
as
xm
BATCH_SIZE
=
1024
VOCAB_SIZE
=
128
*
1024
TOLERANCE
=
1e-6
def
test_topp_result_sums_past_p
():
with
torch
.
device
(
xm
.
xla_device
()):
xm
.
set_rng_state
(
seed
=
33
)
logits
=
torch
.
rand
((
BATCH_SIZE
,
VOCAB_SIZE
))
probs
=
logits
.
softmax
(
dim
=-
1
)
# Random top-p values between 0 and 1.
p
=
torch
.
rand
((
BATCH_SIZE
,
))
# Set p=1 for ~50% of requests in the batch (top-p disabled).
p
.
masked_fill_
(
torch
.
randint
(
0
,
2
,
(
BATCH_SIZE
,
),
dtype
=
bool
),
1
)
no_op_k
=
torch
.
tensor
([
VOCAB_SIZE
])
logits_masked
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
no_op_k
,
p
=
p
)
# Verify that the masked logit's probability sums to at least p.
probs
.
masked_fill_
(
logits_masked
.
isinf
(),
0
)
masked_prob_sum
=
probs
.
sum
(
dim
=-
1
)
xm
.
mark_step
()
# Perform assertion on CPU.
assert
torch
.
all
(
torch
.
ge
(
masked_prob_sum
.
cpu
()
+
TOLERANCE
,
p
.
cpu
()))
def
test_topp_basic
():
with
torch
.
device
(
xm
.
xla_device
()):
logits
=
torch
.
tensor
([[
math
.
log
(
0.2
),
math
.
log
(
0.3
),
math
.
log
(
0.5
)],
[
math
.
log
(
0.5
),
math
.
log
(
0.1
),
math
.
log
(
0.4
)]])
result
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
torch
.
tensor
([
3
,
3
]),
p
=
torch
.
tensor
([
0.79
,
0.79
]))
xm
.
mark_step
()
# Expect the smallest elements to be dropped.
expected_result
=
logits
.
clone
().
cpu
()
expected_result
[
0
,
0
]
=
float
(
"-inf"
)
expected_result
[
1
,
1
]
=
float
(
"-inf"
)
assert
torch
.
allclose
(
expected_result
,
result
.
cpu
())
def
test_topp_select_all
():
with
torch
.
device
(
xm
.
xla_device
()):
logits
=
torch
.
tensor
([[
math
.
log
(
0.2
),
math
.
log
(
0.3
),
math
.
log
(
0.5
)],
[
math
.
log
(
0.5
),
math
.
log
(
0.1
),
math
.
log
(
0.4
)]])
result
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
torch
.
tensor
([
3
,
3
]),
p
=
torch
.
tensor
([
1.0
,
1.0
]))
xm
.
mark_step
()
assert
torch
.
allclose
(
logits
.
cpu
(),
result
.
cpu
())
def
test_topp_with_ties
():
with
torch
.
device
(
xm
.
xla_device
()):
# Input has multiple math.log(0.3).
logits
=
torch
.
tensor
(
[[
math
.
log
(
0.3
),
math
.
log
(
0.3
),
math
.
log
(
0.3
),
math
.
log
(
0.1
)]])
result
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
torch
.
tensor
([
4
]),
p
=
torch
.
tensor
([
0.2
]))
xm
.
mark_step
()
# All tie values are included in the top-p set. Tie breaking is left
# to be done during final sampling (all tie tokens have equal
# probability of being chosen).
expected_result
=
logits
.
clone
().
cpu
()
expected_result
[
0
,
3
]
=
float
(
"-inf"
)
assert
torch
.
allclose
(
expected_result
,
result
.
cpu
())
def
test_both_topk_topp
():
with
torch
.
device
(
xm
.
xla_device
()):
logits
=
torch
.
tensor
([[
math
.
log
(
0.2
),
math
.
log
(
0.3
),
math
.
log
(
0.5
)],
[
math
.
log
(
0.5
),
math
.
log
(
0.1
),
math
.
log
(
0.4
)]])
# Set k=1 for the first batch.
result
=
apply_top_k_top_p_tpu
(
logits
=
logits
.
clone
(),
k
=
torch
.
tensor
([
1
,
3
]),
p
=
torch
.
tensor
([
0.79
,
0.79
]))
xm
.
mark_step
()
# Since for the first batch k=1, expect only the largest element gets
# selected.
expected_result
=
logits
.
clone
().
cpu
()
expected_result
[
0
,
0
]
=
float
(
"-inf"
)
expected_result
[
0
,
1
]
=
float
(
"-inf"
)
expected_result
[
1
,
1
]
=
float
(
"-inf"
)
assert
torch
.
allclose
(
expected_result
,
result
.
cpu
())
tests/
lora/data
/__init__.py
→
tests/
v1/tpu/worker
/__init__.py
View file @
fcfc474d
File moved
tests/v1/tpu/worker/test_tpu_model_runner.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
import
unittest.mock
as
mock
import
pytest
from
vllm.config
import
CacheConfig
,
ModelConfig
,
SchedulerConfig
,
VllmConfig
from
vllm.sampling_params
import
SamplingParams
from
vllm.v1.core.sched.output
import
(
CachedRequestData
,
NewRequestData
,
SchedulerOutput
)
from
vllm.v1.worker.tpu_model_runner
import
(
TPUModelRunner
,
_get_padded_token_len
,
_get_paddings
)
# Mock torch_xla module since it may not be available in the test environments
torch_xla_patcher
=
mock
.
patch
.
dict
(
"sys.modules"
,
{
"torch_xla"
:
mock
.
MagicMock
(),
"torch_xla.core.xla_model"
:
mock
.
MagicMock
(),
"torch_xla.runtime"
:
mock
.
MagicMock
(),
})
torch_xla_patcher
.
start
()
# Mock the PallasAttentionBackend
pallas_attention_backend_patcher
=
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.PallasAttentionBackend"
,
)
pallas_attention_backend_patcher
.
start
()
@
pytest
.
fixture
def
model_runner
():
# Patchers have already been started at module level.
scheduler_config
=
SchedulerConfig
(
max_num_seqs
=
10
,
max_num_batched_tokens
=
512
,
max_model_len
=
512
,
)
model_config
=
ModelConfig
(
model
=
"facebook/opt-125m"
,
task
=
"generate"
,
tokenizer
=
"facebook/opt-125m"
,
tokenizer_mode
=
"auto"
,
trust_remote_code
=
True
,
dtype
=
"bfloat16"
,
# TPUs typically use bfloat16
seed
=
42
,
)
cache_config
=
CacheConfig
(
block_size
=
16
,
gpu_memory_utilization
=
0.9
,
swap_space
=
0
,
cache_dtype
=
"auto"
,
)
vllm_config
=
VllmConfig
(
model_config
=
model_config
,
cache_config
=
cache_config
,
scheduler_config
=
scheduler_config
,
)
device
=
"xla:0"
# Mocking TPU device
with
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.torch"
),
\
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.xm"
),
\
mock
.
patch
(
"vllm.v1.worker.tpu_model_runner.xr"
):
return
TPUModelRunner
(
vllm_config
,
device
)
@
pytest
.
fixture
(
autouse
=
True
,
scope
=
"session"
)
def
cleanup_patches
():
yield
torch_xla_patcher
.
stop
()
pallas_attention_backend_patcher
.
stop
()
def
_schedule_new_request
(
*
req_ids
:
str
)
->
SchedulerOutput
:
new_reqs
=
[]
num_scheduled_tokens
=
{}
total_num_scheduled_tokens
=
0
for
req_id
in
req_ids
:
new_reqs
.
append
(
NewRequestData
(
req_id
=
req_id
,
prompt_token_ids
=
[
1
,
2
,
3
],
prompt
=
"test"
,
mm_inputs
=
[],
mm_hashes
=
[],
mm_positions
=
[],
sampling_params
=
SamplingParams
(),
block_ids
=
[
0
],
num_computed_tokens
=
0
,
lora_request
=
None
,
))
num_scheduled_tokens
[
req_id
]
=
3
total_num_scheduled_tokens
+=
num_scheduled_tokens
[
req_id
]
return
SchedulerOutput
(
scheduled_new_reqs
=
new_reqs
,
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
num_scheduled_tokens
,
total_num_scheduled_tokens
=
total_num_scheduled_tokens
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
def
_is_req_scheduled
(
model_runner
,
req_id
:
str
)
->
bool
:
return
req_id
in
model_runner
.
input_batch
.
req_id_to_index
def
_is_req_added
(
model_runner
,
req_id
:
str
)
->
bool
:
return
req_id
in
model_runner
.
requests
def
_is_req_state_block_table_match
(
model_runner
,
req_id
:
str
)
->
bool
:
req_index
=
model_runner
.
input_batch
.
req_id_to_index
[
req_id
]
block_table
=
model_runner
.
input_batch
.
block_table
req_state
=
model_runner
.
requests
[
req_id
]
if
block_table
.
num_blocks_per_row
[
req_index
]
!=
len
(
req_state
.
block_ids
):
return
False
num_blocks
=
block_table
.
num_blocks_per_row
[
req_index
]
return
(
block_table
.
block_table_np
[
req_index
,
:
num_blocks
]
==
req_state
.
block_ids
).
all
()
def
test_update_states_new_request
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_finished
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# finish req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
{
req_id
},
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
assert
not
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
def
test_update_states_request_resumed
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# unschedule req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{},
total_num_scheduled_tokens
=
0
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
not
_is_req_scheduled
(
model_runner
,
req_id
)
# resume req
cached_req_data
=
CachedRequestData
(
req_id
=
req_id
,
resumed_from_preemption
=
False
,
new_token_ids
=
[],
new_block_ids
=
[],
num_computed_tokens
=
0
,
)
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[
cached_req_data
],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_no_changes
(
model_runner
):
req_id
=
"req_0"
# new req
scheduler_output
=
_schedule_new_request
(
req_id
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
# schedule req
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_id
:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_id
)
assert
_is_req_scheduled
(
model_runner
,
req_id
)
assert
_is_req_state_block_table_match
(
model_runner
,
req_id
)
def
test_update_states_request_unscheduled
(
model_runner
):
req_ids
=
(
"req_0"
,
"req_1"
)
# new reqs
scheduler_output
=
_schedule_new_request
(
*
req_ids
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
# unschedule req_1
scheduler_output
=
SchedulerOutput
(
scheduled_new_reqs
=
[],
scheduled_cached_reqs
=
[],
num_scheduled_tokens
=
{
req_ids
[
0
]:
1
},
total_num_scheduled_tokens
=
1
,
scheduled_spec_decode_tokens
=
{},
scheduled_encoder_inputs
=
{},
num_common_prefix_blocks
=
0
,
finished_req_ids
=
set
(),
free_encoder_input_ids
=
[],
structured_output_request_ids
=
{},
grammar_bitmask
=
None
,
)
model_runner
.
_update_states
(
scheduler_output
)
assert
_is_req_added
(
model_runner
,
req_ids
[
0
])
assert
_is_req_scheduled
(
model_runner
,
req_ids
[
0
])
assert
_is_req_added
(
model_runner
,
req_ids
[
1
])
assert
not
_is_req_scheduled
(
model_runner
,
req_ids
[
1
])
def
test_get_paddings
():
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
expected_paddings
=
[
16
,
32
,
64
,
128
,
192
,
256
,
320
,
384
,
448
,
512
]
actual_paddings
=
_get_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
actual_paddings
==
expected_paddings
def
test_get_padded_token_len
():
min_token_size
,
max_token_size
,
padding_gap
=
16
,
512
,
64
paddings
=
_get_paddings
(
min_token_size
,
max_token_size
,
padding_gap
)
assert
_get_padded_token_len
(
paddings
,
1
)
==
16
assert
_get_padded_token_len
(
paddings
,
16
)
==
16
assert
_get_padded_token_len
(
paddings
,
20
)
==
32
assert
_get_padded_token_len
(
paddings
,
300
)
==
320
assert
_get_padded_token_len
(
paddings
,
512
)
==
512
tools/shellcheck.sh
View file @
fcfc474d
...
...
@@ -18,5 +18,5 @@ if ! [ -x "$(command -v shellcheck)" ]; then
export
PATH
=
"
$PATH
:
$(
pwd
)
/shellcheck-
${
scversion
}
"
fi
# TODO - fix warnings in .buildkite/run-amd-test.sh
find
.
-name
"*.sh"
".git"
-prune
-not
-path
"./.buildkite/run-amd-test.sh"
-print0
| xargs
-0
-I
{}
sh
-c
'git check-ignore -q "{}" || shellcheck -s bash "{}"'
# TODO - fix warnings in .buildkite/
scripts/hardware_ci/
run-amd-test.sh
find
.
-name
"*.sh"
".git"
-prune
-not
-path
"./.buildkite/
scripts/hardware_ci/
run-amd-test.sh"
-print0
| xargs
-0
-I
{}
sh
-c
'git check-ignore -q "{}" || shellcheck -s bash "{}"'
vllm/__init__.py
View file @
fcfc474d
...
...
@@ -4,9 +4,10 @@
# version library first. Such assumption is critical for some customization.
from
.version
import
__version__
,
__version_tuple__
# isort:skip
import
os
import
torch
# The environment variables override should be imported before any other
# modules to ensure that the environment variables are set before any
# other modules are imported.
import
vllm.env_override
# isort:skip # noqa: F401
from
vllm.engine.arg_utils
import
AsyncEngineArgs
,
EngineArgs
from
vllm.engine.async_llm_engine
import
AsyncLLMEngine
...
...
@@ -25,19 +26,6 @@ from vllm.sampling_params import SamplingParams
from
vllm.version
import
__version__
,
__version_tuple__
,
__hcu_version__
# set some common config/environment variables that should be set
# for all processes created by vllm and all processes
# that interact with vllm workers.
# they are executed whenever `import vllm` is called.
# see https://github.com/NVIDIA/nccl/issues/1234
os
.
environ
[
'NCCL_CUMEM_ENABLE'
]
=
'0'
# see https://github.com/vllm-project/vllm/issues/10480
os
.
environ
[
'TORCHINDUCTOR_COMPILE_THREADS'
]
=
'1'
# see https://github.com/vllm-project/vllm/issues/10619
torch
.
_inductor
.
config
.
compile_threads
=
1
__all__
=
[
"__version__"
,
"__version_tuple__"
,
...
...
vllm/_custom_ops.py
View file @
fcfc474d
...
...
@@ -148,6 +148,7 @@ def paged_attention_v2_with_mask(
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
query_start_loc
:
Optional
[
torch
.
Tensor
],
block_size
:
int
,
max_seq_len
:
int
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
...
...
@@ -440,6 +441,7 @@ def paged_attention_v2_opt_tc_with_mask(
# scale: float,
# block_tables: torch.Tensor,
# seq_lens: torch.Tensor,
# query_start_loc: Optional[torch.Tensor],
# block_size: int,
# max_seq_len: int,
# alibi_slopes: Optional[torch.Tensor],
...
...
@@ -450,8 +452,21 @@ def paged_attention_v2_opt_tc_with_mask(
# torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
# key_cache, value_cache, num_kv_heads,
# scale, block_tables, seq_lens,
# block_size, max_seq_len, alibi_slopes,
# kv_cache_dtype, k_scale, v_scale)
# query_start_loc, block_size, max_seq_len,
# alibi_slopes, kv_cache_dtype, k_scale,
# v_scale)
def
mla_decode_kvcache_cpu
(
out
:
torch
.
Tensor
,
query
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
scale
:
float
,
block_tables
:
torch
.
Tensor
,
seq_lens
:
torch
.
Tensor
,
)
->
None
:
torch
.
ops
.
_C_cpu
.
mla_decode_kvcache
(
out
,
query
,
kv_cache
,
scale
,
block_tables
,
seq_lens
)
# pos encoding ops
...
...
@@ -792,7 +807,6 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# memory_format=torch.contiguous_format)
# if hasattr(torch.ops._C, "allspark_w8a16_gemm"):
# @register_fake("_C::allspark_w8a16_gemm")
...
...
@@ -810,13 +824,16 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# if hasattr(torch.ops._C, "ggml_dequantize"):
# @register_fake("_C::ggml_dequantize")
# def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int,
# m: torch.SymInt,
# n: torch.SymInt) -> torch.Tensor:
# def _ggml_dequantize_fake(
# W: torch.Tensor,
# quant_type: int,
# m: torch.SymInt,
# n: torch.SymInt,
# dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# return torch.empty((m, n), dtype=torch.float16, device=W.device)
# @register_fake("_C::ggml_mul_mat_vec_a8")
# def _ggml_mul_mat_vec_a8_fake(
# W: torch.Tensor,
...
...
@@ -995,6 +1012,9 @@ def cutlass_sparse_scaled_mm_supported(cuda_device_capability: int) -> bool:
cuda_device_capability
)
def
cutlass_group_gemm_supported
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_group_gemm_supported
(
cuda_device_capability
)
def
cutlass_sparse_compress
(
a
:
torch
.
Tensor
)
\
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
...
...
@@ -1085,6 +1105,56 @@ def cutlass_scaled_sparse_mm(
return
out
def
get_cutlass_moe_mm_data
(
topk_ids
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes1
:
torch
.
Tensor
,
problem_sizes2
:
torch
.
Tensor
,
input_permutation
:
torch
.
Tensor
,
output_permutation
:
torch
.
Tensor
,
num_experts
:
int
,
n
:
int
,
k
:
int
):
"""
Prepare data necessary to perform CUTLASS grouped matrix multiplications
used in CUTLASS-based fused MoE.
The function takes in topk_ids (token-expert mapping) and uses it to
compute:
- expert_offsets: Indices that mark at which token index each expert begins
its computation after the input is sorted with
input_permutation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes1, problem_sizes2: MxNxK sizes of each expert's
multiplication in two grouped MMs used in
the fused MoE operation.
- input_permutation: Permutation that must be used to shuffle the input
before executing the MMs.
- output_permutation: Permutation that must be used to shuffle the output
after executing the MMs.
"""
torch
.
ops
.
_C
.
get_cutlass_moe_mm_data
(
topk_ids
,
expert_offsets
,
problem_sizes1
,
problem_sizes2
,
input_permutation
,
output_permutation
,
num_experts
,
n
,
k
)
def
cutlass_moe_mm
(
out_tensors
:
torch
.
Tensor
,
a_tensors
:
torch
.
Tensor
,
b_tensors
:
torch
.
Tensor
,
a_scales
:
torch
.
Tensor
,
b_scales
:
torch
.
Tensor
,
expert_offsets
:
torch
.
Tensor
,
problem_sizes
:
torch
.
Tensor
,
a_strides
:
torch
.
Tensor
,
b_strides
:
torch
.
Tensor
,
c_strides
:
torch
.
Tensor
):
"""
A single grouped matrix multiplication used in CUTLASS-based fused MoE.
The function executes fp8-quantized OUT = AB matrix multiplication.
- expert_offsets: Indices that mark at which token index each expert begins
its computation. The number of tokens computed with
expert E is expert_offsets[E + 1] - expert_offsets[E]
- problem_sizes: MxNxK sizes of each expert's multiplication in two grouped
MMs used in the fused MoE operation.
- a/b/c_strides: The data strides passed to grouped matrix multiplication.
"""
torch
.
ops
.
_C
.
cutlass_moe_mm
(
out_tensors
,
a_tensors
,
b_tensors
,
a_scales
,
b_scales
,
expert_offsets
,
problem_sizes
,
a_strides
,
b_strides
,
c_strides
)
# aqlm
def
aqlm_gemm
(
input
:
torch
.
Tensor
,
codes
:
torch
.
Tensor
,
codebooks
:
torch
.
Tensor
,
scales
:
torch
.
Tensor
,
...
...
@@ -1452,9 +1522,9 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# gguf
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
#
n: int
) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n)
# def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int,
n: int,
#
dtype: Optional[torch.dtype]
) -> torch.Tensor:
# return torch.ops._C.ggml_dequantize(W, quant_type, m, n
, dtype
)
def
ggml_mul_mat_vec_a8
(
...
...
@@ -1579,7 +1649,7 @@ def moe_wna16_gemm(input: torch.Tensor, output: torch.Tensor,
def
topk_softmax
(
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
token_expert_indicies
:
torch
.
Tensor
,
gating_output
:
float
)
->
None
:
gating_output
:
torch
.
Tensor
)
->
None
:
torch
.
ops
.
_moe_C
.
topk_softmax
(
topk_weights
,
topk_ids
,
token_expert_indicies
,
gating_output
)
...
...
@@ -1692,9 +1762,9 @@ def get_max_shared_memory_per_block_device_attribute(device: int) -> int:
# custom ar
def
init_custom_ar
(
ipc_tensors
:
list
[
torch
.
Tensor
],
rank_data
:
torch
.
Tensor
,
rank
:
int
,
full
_nvlink
:
bool
)
->
int
:
rank
:
int
,
full
y_connected
:
bool
)
->
int
:
return
torch
.
ops
.
_C_custom_ar
.
init_custom_ar
(
ipc_tensors
,
rank_data
,
rank
,
full
_nvlink
)
full
y_connected
)
def
all_reduce
(
fa
:
int
,
inp
:
torch
.
Tensor
,
out
:
torch
.
Tensor
,
reg_buffer
:
int
,
...
...
@@ -1760,6 +1830,7 @@ def write_cache_multi_layers(
value_caches
,
slot_mapping
,
kv_cache_dtype
)
def
get_flash_mla_metadata
(
cache_seqlens
:
torch
.
Tensor
,
num_heads_per_head_k
:
int
,
...
...
vllm/_ipex_ops.py
View file @
fcfc474d
...
...
@@ -187,15 +187,28 @@ class ipex_ops:
gen_
:
torch
.
Generator
,
logits_soft_cap
:
float
,
)
->
None
:
ipex
.
llm
.
functional
.
varlen_attention
(
query
.
contiguous
(),
key
.
contiguous
(),
value
.
contiguous
(),
out
,
seqlen_q
.
int
(),
seqlen_k
.
int
(),
max_seqlen_q
,
max_seqlen_k
,
pdropout
,
softmax_scale
,
zero_tensors
,
is_causal
,
return_softmax
,
gen_
,
logits_soft_cap
)
if
ipex
.
__version__
.
endswith
(
"cpu"
):
if
logits_soft_cap
!=
0.0
:
raise
ValueError
(
"IPEX CPU does not support logits_soft_cap"
)
ipex
.
llm
.
functional
.
varlen_attention
(
query
.
contiguous
(),
key
.
contiguous
(),
value
.
contiguous
(),
out
,
seqlen_q
.
int
(),
seqlen_k
.
int
(),
max_seqlen_q
,
max_seqlen_k
,
pdropout
,
softmax_scale
,
zero_tensors
,
is_causal
,
return_softmax
,
gen_
)
else
:
# XPU build
ipex
.
llm
.
functional
.
varlen_attention
(
query
.
contiguous
(),
key
.
contiguous
(),
value
.
contiguous
(),
out
,
seqlen_q
.
int
(),
seqlen_k
.
int
(),
max_seqlen_q
,
max_seqlen_k
,
pdropout
,
softmax_scale
,
zero_tensors
,
is_causal
,
return_softmax
,
gen_
,
logits_soft_cap
)
@
staticmethod
def
reshape_and_cache
(
...
...
vllm/assets/video.py
View file @
fcfc474d
...
...
@@ -10,8 +10,6 @@ import numpy.typing as npt
from
huggingface_hub
import
hf_hub_download
from
PIL
import
Image
from
vllm.multimodal.video
import
sample_frames_from_video
from
.base
import
get_cache_dir
...
...
@@ -43,14 +41,19 @@ def video_to_ndarrays(path: str, num_frames: int = -1) -> npt.NDArray:
total_frames
=
int
(
cap
.
get
(
cv2
.
CAP_PROP_FRAME_COUNT
))
frames
=
[]
for
i
in
range
(
total_frames
):
ret
,
frame
=
cap
.
read
()
if
ret
:
frames
.
append
(
frame
)
cap
.
release
()
num_frames
=
num_frames
if
num_frames
>
0
else
total_frames
frame_indices
=
np
.
linspace
(
0
,
total_frames
-
1
,
num_frames
,
dtype
=
int
)
for
idx
in
range
(
total_frames
):
ok
=
cap
.
grab
()
# next img
if
not
ok
:
break
if
idx
in
frame_indices
:
# only decompress needed
ret
,
frame
=
cap
.
retrieve
()
if
ret
:
frames
.
append
(
frame
)
frames
=
np
.
stack
(
frames
)
frames
=
sample_frames_from_video
(
frames
,
num_frames
)
if
len
(
frames
)
<
num_frames
:
raise
ValueError
(
f
"Could not read enough frames from video file
{
path
}
"
f
" (expected
{
num_frames
}
frames, got
{
len
(
frames
)
}
)"
)
...
...
vllm/attention/backends/cpu_mla.py
0 → 100644
View file @
fcfc474d
# SPDX-License-Identifier: Apache-2.0
from
dataclasses
import
dataclass
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
torch
import
vllm._custom_ops
as
ops
from
vllm._ipex_ops
import
ipex_ops
from
vllm.attention.backends.abstract
import
(
AttentionBackend
,
AttentionMetadataBuilder
,
AttentionType
,
is_quantized_kv_cache
)
from
vllm.attention.backends.mla.common
import
MLACommonImpl
,
MLACommonState
from
vllm.attention.backends.torch_sdpa
import
TorchSDPAMetadata
from
vllm.utils
import
make_tensor_with_pad
from
vllm.worker.cpu_model_runner
import
ModelInputForCPUBuilder
class
CPUMLABackend
(
AttentionBackend
):
@
staticmethod
def
get_name
()
->
str
:
return
"CPU_MLA"
@
staticmethod
def
get_metadata_cls
()
->
Type
[
"CPUMLAMetadata"
]:
return
CPUMLAMetadata
@
staticmethod
def
get_builder_cls
()
->
Type
[
"CPUMLAMetadataBuilder"
]:
return
CPUMLAMetadataBuilder
@
staticmethod
def
get_state_cls
()
->
Type
[
"MLACommonState"
]:
return
MLACommonState
@
staticmethod
def
get_impl_cls
()
->
Type
[
"CPUMLAImpl"
]:
return
CPUMLAImpl
@
staticmethod
def
get_kv_cache_shape
(
num_blocks
:
int
,
block_size
:
int
,
num_kv_heads
:
int
,
# assumed to be 1 for MLA
head_size
:
int
,
)
->
Tuple
[
int
,
...]:
return
(
num_blocks
,
block_size
,
head_size
)
@
staticmethod
def
swap_blocks
(
src_kv_cache
:
torch
.
Tensor
,
dst_kv_cache
:
torch
.
Tensor
,
src_to_dst
:
torch
.
Tensor
,
)
->
None
:
ops
.
swap_blocks
(
src_kv_cache
,
dst_kv_cache
,
src_to_dst
)
@
staticmethod
def
copy_blocks
(
kv_caches
:
List
[
torch
.
Tensor
],
src_to_dists
:
torch
.
Tensor
,
)
->
None
:
ops
.
copy_blocks_mla
(
kv_caches
,
src_to_dists
)
@
staticmethod
def
get_supported_head_sizes
()
->
List
[
int
]:
return
[
576
]
@
dataclass
class
CPUMLAMetadata
(
TorchSDPAMetadata
):
# New for MLA
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions
:
torch
.
Tensor
=
None
# required by MLACommonImpl
is_profile_run
:
bool
=
False
class
CPUMLAMetadataBuilder
(
AttentionMetadataBuilder
[
CPUMLAMetadata
]):
def
__init__
(
self
,
input_builder
:
ModelInputForCPUBuilder
)
->
None
:
self
.
chunked_prefill
=
input_builder
.
chunked_prefill
self
.
input_builder
=
input_builder
assert
not
self
.
chunked_prefill
,
\
"chunked prefill is currently not supported"
def
prepare
(
self
):
self
.
input_data
=
self
.
input_builder
.
input_data
def
build
(
self
,
seq_lens
,
query_lens
,
cuda_graph_pad_size
,
batch_size
):
input_data
=
self
.
input_data
prefill_seq_lens
=
seq_lens
[
0
:
input_data
.
num_prefills
]
prefill_query_lens
=
query_lens
[
0
:
input_data
.
num_prefills
]
slot_mapping
=
torch
.
tensor
(
input_data
.
slot_mapping
,
dtype
=
torch
.
long
,
device
=
"cpu"
)
# metadata for prefill
if
input_data
.
num_prefills
>
0
:
query_lens_tensor
=
torch
.
tensor
(
prefill_query_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
kv_lens_tensor
=
torch
.
tensor
(
prefill_seq_lens
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
query_start_loc
=
torch
.
zeros
(
input_data
.
num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
kv_start_loc
=
torch
.
zeros
(
input_data
.
num_prefills
+
1
,
dtype
=
torch
.
int32
,
device
=
"cpu"
)
torch
.
cumsum
(
query_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
query_start_loc
[
1
:])
torch
.
cumsum
(
kv_lens_tensor
,
dim
=
0
,
dtype
=
torch
.
int32
,
out
=
kv_start_loc
[
1
:])
max_query_len
=
max
(
prefill_query_lens
)
max_kv_len
=
max
(
prefill_seq_lens
)
# for chunked-prefill
if
self
.
chunked_prefill
:
prefill_block_tables
=
make_tensor_with_pad
(
self
.
input_data
.
prefill_block_tables
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
else
:
prefill_block_tables
=
None
else
:
query_start_loc
=
None
kv_start_loc
=
None
max_query_len
=
None
max_kv_len
=
None
prefill_block_tables
=
None
# metadata for decode
if
input_data
.
num_decode_tokens
!=
0
:
seq_lens_tensor
=
torch
.
tensor
(
input_data
.
seq_lens
[
input_data
.
num_prefills
:],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
block_tables
=
make_tensor_with_pad
(
self
.
input_data
.
decode_block_tables
,
pad
=
0
,
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
else
:
block_tables
=
torch
.
tensor
([])
seq_lens_tensor
=
torch
.
tensor
(
input_data
.
seq_lens
[:
input_data
.
num_prefills
],
dtype
=
torch
.
int32
,
device
=
"cpu"
,
)
# For multi-modal models
placeholder_index_maps
=
None
if
len
(
input_data
.
multi_modal_inputs_list
)
!=
0
:
placeholder_index_maps
=
{
modality
:
placeholder_map
.
index_map
()
for
modality
,
placeholder_map
in
input_data
.
multi_modal_placeholder_maps
.
items
()
}
return
CPUMLAMetadata
(
chunked_prefill
=
self
.
chunked_prefill
,
seq_lens
=
prefill_seq_lens
,
seq_lens_tensor
=
seq_lens_tensor
,
max_query_len
=
max_query_len
,
max_kv_len
=
max_kv_len
,
query_start_loc
=
query_start_loc
,
kv_start_loc
=
kv_start_loc
,
max_decode_seq_len
=
input_data
.
max_decode_seq_len
,
num_prefills
=
input_data
.
num_prefills
,
num_prefill_tokens
=
input_data
.
num_prefill_tokens
,
num_decode_tokens
=
input_data
.
num_decode_tokens
,
block_tables
=
block_tables
,
prefill_block_tables
=
prefill_block_tables
,
slot_mapping
=
slot_mapping
,
multi_modal_placeholder_index_maps
=
placeholder_index_maps
,
enable_kv_scales_calculation
=
False
,
input_positions
=
torch
.
tensor
([
self
.
input_data
.
input_positions
]))
class
CPUMLAImpl
(
MLACommonImpl
[
CPUMLAMetadata
]):
def
__init__
(
self
,
num_heads
:
int
,
head_size
:
int
,
scale
:
float
,
num_kv_heads
:
int
,
alibi_slopes
:
Optional
[
List
[
float
]],
sliding_window
:
Optional
[
int
],
kv_cache_dtype
:
str
,
blocksparse_params
:
Optional
[
Dict
[
str
,
Any
]],
logits_soft_cap
:
Optional
[
float
],
attn_type
:
str
,
# MLA Specific Arguments
**
mla_args
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
,
alibi_slopes
,
sliding_window
,
kv_cache_dtype
,
blocksparse_params
,
logits_soft_cap
,
attn_type
,
**
mla_args
)
unsupported_features
=
[
alibi_slopes
,
sliding_window
,
blocksparse_params
,
logits_soft_cap
]
if
any
(
unsupported_features
):
raise
NotImplementedError
(
"CPUMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap"
)
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CPUMLAImpl"
)
# states is implemented.
if
is_quantized_kv_cache
(
self
.
kv_cache_dtype
):
raise
NotImplementedError
(
"CPUMLAImpl with FP8 KV cache not yet supported"
)
def
_forward_prefill
(
self
,
q
:
torch
.
Tensor
,
kv_c_normed
:
torch
.
Tensor
,
k_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUMLAMetadata
,
# type: ignore[override]
)
->
torch
.
Tensor
:
prefill_metadata
=
attn_metadata
.
prefill_metadata
assert
prefill_metadata
is
not
None
kv_nope
=
self
.
kv_b_proj
(
kv_c_normed
)[
0
].
view
(
\
-
1
,
self
.
num_heads
,
self
.
qk_nope_head_dim
+
self
.
v_head_dim
)
k_nope
,
v
=
kv_nope
\
.
split
([
self
.
qk_nope_head_dim
,
self
.
v_head_dim
],
dim
=-
1
)
k
=
torch
.
cat
((
k_nope
,
k_pe
.
expand
((
*
k_nope
.
shape
[:
-
1
],
-
1
))),
dim
=-
1
)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded
=
torch
.
nn
.
functional
.
pad
(
v
,
[
0
,
q
.
shape
[
-
1
]
-
v
.
shape
[
-
1
]],
value
=
0
)
output
=
torch
.
empty_like
(
q
)
ipex_ops
.
varlen_attention
(
query
=
q
,
key
=
k
,
value
=
v_padded
,
out
=
output
,
seqlen_q
=
prefill_metadata
.
query_start_loc
,
seqlen_k
=
prefill_metadata
.
query_start_loc
,
max_seqlen_q
=
prefill_metadata
.
max_query_len
,
max_seqlen_k
=
prefill_metadata
.
max_query_len
,
pdropout
=
0.0
,
softmax_scale
=
self
.
scale
,
zero_tensors
=
False
,
is_causal
=
True
,
return_softmax
=
False
,
gen_
=
None
,
logits_soft_cap
=
0.0
,
)
# remove padding
output
=
output
.
view
(
-
1
,
self
.
num_heads
,
q
.
shape
[
-
1
])[...,
:
v
.
shape
[
-
1
]]
output
=
output
.
reshape
(
-
1
,
self
.
num_heads
*
v
.
shape
[
-
1
])
return
self
.
o_proj
(
output
)[
0
]
def
_forward_decode
(
self
,
q_nope
:
torch
.
Tensor
,
q_pe
:
torch
.
Tensor
,
kv_c_and_k_pe_cache
:
torch
.
Tensor
,
attn_metadata
:
CPUMLAMetadata
,
# type: ignore[override]
)
->
torch
.
Tensor
:
assert
kv_c_and_k_pe_cache
.
numel
()
>
0
decode_meta
=
attn_metadata
.
decode_metadata
assert
decode_meta
is
not
None
q
=
torch
.
cat
([
q_nope
,
q_pe
],
dim
=-
1
)
o
=
q
.
new_empty
(
q
.
shape
[
0
],
self
.
num_heads
,
self
.
kv_lora_rank
)
# Run MQA
ops
.
mla_decode_kvcache_cpu
(
o
,
q
,
kv_c_and_k_pe_cache
,
self
.
scale
,
decode_meta
.
block_tables
,
decode_meta
.
seq_lens_tensor
)
return
self
.
_v_up_proj_and_o_proj
(
o
)
vllm/attention/backends/mla/common.py
View file @
fcfc474d
...
...
@@ -204,7 +204,6 @@ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
from
vllm.attention.backends.utils
import
(
PAD_SLOT_ID
,
compute_slot_mapping
,
compute_slot_mapping_start_idx
,
is_block_tables_empty
)
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
LinearBase
,
RowParallelLinear
,
UnquantizedLinearMethod
)
...
...
@@ -212,18 +211,27 @@ from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding
,
RotaryEmbedding
)
from
vllm.multimodal
import
MultiModalPlaceholderMap
from
vllm.platforms
import
current_platform
from
vllm.triton_utils
import
HAS_TRITON
from
vllm.utils
import
async_tensor_h2d
,
cdiv
,
make_tensor_with_pad
,
round_down
from
vllm.vllm_flash_attn.fa_utils
import
get_flash_attn_version
if
HAS_TRITON
:
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
from
vllm.attention.ops.triton_merge_attn_states
import
merge_attn_states
else
:
merge_attn_states
=
None
triton_attention
=
None
try
:
from
vllm.vllm_flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
True
except
ImportError
:
# For rocm use upstream flash attention
from
flash_attn
import
flash_attn_varlen_func
is_vllm_fa
=
False
from
vllm.attention.ops.triton_flash_attention
import
triton_attention
try
:
# For rocm use upstream flash attention
from
flash_attn
import
flash_attn_varlen_func
except
ImportError
:
flash_attn_varlen_func
=
None
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
(
ModelInputForGPUBuilder
,
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
fcfc474d
...
...
@@ -18,16 +18,13 @@ from vllm.attention.ops.paged_attn import (PagedAttention,
from
vllm.logger
import
init_logger
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
if
TYPE_CHECKING
:
from
vllm.worker.model_runner
import
ModelInputForGPUWithSamplingMetadata
logger
=
init_logger
(
__name__
)
_PARTITION_SIZE_ROCM
=
256
_GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
_ON_NAVI
=
"gfx1"
in
_GPU_ARCH
_ON_MI250_MI300
=
any
(
arch
in
_GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
class
ROCmFlashAttentionBackend
(
AttentionBackend
):
...
...
@@ -804,9 +801,9 @@ class ROCmFlashAttentionImpl(AttentionImpl):
num_seqs
,
num_heads
,
head_size
=
decode_query
.
shape
block_size
=
value_cache
.
shape
[
3
]
gqa_ratio
=
num_heads
//
self
.
num_kv_heads
# use_custom =
_
use_rocm_custom_paged_attention(
# use_custom = use_rocm_custom_paged_attention(
# decode_query.dtype, head_size, block_size, gqa_ratio,
# decode_meta.max_decode_seq_len)
# decode_meta.max_decode_seq_len
, self.sliding_window
)
use_custom
=
False
if
use_custom
:
max_seq_len
=
(
decode_meta
.
max_decode_seq_len
if
self
.
attn_type
...
...
@@ -832,6 +829,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
out
=
output
[
num_prefill_tokens
:]
else
:
out
=
output
query_start_loc
=
None
ops
.
paged_attention_rocm
(
out
,
exp_sums
,
...
...
@@ -848,6 +847,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
decode_meta
.
seq_lens_tensor
if
self
.
attn_type
!=
AttentionType
.
ENCODER_DECODER
else
decode_meta
.
encoder_seq_lens_tensor
,
query_start_loc
,
block_size
,
max_seq_len
,
self
.
alibi_slopes
,
...
...
@@ -902,9 +902,8 @@ def _sdpa_attention(
for
i
,
seq_len
in
enumerate
(
seq_lens
):
end
=
start
+
seq_len
with
torch
.
backends
.
cuda
.
sdp_kernel
(
enable_math
=
True
,
enable_flash
=
False
,
enable_mem_efficient
=
False
):
with
torch
.
nn
.
attention
.
sdpa_kernel
(
torch
.
nn
.
attention
.
SDPBackend
.
MATH
):
sub_out
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
query
[:,
start
:
end
,
:],
key
[:,
start
:
end
,
:],
...
...
@@ -917,14 +916,3 @@ def _sdpa_attention(
start
=
end
return
output
def
_use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
max_seq_len
:
int
)
->
bool
:
# rocm custom page attention not support on navi (gfx1*)
return
(
_ON_MI250_MI300
and
not
_ON_NAVI
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
block_size
==
16
or
block_size
==
32
)
and
(
gqa_ratio
>=
1
and
gqa_ratio
<=
16
)
and
max_seq_len
<=
32768
)
vllm/attention/ops/chunked_prefill_paged_decode.py
View file @
fcfc474d
...
...
@@ -10,6 +10,9 @@ import torch
import
triton
import
triton.language
as
tl
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms.rocm
import
use_rocm_custom_paged_attention
from
.prefix_prefill
import
context_attention_fwd
...
...
@@ -33,26 +36,26 @@ def kernel_paged_attention_2d(
num_query_heads
:
tl
.
constexpr
,
# int
num_queries_per_kv
:
tl
.
constexpr
,
# int
num_queries_per_kv_padded
:
tl
.
constexpr
,
# int
block_table_stride
:
tl
.
constexpr
,
# int
query_stride_0
:
tl
.
constexpr
,
# int
query_stride_1
:
tl
.
constexpr
,
# int, should be equal to head_size
output_stride_0
:
tl
.
constexpr
,
# int
output_stride_1
:
tl
.
constexpr
,
# int, should be equal to head_size
block_table_stride
:
tl
.
int64
,
# int
query_stride_0
:
tl
.
int64
,
# int
query_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
output_stride_0
:
tl
.
int64
,
# int
output_stride_1
:
tl
.
int64
,
# int, should be equal to head_size
BLOCK_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE
:
tl
.
constexpr
,
# int
HEAD_SIZE_PADDED
:
tl
.
constexpr
,
# int, must be power of 2
USE_ALIBI_SLOPES
:
tl
.
constexpr
,
# bool
SLIDING_WINDOW
:
tl
.
constexpr
,
# int
x
:
tl
.
constexpr
,
# int
stride_k_cache_0
:
tl
.
constexpr
,
# int
stride_k_cache_1
:
tl
.
constexpr
,
# int
stride_k_cache_2
:
tl
.
constexpr
,
# int
stride_k_cache_3
:
tl
.
constexpr
,
# int
stride_k_cache_4
:
tl
.
constexpr
,
# int
stride_v_cache_0
:
tl
.
constexpr
,
# int
stride_v_cache_1
:
tl
.
constexpr
,
# int
stride_v_cache_2
:
tl
.
constexpr
,
# int
stride_v_cache_3
:
tl
.
constexpr
,
# int
stride_k_cache_0
:
tl
.
int64
,
# int
stride_k_cache_1
:
tl
.
int64
,
# int
stride_k_cache_2
:
tl
.
int64
,
# int
stride_k_cache_3
:
tl
.
int64
,
# int
stride_k_cache_4
:
tl
.
int64
,
# int
stride_v_cache_0
:
tl
.
int64
,
# int
stride_v_cache_1
:
tl
.
int64
,
# int
stride_v_cache_2
:
tl
.
int64
,
# int
stride_v_cache_3
:
tl
.
int64
,
# int
filter_by_query_len
:
tl
.
constexpr
,
# bool
query_start_len_ptr
,
# [num_seqs+1]
):
...
...
@@ -212,6 +215,7 @@ def chunked_prefill_paged_decode(
block_table
,
query_start_loc
,
seq_lens
,
max_seq_len
,
max_query_len
,
k_scale
,
v_scale
,
...
...
@@ -240,6 +244,7 @@ def chunked_prefill_paged_decode(
b_loc
=
block_table
,
b_start_loc
=
query_start_loc
,
b_seq_len
=
seq_lens
,
max_seq_len
=
max_seq_len
,
max_input_len
=
max_query_len
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
...
...
@@ -275,43 +280,87 @@ def chunked_prefill_paged_decode(
num_queries_per_kv_padded
=
max
(
triton
.
next_power_of_2
(
num_queries_per_kv
),
16
)
kernel_paged_attention_2d
[(
num_seqs
,
num_kv_heads
,
)](
output_ptr
=
output
,
query_ptr
=
query
,
key_cache_ptr
=
key_cache
,
value_cache_ptr
=
value_cache
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seq_lens
,
alibi_slopes_ptr
=
alibi_slopes
,
scale
=
sm_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
num_queries_per_kv_padded
=
num_queries_per_kv_padded
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
query
.
stride
(
0
),
query_stride_1
=
query
.
stride
(
1
),
output_stride_0
=
output
.
stride
(
0
),
output_stride_1
=
output
.
stride
(
1
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
SLIDING_WINDOW
=
sliding_window
,
x
=
key_cache
.
shape
[
4
],
stride_k_cache_0
=
key_cache
.
stride
(
0
),
stride_k_cache_1
=
key_cache
.
stride
(
1
),
stride_k_cache_2
=
key_cache
.
stride
(
2
),
stride_k_cache_3
=
key_cache
.
stride
(
3
),
stride_k_cache_4
=
key_cache
.
stride
(
4
),
stride_v_cache_0
=
value_cache
.
stride
(
0
),
stride_v_cache_1
=
value_cache
.
stride
(
1
),
stride_v_cache_2
=
value_cache
.
stride
(
2
),
stride_v_cache_3
=
value_cache
.
stride
(
3
),
filter_by_query_len
=
True
,
query_start_len_ptr
=
query_start_loc
,
)
use_custom
=
use_rocm_custom_paged_attention
(
query
.
dtype
,
head_size
,
block_size
,
num_queries_per_kv
,
max_seq_len
,
sliding_window
)
if
use_custom
:
_PARTITION_SIZE_ROCM
=
256
max_num_partitions
=
((
max_seq_len
+
_PARTITION_SIZE_ROCM
-
1
)
//
_PARTITION_SIZE_ROCM
)
assert
_PARTITION_SIZE_ROCM
%
block_size
==
0
total_num_seq
=
query
.
shape
[
0
]
tmp_output
=
torch
.
empty
(
size
=
(
total_num_seq
,
num_query_heads
,
max_num_partitions
,
head_size
),
dtype
=
output
.
dtype
,
device
=
output
.
device
,
)
exp_sums
=
torch
.
empty
(
size
=
(
total_num_seq
,
num_query_heads
,
max_num_partitions
),
dtype
=
torch
.
float32
,
device
=
output
.
device
,
)
max_logits
=
torch
.
empty_like
(
exp_sums
)
ops
.
paged_attention_rocm
(
output
,
exp_sums
,
max_logits
,
tmp_output
,
query
,
key_cache
,
value_cache
,
num_kv_heads
,
scale
=
sm_scale
,
block_tables
=
block_table
,
seq_lens
=
seq_lens
,
query_start_loc
=
query_start_loc
,
block_size
=
block_size
,
max_seq_len
=
max_seq_len
,
alibi_slopes
=
alibi_slopes
,
kv_cache_dtype
=
kv_cache_dtype
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
)
else
:
kernel_paged_attention_2d
[(
num_seqs
,
num_kv_heads
,
)](
output_ptr
=
output
,
query_ptr
=
query
,
key_cache_ptr
=
key_cache
,
value_cache_ptr
=
value_cache
,
block_tables_ptr
=
block_table
,
seq_lens_ptr
=
seq_lens
,
alibi_slopes_ptr
=
alibi_slopes
,
scale
=
sm_scale
,
k_scale
=
k_scale
,
v_scale
=
v_scale
,
num_query_heads
=
num_query_heads
,
num_queries_per_kv
=
num_queries_per_kv
,
num_queries_per_kv_padded
=
num_queries_per_kv_padded
,
block_table_stride
=
block_table
.
stride
(
0
),
query_stride_0
=
query
.
stride
(
0
),
query_stride_1
=
query
.
stride
(
1
),
output_stride_0
=
output
.
stride
(
0
),
output_stride_1
=
output
.
stride
(
1
),
BLOCK_SIZE
=
block_size
,
HEAD_SIZE
=
head_size
,
HEAD_SIZE_PADDED
=
triton
.
next_power_of_2
(
head_size
),
USE_ALIBI_SLOPES
=
use_alibi_slopes
,
SLIDING_WINDOW
=
sliding_window
,
x
=
key_cache
.
shape
[
4
],
stride_k_cache_0
=
key_cache
.
stride
(
0
),
stride_k_cache_1
=
key_cache
.
stride
(
1
),
stride_k_cache_2
=
key_cache
.
stride
(
2
),
stride_k_cache_3
=
key_cache
.
stride
(
3
),
stride_k_cache_4
=
key_cache
.
stride
(
4
),
stride_v_cache_0
=
value_cache
.
stride
(
0
),
stride_v_cache_1
=
value_cache
.
stride
(
1
),
stride_v_cache_2
=
value_cache
.
stride
(
2
),
stride_v_cache_3
=
value_cache
.
stride
(
3
),
filter_by_query_len
=
True
,
query_start_len_ptr
=
query_start_loc
,
)
vllm/attention/ops/nki_flash_attn.py
View file @
fcfc474d
...
...
@@ -144,8 +144,7 @@ def transform_block_tables_for_indirect_load(
def
load_kv_tile_from_cache
(
cur_k_tile
,
cur_v_tile
,
key_cache
,
value_cache
,
kv_cache
,
block_tables
,
large_k_tile_idx
,
num_blocks_per_large_tile
,
...
...
@@ -169,8 +168,8 @@ def load_kv_tile_from_cache(
for
load_idx
in
nl
.
affine_range
(
num_loads
):
i_p
=
nl
.
arange
(
B_P_SIZE
)[:,
None
]
i_f
=
nl
.
arange
(
tiled_block_size
*
B_D_SIZE
)[
None
,
:]
loaded
=
nl
.
load
(
k
ey
_cache
[
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
loaded
=
nl
.
load
(
k
v
_cache
[
0
,
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
if
cur_k_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_k_tile
.
dtype
)
# Transpose SBUF tensor using PE
...
...
@@ -185,7 +184,7 @@ def load_kv_tile_from_cache(
# load value cache
for
load_idx
in
nl
.
affine_range
(
num_loads
):
loaded
=
nl
.
load
(
v
alue
_cache
[
block_tables
[
load_idx
,
i_p
,
loaded
=
nl
.
load
(
k
v_cache
[
1
,
block_tables
[
load_idx
,
i_p
,
large_k_tile_idx
],
i_f
])
if
cur_v_tile
.
dtype
!=
loaded
.
dtype
:
loaded
=
nl
.
copy
(
loaded
,
dtype
=
cur_v_tile
.
dtype
)
...
...
@@ -418,8 +417,7 @@ def flash_paged_attention(
query
,
key
,
value
,
key_cache
,
value_cache
,
kv_cache
,
block_tables
,
mask
,
softmax_scale
=
None
,
...
...
@@ -434,8 +432,7 @@ def flash_paged_attention(
- query: shape (1, n_heads, d, seq_q)
- key: shape (1, n_kv_heads, d, seq_k)
- value: shape (1, n_kv_heads, seq_v, d)
- key_cache: (num_blocks, n_kv_heads, block_size, d)
- value_cache: (num_blocks, n_kv_heads, block_size, d)
- kv_cache: (2, num_blocks, n_kv_heads, block_size, d)
- block_tables: (num_active_blocks, )
- mask: (seq_q, num_active_blocks * block_size + seq_q)
- o: shape (1, n_heads, seq_q, d)
...
...
@@ -444,7 +441,7 @@ def flash_paged_attention(
- We use continuous batching by default, so the batch dimension is
always 1, and different requests are concatenated along sequence
dimension.
- We use paged cache blocks (k
ey_cache, value
_cache) to store KV cache.
- We use paged cache blocks (k
v
_cache) to store KV cache.
IO tensor dtypes:
- This kernel assumes all IO tensors have the same dtype except for
...
...
@@ -475,15 +472,13 @@ def flash_paged_attention(
b
,
h
,
d
,
seqlen_q
=
query
.
shape
B_D_SIZE
=
d
n_tile_q
=
seqlen_q
//
B_P_SIZE
# since q will be loaded on tensor engine
num_blocks
,
k_h
,
block_size
,
_
=
k
ey
_cache
.
shape
_
,
num_blocks
,
k_h
,
block_size
,
_
=
k
v
_cache
.
shape
q_h_per_k_h
=
h
//
k_h
assert
b
==
1
,
f
"invalid batch size
{
b
=
}
"
assert
d
<=
128
,
f
" we do not support head_dim > 128, got head dim
{
d
=
}
"
cache_shape
=
(
num_blocks
,
k_h
,
block_size
,
d
)
assert
(
tuple
(
key_cache
.
shape
)
==
cache_shape
),
f
"
{
key_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
assert
(
tuple
(
value_cache
.
shape
)
==
cache_shape
),
f
"
{
value_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
cache_shape
=
(
2
,
num_blocks
,
k_h
,
block_size
,
d
)
assert
(
tuple
(
kv_cache
.
shape
)
==
cache_shape
),
f
"
{
kv_cache
.
shape
=
}
mismatch, expect
{
cache_shape
}
"
assert
key
is
None
or
tuple
(
key
.
shape
)
==
(
1
,
k_h
,
...
...
@@ -580,13 +575,13 @@ def flash_paged_attention(
head_id
=
head_id
,
)
# Flatten KV cache to be
2
D for loading into SBUF
# Flatten KV cache to be
3
D for loading into SBUF
new_cache_shape
=
(
2
,
num_blocks
*
k_h
*
block_size_tiling_factor
,
tiled_block_size
*
d
,
)
key_cache
=
key_cache
.
reshape
(
new_cache_shape
)
value_cache
=
value_cache
.
reshape
(
new_cache_shape
)
kv_cache
=
kv_cache
.
reshape
(
new_cache_shape
)
# Global Flash Attention accumulators
o_buffer
=
nl
.
zeros
(
...
...
@@ -621,8 +616,7 @@ def flash_paged_attention(
load_kv_tile_from_cache
(
cur_k_tile
=
cur_k_tile
,
cur_v_tile
=
cur_v_tile
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
kv_cache
=
kv_cache
,
block_tables
=
block_tables_sbuf
,
large_k_tile_idx
=
large_k_tile_idx
,
num_blocks_per_large_tile
=
num_blocks_per_large_tile
,
...
...
@@ -821,8 +815,7 @@ def flash_attn_varlen_nkifunc(
query
,
key
,
value
,
key_cache
,
value_cache
,
kv_cache
,
block_table
,
attn_mask
,
n_kv_head
=
None
,
...
...
@@ -838,8 +831,7 @@ def flash_attn_varlen_nkifunc(
- query: (1, n_heads, d, seq_q)
- key: (1, n_kv_heads, d, seq_k)
- value: (1, n_kv_heads, seq_v, d)
- key_cache: (n_blocks, n_kv_heads, block_size, d)
- value_cache: (n_blocks, n_kv_heads, block_size, d)
- kv_cache: (2, n_blocks, n_kv_heads, block_size, d)
- block_tables: (n_active_blocks, )
- attn_mask: (seq_q, n_active_blocks * block_size + seq_q)
...
...
@@ -849,17 +841,17 @@ def flash_attn_varlen_nkifunc(
for better DMA throughput
"""
if
n_kv_head
is
None
:
n_kv_head
=
key_cache
.
shape
[
1
]
assert
key_cache
.
shape
[
1
]
==
n_kv_head
n_kv_head
=
kv_cache
.
shape
[
2
]
assert
kv_cache
.
shape
[
0
]
==
2
assert
kv_cache
.
shape
[
2
]
==
n_kv_head
if
head_size
is
None
:
head_size
=
k
ey
_cache
.
shape
[
-
1
]
head_size
=
k
v
_cache
.
shape
[
-
1
]
kwargs
=
dict
(
query
=
query
,
key
=
key
,
value
=
value
,
key_cache
=
key_cache
,
value_cache
=
value_cache
,
kv_cache
=
kv_cache
,
block_tables
=
block_table
,
mask
=
attn_mask
,
softmax_scale
=
1.0
/
(
head_size
**
0.5
),
...
...
@@ -874,8 +866,7 @@ def flash_attn_varlen_nkifunc(
def
reshape_and_cache
(
key
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
key_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
)
->
None
:
"""
...
...
@@ -886,29 +877,29 @@ def reshape_and_cache(
(num_tokens, n_kv_head, d_head)
value (torch.Tensor): Value tensor with shape
(num_tokens, n_kv_head, d_head)
key_cache (torch.Tensor): Key cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
value_cache (torch.Tensor): Value cache tensor with shape
(num_blocks, n_kv_head, block_size, d_head)
kv_cache (torch.Tensor): Key/value cache tensor with shape
(2, num_blocks, n_kv_head, block_size, d_head)
slot_mapping (torch.Tensor): Mapping tensor indicating cache positions
with shape (num_tokens)
Returns:
None: Updates the k
ey_cache and value
_cache tensor
s
in-place
None: Updates the k
v
_cache tensor in-place
"""
block_size
=
key_cache
.
size
(
2
)
block_size
=
kv_cache
.
size
(
3
)
n_kv_head
=
key
.
size
(
1
)
# Calculate indices with explicit floor division
block_indices
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_offsets
=
slot_mapping
%
block_size
# Create the head indices tensor
head_indices
=
torch
.
arange
(
n_kv_head
,
device
=
key
.
device
)
# Update caches using index_put_
key_cache
.
index_put_
(
(
block_indices
.
unsqueeze
(
1
),
torch
.
arange
(
key_cache
.
size
(
1
),
device
=
key
.
device
),
block_offsets
.
unsqueeze
(
1
)),
key
)
value_cache
.
index_put_
(
(
block_indices
.
unsqueeze
(
1
),
torch
.
arange
(
value_cache
.
size
(
1
),
device
=
value
.
device
),
block_offsets
.
unsqueeze
(
1
)),
value
)
kv_cache
.
index_put_
(
(
torch
.
tensor
([
0
],
device
=
key
.
device
),
block_indices
[:,
None
],
head_indices
[
None
,
:],
block_offsets
[:,
None
]),
key
)
kv_cache
.
index_put_
(
(
torch
.
tensor
([
1
],
device
=
key
.
device
),
block_indices
[:,
None
],
head_indices
[
None
,
:],
block_offsets
[:,
None
]),
value
)
vllm/attention/ops/paged_attn.py
View file @
fcfc474d
...
...
@@ -435,6 +435,7 @@ class PagedAttention:
v_scale
:
torch
.
Tensor
,
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
max_seq_len
=
None
context_attention_fwd
(
query
,
key
,
...
...
@@ -447,6 +448,7 @@ class PagedAttention:
# query_start_loc is (batch_size + 1,)
query_start_loc
,
seq_lens_tensor
,
max_seq_len
,
max_query_len
,
k_scale
,
v_scale
,
...
...
vllm/attention/ops/prefix_prefill.py
View file @
fcfc474d
...
...
@@ -729,6 +729,7 @@ if triton.__version__ >= "2.1.0":
b_loc
,
b_start_loc
,
b_seq_len
,
max_seq_len
,
max_input_len
,
k_scale
:
torch
.
Tensor
,
v_scale
:
torch
.
Tensor
,
...
...
@@ -756,7 +757,7 @@ if triton.__version__ >= "2.1.0":
assert
(
v_cache
.
dtype
==
torch
.
uint8
)
if
kv_cache_dtype
in
(
"fp8"
,
"fp8_e4m3"
):
target_dtype
=
torch
.
float8_e4m3fn
target_dtype
=
current_platform
.
fp8_dtype
()
elif
kv_cache_dtype
==
"fp8_e5m2"
:
target_dtype
=
torch
.
float8_e5m2
else
:
...
...
vllm/attention/ops/triton_merge_attn_states.py
View file @
fcfc474d
...
...
@@ -54,6 +54,15 @@ def merge_attn_states_kernel(
p_lse
=
tl
.
load
(
prefix_lse
+
head_idx
*
num_tokens
+
token_idx
)
s_lse
=
tl
.
load
(
suffix_lse
+
head_idx
*
num_tokens
+
token_idx
)
# FA2 and FA3 have different behavior for when the sum-exp is 0, this namely
# arises with 0 len seqlens. FA3 returns -inf here while FA2 returns inf.
# If we see an inf assume FA2 and convert inf to -inf for consistency
# and correctness. Inf generally doesn't make sense in this context outside
# of undefined-behavior/FA2-case, so I think this a safe assumption.
p_lse
=
float
(
'-inf'
)
if
p_lse
==
float
(
'inf'
)
else
p_lse
s_lse
=
float
(
'-inf'
)
if
s_lse
==
float
(
'inf'
)
else
s_lse
max_lse
=
tl
.
maximum
(
p_lse
,
s_lse
)
p_lse
=
p_lse
-
max_lse
s_lse
=
s_lse
-
max_lse
...
...
vllm/benchmarks/backend_request_func.py
View file @
fcfc474d
...
...
@@ -219,7 +219,15 @@ async def async_request_deepspeed_mii(
if
response
.
status
==
200
:
parsed_resp
=
await
response
.
json
()
output
.
latency
=
time
.
perf_counter
()
-
st
output
.
generated_text
=
parsed_resp
[
"text"
][
0
]
if
"choices"
in
parsed_resp
:
output
.
generated_text
=
parsed_resp
[
"choices"
][
0
][
"text"
]
elif
"text"
in
parsed_resp
:
output
.
generated_text
=
parsed_resp
[
"text"
][
0
]
else
:
output
.
error
=
(
"Unexpected response format: "
"neither 'choices' nor 'text' found"
)
output
.
success
=
False
output
.
success
=
True
else
:
output
.
error
=
response
.
reason
or
""
...
...
vllm/benchmarks/benchmark_serving.py
View file @
fcfc474d
...
...
@@ -7,9 +7,6 @@ On the server side, run one of the following commands:
--swap-space 16 \
--disable-log-requests
(TGI backend)
./launch_tgi_server.sh <your_model> <max_batch_total_tokens>
On the client side, run:
python benchmarks/benchmark_serving.py \
--backend <backend> \
...
...
@@ -52,9 +49,11 @@ try:
except
ImportError
:
from
argparse
import
ArgumentParser
as
FlexibleArgumentParser
from
benchmark_dataset
import
(
BurstGPTDataset
,
HuggingFaceDataset
,
RandomDataset
,
SampleRequest
,
ShareGPTDataset
,
SonnetDataset
,
VisionArenaDataset
)
from
benchmark_dataset
import
(
AIMODataset
,
BurstGPTDataset
,
ConversationDataset
,
HuggingFaceDataset
,
InstructCoderDataset
,
RandomDataset
,
SampleRequest
,
ShareGPTDataset
,
SonnetDataset
,
VisionArenaDataset
)
from
benchmark_utils
import
convert_to_pytorch_benchmark_format
,
write_to_json
MILLISECONDS_TO_SECONDS_CONVERSION
=
1000
...
...
@@ -586,19 +585,39 @@ def main(args: argparse.Namespace):
return_prompt_formatted
=
True
)
elif
args
.
dataset_name
==
"hf"
:
# Choose between VisionArenaDataset
# and HuggingFaceDataset based on provided parameters.
dataset_class
=
(
VisionArenaDataset
if
args
.
dataset_path
==
VisionArenaDataset
.
VISION_ARENA_DATASET_PATH
and
args
.
hf_subset
is
None
else
HuggingFaceDataset
)
# all following datasets are implemented from the
# HuggingFaceDataset base class
if
args
.
dataset_path
in
VisionArenaDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
VisionArenaDataset
args
.
hf_split
=
"train"
args
.
hf_subset
=
None
elif
args
.
dataset_path
in
InstructCoderDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
InstructCoderDataset
args
.
hf_split
=
"train"
elif
args
.
dataset_path
in
ConversationDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
ConversationDataset
elif
args
.
dataset_path
in
AIMODataset
.
SUPPORTED_DATASET_PATHS
:
dataset_class
=
AIMODataset
args
.
hf_split
=
"train"
else
:
supported_datasets
=
set
([
dataset_name
for
cls
in
HuggingFaceDataset
.
__subclasses__
()
for
dataset_name
in
cls
.
SUPPORTED_DATASET_PATHS
])
raise
ValueError
(
f
"Unsupported dataset path:
{
args
.
dataset_path
}
. "
"Huggingface dataset only supports dataset_path"
f
" from one of following:
{
supported_datasets
}
. "
"Please consider contributing if you would "
"like to add support for additional dataset formats."
)
input_requests
=
dataset_class
(
dataset_path
=
args
.
dataset_path
,
dataset_subset
=
args
.
hf_subset
,
dataset_split
=
args
.
hf_split
,
random_seed
=
args
.
seed
,
).
sample
(
num_requests
=
args
.
num_prompts
,
tokenizer
=
tokenizer
,
random_seed
=
args
.
seed
,
output_len
=
args
.
hf_output_len
,
)
...
...
vllm/benchmarks/benchmark_throughput.py
View file @
fcfc474d
...
...
@@ -14,7 +14,8 @@ from typing import Any, Optional, Union
import
numpy
as
np
import
torch
import
uvloop
from
benchmark_dataset
import
(
BurstGPTDataset
,
HuggingFaceDataset
,
from
benchmark_dataset
import
(
AIMODataset
,
BurstGPTDataset
,
ConversationDataset
,
InstructCoderDataset
,
RandomDataset
,
SampleRequest
,
ShareGPTDataset
,
SonnetDataset
,
VisionArenaDataset
)
from
benchmark_utils
import
convert_to_pytorch_benchmark_format
,
write_to_json
...
...
@@ -347,6 +348,7 @@ def get_requests(args, tokenizer):
"input_len"
:
args
.
input_len
,
"output_len"
:
args
.
output_len
,
}
if
args
.
dataset_path
is
None
or
args
.
dataset_name
==
"random"
:
sample_kwargs
[
"range_ratio"
]
=
args
.
random_range_ratio
sample_kwargs
[
"prefix_len"
]
=
args
.
prefix_len
...
...
@@ -364,18 +366,23 @@ def get_requests(args, tokenizer):
elif
args
.
dataset_name
==
"burstgpt"
:
dataset_cls
=
BurstGPTDataset
elif
args
.
dataset_name
==
"hf"
:
if
args
.
backend
!=
"vllm-chat"
:
raise
ValueError
(
"hf datasets only are supported by vllm-chat backend"
)
# Choose between VisionArenaDataset and HuggingFaceDataset based on
# provided parameters.
dataset_cls
=
(
VisionArenaDataset
if
args
.
dataset_path
==
VisionArenaDataset
.
VISION_ARENA_DATASET_PATH
and
args
.
hf_subset
is
None
else
HuggingFaceDataset
)
common_kwargs
[
'dataset_subset'
]
=
args
.
hf_subset
common_kwargs
[
'dataset_split'
]
=
args
.
hf_split
sample_kwargs
[
"enable_multimodal_chat"
]
=
True
if
args
.
dataset_path
in
VisionArenaDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_cls
=
VisionArenaDataset
common_kwargs
[
'dataset_subset'
]
=
None
common_kwargs
[
'dataset_split'
]
=
"train"
sample_kwargs
[
"enable_multimodal_chat"
]
=
True
elif
args
.
dataset_path
in
InstructCoderDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_cls
=
InstructCoderDataset
common_kwargs
[
'dataset_split'
]
=
"train"
elif
args
.
dataset_path
in
ConversationDataset
.
SUPPORTED_DATASET_PATHS
:
dataset_cls
=
ConversationDataset
common_kwargs
[
'dataset_subset'
]
=
args
.
hf_subset
common_kwargs
[
'dataset_split'
]
=
args
.
hf_split
sample_kwargs
[
"enable_multimodal_chat"
]
=
True
elif
args
.
dataset_path
in
AIMODataset
.
SUPPORTED_DATASET_PATHS
:
dataset_cls
=
AIMODataset
common_kwargs
[
'dataset_subset'
]
=
None
common_kwargs
[
'dataset_split'
]
=
"train"
else
:
raise
ValueError
(
f
"Unknown dataset name:
{
args
.
dataset_name
}
"
)
# Remove None values
...
...
@@ -509,9 +516,17 @@ def validate_args(args):
warnings
.
warn
(
"--hf-subset and --hf-split will be ignored
\
since --dataset-name is not 'hf'."
,
stacklevel
=
2
)
elif
args
.
dataset_name
==
"hf"
and
args
.
backend
!=
"vllm-chat"
:
raise
ValueError
(
"When --dataset-name is 'hf', backend must be 'vllm-chat'"
)
elif
args
.
dataset_name
==
"hf"
:
if
args
.
dataset_path
in
(
VisionArenaDataset
.
SUPPORTED_DATASET_PATHS
.
keys
()
|
ConversationDataset
.
SUPPORTED_DATASET_PATHS
):
assert
args
.
backend
==
"vllm-chat"
,
f
"
{
args
.
dataset_path
}
needs to use vllm-chat as the backend."
#noqa: E501
elif
args
.
dataset_path
in
(
InstructCoderDataset
.
SUPPORTED_DATASET_PATHS
|
AIMODataset
.
SUPPORTED_DATASET_PATHS
):
assert
args
.
backend
==
"vllm"
,
f
"
{
args
.
dataset_path
}
needs to use vllm as the backend."
#noqa: E501
else
:
raise
ValueError
(
f
"
{
args
.
dataset_path
}
is not supported by hf dataset."
)
# --random-range-ratio: only used when dataset_name is 'random'
if
args
.
dataset_name
!=
'random'
and
args
.
random_range_ratio
is
not
None
:
...
...
Prev
1
…
10
11
12
13
14
15
16
17
18
…
26
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment