Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
598b6d7b
Unverified
Commit
598b6d7b
authored
Nov 01, 2024
by
Pavani Majety
Committed by
GitHub
Nov 01, 2024
Browse files
[Bugfix/Core] Flashinfer k_scale and v_scale (#9861)
parent
aff1fd81
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
25 additions
and
12 deletions
+25
-12
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+14
-7
vllm/attention/backends/flashinfer.py
vllm/attention/backends/flashinfer.py
+6
-3
vllm/model_executor/layers/quantization/modelopt.py
vllm/model_executor/layers/quantization/modelopt.py
+5
-2
No files found.
tests/kernels/test_cache.py
View file @
598b6d7b
...
@@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
...
@@ -258,19 +258,20 @@ def test_reshape_and_cache_flash(
del
key_caches
del
key_caches
del
value_caches
del
value_caches
k_scale
=
key
.
amax
().
item
()
/
256
v_scale
=
value
.
amax
().
item
()
/
256
# Clone the KV caches.
# Clone the KV caches.
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
cloned_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
)
ops
.
convert_fp8
(
cloned_key_cache
,
key_cache
,
k_scale
,
kv_cache_dtype
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
cloned_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
)
ops
.
convert_fp8
(
cloned_value_cache
,
value_cache
,
v_scale
,
kv_cache_dtype
)
else
:
else
:
cloned_key_cache
=
key_cache
.
clone
()
cloned_key_cache
=
key_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
cloned_value_cache
=
value_cache
.
clone
()
# Using default kv_scale
k_scale
=
v_scale
=
1.0
# Call the reshape_and_cache kernel.
# Call the reshape_and_cache kernel.
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
opcheck
(
torch
.
ops
.
_C_cache_ops
.
reshape_and_cache_flash
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
(
key
,
value
,
key_cache
,
value_cache
,
slot_mapping
,
kv_cache_dtype
,
...
@@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(
...
@@ -281,9 +282,15 @@ def test_reshape_and_cache_flash(
if
kv_cache_dtype
==
"fp8"
:
if
kv_cache_dtype
==
"fp8"
:
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
result_key_cache
=
torch
.
empty_like
(
key_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
)
ops
.
convert_fp8
(
result_key_cache
,
key_cache
,
k_scale
,
kv_dtype
=
kv_cache_dtype
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
result_value_cache
=
torch
.
empty_like
(
value_cache
,
dtype
=
torch
.
float16
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
)
ops
.
convert_fp8
(
result_value_cache
,
value_cache
,
v_scale
,
kv_dtype
=
kv_cache_dtype
)
# Run the reference implementation.
# Run the reference implementation.
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"floor"
)
...
...
vllm/attention/backends/flashinfer.py
View file @
598b6d7b
...
@@ -759,8 +759,6 @@ class FlashInferImpl(AttentionImpl):
...
@@ -759,8 +759,6 @@ class FlashInferImpl(AttentionImpl):
v_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
assert
k_scale
==
1.0
and
v_scale
==
1.0
,
(
"key/v_scale is not supported in FlashInfer."
)
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
...
@@ -874,7 +872,12 @@ def unified_flash_infer(
...
@@ -874,7 +872,12 @@ def unified_flash_infer(
assert
prefill_meta
is
not
None
assert
prefill_meta
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
assert
prefill_meta
.
prefill_wrapper
is
not
None
prefill_output
=
prefill_meta
.
prefill_wrapper
.
forward
(
prefill_output
=
prefill_meta
.
prefill_wrapper
.
forward
(
query
,
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
)
query
,
kv_cache
,
logits_soft_cap
=
logits_soft_cap
,
causal
=
True
,
k_scale
=
k_scale
,
v_scale
=
v_scale
)
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
if
decode_meta
:
=
attn_metadata
.
decode_metadata
:
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
assert
attn_metadata
.
decode_metadata
.
decode_wrapper
is
not
None
...
...
vllm/model_executor/layers/quantization/modelopt.py
View file @
598b6d7b
...
@@ -141,6 +141,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
...
@@ -141,6 +141,9 @@ class ModelOptFp8LinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"input_scale"
,
scale
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
weight
=
layer
.
weight
max_w_scale
=
layer
.
weight_scale
.
max
()
if
not
(
layer
.
weight_scale
==
layer
.
weight_scale
[
0
]).
all
():
max_w_scale
,
weight
=
requantize_with_max_scale
(
max_w_scale
,
weight
=
requantize_with_max_scale
(
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight
,
layer
.
weight_scale
,
layer
.
logical_widths
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
layer
.
weight
=
Parameter
(
weight
.
t
(),
requires_grad
=
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment