Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
dd66fd2b
Unverified
Commit
dd66fd2b
authored
Jan 28, 2025
by
Mengqing Cao
Committed by
GitHub
Jan 28, 2025
Browse files
[CI] fix pre-commit error (#12494)
Signed-off-by:
Mengqing Cao
<
cmq0113@163.com
>
parent
0f465ab5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
16 deletions
+29
-16
vllm/attention/ops/nki_flash_attn.py
vllm/attention/ops/nki_flash_attn.py
+25
-12
vllm/spec_decode/spec_decode_worker.py
vllm/spec_decode/spec_decode_worker.py
+4
-4
No files found.
vllm/attention/ops/nki_flash_attn.py
View file @
dd66fd2b
...
@@ -106,11 +106,12 @@ def _flash_attention_core(
...
@@ -106,11 +106,12 @@ def _flash_attention_core(
assert
(
continuous_batching_mask
assert
(
continuous_batching_mask
is
not
None
),
"continuous_batching_mask input is required."
is
not
None
),
"continuous_batching_mask input is required."
if
continuous_batching_mask
is
not
None
:
if
continuous_batching_mask
is
not
None
:
assert
(
logit_bias_tile
is
assert
(
None
),
"continuous_batching_mask does not support logit_bias!"
logit_bias_tile
is
None
),
"continuous_batching_mask does not support logit_bias!"
# mask are used to only apply computation to the lower half of the matrix,
# mask are used to only apply computation to the lower half of the matrix,
# which reduce the arth
i
metic intensity by half
# which reduce the ar
i
thmetic intensity by half
forward_mask
=
(
q_tile_idx
*
B_P_SIZE
>=
local_k_large_tile_idx
*
forward_mask
=
(
q_tile_idx
*
B_P_SIZE
>=
local_k_large_tile_idx
*
LARGE_TILE_SZ
if
use_causal_mask
else
None
)
LARGE_TILE_SZ
if
use_causal_mask
else
None
)
...
@@ -468,9 +469,11 @@ def flash_paged_attention(
...
@@ -468,9 +469,11 @@ def flash_paged_attention(
block_in_partition
)
block_in_partition
)
loaded_v
=
nl
.
load
(
value_cache
[
block_tables_sbuf
[
v_i
,
j
],
:,
loaded_v
=
nl
.
load
(
value_cache
[
block_tables_sbuf
[
v_i
,
j
],
:,
head_id
,
:])
head_id
,
:])
cur_v_tile
[
partition_idx
,
cur_v_tile
[
nl
.
ds
(
block_in_partition
*
partition_idx
,
block_size
,
block_size
),
:,
]
=
loaded_v
nl
.
ds
(
block_in_partition
*
block_size
,
block_size
),
:,
]
=
loaded_v
cur_mask
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
cur_mask
=
nl
.
ndarray
((
par_dim
(
B_P_SIZE
),
LARGE_TILE_SZ
),
dtype
=
mask
.
dtype
)
dtype
=
mask
.
dtype
)
...
@@ -601,20 +604,30 @@ def flash_paged_attention(
...
@@ -601,20 +604,30 @@ def flash_paged_attention(
)
)
nl
.
store
(
nl
.
store
(
o
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
o
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
:,
],
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
:,
],
out
,
out
,
)
)
# maximum and summation statistics
# maximum and summation statistics
if
return_debug_tensors
:
if
return_debug_tensors
:
nl
.
store
(
nl
.
store
(
hbm_m_buffer
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
hbm_m_buffer
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
],
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
],
m_buffer
[
i
,
i_q_h
,
:,
:],
m_buffer
[
i
,
i_q_h
,
:,
:],
)
)
nl
.
store
(
nl
.
store
(
hbm_l_buffer
[
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
hbm_l_buffer
[
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
],
batch_id
,
head_id
*
q_h_per_k_h
+
i_q_h
,
nl
.
ds
(
i
*
B_P_SIZE
,
B_P_SIZE
),
],
l_buffer
[:,
i
,
i_q_h
],
l_buffer
[:,
i
,
i_q_h
],
)
)
nl
.
store
(
nl
.
store
(
...
...
vllm/spec_decode/spec_decode_worker.py
View file @
dd66fd2b
...
@@ -870,10 +870,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
...
@@ -870,10 +870,10 @@ class SpecDecodeWorker(LoraNotSupportedWorkerBase):
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_token_ids
+
1
# Convert -1 to 0
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
# b
accepted_index
=
accepted_index
.
count_nonzero
(
dim
=
1
).
add_
(
-
1
)
# b
# Drop non-terminal prefill chunks hidden states.
# Drop non-terminal prefill chunks hidden states.
hidden_states
=
hidden_states
[
hidden_states
=
hidden_states
[
accepted_index
!=
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
VLLM_INVALID_TOKEN_ID
]
accepted_index
=
accepted_index
[
accepted_index
=
accepted_index
[
accepted_index
!=
accepted_index
!=
VLLM_INVALID_TOKEN_ID
]
VLLM_INVALID_TOKEN_ID
]
assert
len
(
accepted_index
)
==
hidden_states
.
shape
[
0
]
==
len
(
assert
len
(
accepted_index
)
==
hidden_states
.
shape
[
0
]
==
len
(
terminal_metadata
)
terminal_metadata
)
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
index
=
accepted_index
[:,
None
,
None
].
expand
(
-
1
,
1
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment