Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
978aed53
Unverified
Commit
978aed53
authored
Jul 16, 2024
by
Michael Goin
Committed by
GitHub
Jul 16, 2024
Browse files
[Kernel][Attention] Separate `Attention.kv_scale` into `k_scale` and `v_scale` (#6081)
parent
160e1d8c
Changes
33
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
160 additions
and
86 deletions
+160
-86
vllm/attention/backends/pallas.py
vllm/attention/backends/pallas.py
+3
-2
vllm/attention/backends/rocm_flash_attn.py
vllm/attention/backends/rocm_flash_attn.py
+6
-3
vllm/attention/backends/torch_sdpa.py
vllm/attention/backends/torch_sdpa.py
+7
-4
vllm/attention/backends/xformers.py
vllm/attention/backends/xformers.py
+5
-3
vllm/attention/layer.py
vllm/attention/layer.py
+8
-6
vllm/attention/ops/ipex_attn.py
vllm/attention/ops/ipex_attn.py
+4
-2
vllm/attention/ops/paged_attn.py
vllm/attention/ops/paged_attn.py
+10
-5
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+9
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+38
-13
vllm/model_executor/model_loader/weight_utils.py
vllm/model_executor/model_loader/weight_utils.py
+53
-5
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+5
-14
vllm/model_executor/models/mixtral.py
vllm/model_executor/models/mixtral.py
+6
-15
vllm/model_executor/models/qwen2.py
vllm/model_executor/models/qwen2.py
+6
-14
No files found.
vllm/attention/backends/pallas.py
View file @
978aed53
...
@@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -131,7 +131,8 @@ class PallasAttentionBackendImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]],
kv_cache
:
Tuple
[
Optional
[
torch
.
Tensor
],
Optional
[
torch
.
Tensor
]],
attn_metadata
:
PallasMetadata
,
attn_metadata
:
PallasMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with Pallas attention.
"""Forward pass with Pallas attention.
...
@@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
...
@@ -146,7 +147,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns:
Returns:
shape = [batch_size, seq_len, num_heads * head_size]
shape = [batch_size, seq_len, num_heads * head_size]
"""
"""
assert
kv_scale
==
1.0
assert
k
_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
...
...
vllm/attention/backends/rocm_flash_attn.py
View file @
978aed53
...
@@ -296,7 +296,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -296,7 +296,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
kv_cache
:
torch
.
Tensor
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
attn_metadata
:
ROCmFlashAttentionMetadata
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with FlashAttention and PagedAttention.
"""Forward pass with FlashAttention and PagedAttention.
...
@@ -336,7 +337,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -336,7 +337,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
num_prefill_tokens
=
attn_metadata
.
num_prefill_tokens
...
@@ -456,7 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
...
@@ -456,7 +458,8 @@ class ROCmFlashAttentionImpl(AttentionImpl):
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
vllm/attention/backends/torch_sdpa.py
View file @
978aed53
...
@@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -144,7 +144,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
value
:
torch
.
Tensor
,
value
:
torch
.
Tensor
,
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
attn_metadata
:
TorchSDPAMetadata
,
# type: ignore
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with torch SDPA and PagedAttention.
"""Forward pass with torch SDPA and PagedAttention.
...
@@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -158,7 +159,7 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns:
Returns:
shape = [num_tokens, num_heads * head_size]
shape = [num_tokens, num_heads * head_size]
"""
"""
assert
kv_scale
==
1.0
assert
k
_scale
==
1.0
and
v_scale
==
1.0
if
attn_type
!=
AttentionType
.
DECODER
:
if
attn_type
!=
AttentionType
.
DECODER
:
raise
NotImplementedError
(
"Encoder self-attention and "
raise
NotImplementedError
(
"Encoder self-attention and "
"encoder/decoder cross-attention "
"encoder/decoder cross-attention "
...
@@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -176,7 +177,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
PagedAttention
.
write_to_paged_cache
(
key
,
value
,
key_cache
,
value_cache
,
value_cache
,
attn_metadata
.
slot_mapping
,
attn_metadata
.
slot_mapping
,
self
.
kv_cache_dtype
,
kv_scale
)
self
.
kv_cache_dtype
,
k_scale
,
v_scale
)
if
attn_metadata
.
is_prompt
:
if
attn_metadata
.
is_prompt
:
assert
attn_metadata
.
seq_lens
is
not
None
assert
attn_metadata
.
seq_lens
is
not
None
...
@@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
...
@@ -239,7 +241,8 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
vllm/attention/backends/xformers.py
View file @
978aed53
...
@@ -427,7 +427,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -427,7 +427,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
value
:
Optional
[
torch
.
Tensor
],
value
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
kv_cache
:
Optional
[
torch
.
Tensor
],
attn_metadata
:
"XFormersMetadata"
,
attn_metadata
:
"XFormersMetadata"
,
kv_scale
:
float
=
1.0
,
k_scale
:
float
=
1.0
,
v_scale
:
float
=
1.0
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
attn_type
:
AttentionType
=
AttentionType
.
DECODER
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""Forward pass with xFormers and PagedAttention.
"""Forward pass with xFormers and PagedAttention.
...
@@ -531,7 +532,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -531,7 +532,7 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
value_cache
,
value_cache
,
updated_slot_mapping
,
updated_slot_mapping
,
self
.
kv_cache_dtype
,
self
.
kv_cache_dtype
,
kv_scale
)
k
_scale
,
v_scale
)
if
attn_type
!=
AttentionType
.
ENCODER
:
if
attn_type
!=
AttentionType
.
ENCODER
:
# Decoder self-attention supports chunked prefill.
# Decoder self-attention supports chunked prefill.
...
@@ -620,7 +621,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
...
@@ -620,7 +621,8 @@ class XFormersImpl(AttentionImpl[XFormersMetadata]):
self
.
num_kv_heads
,
self
.
num_kv_heads
,
self
.
scale
,
self
.
scale
,
self
.
alibi_slopes
,
self
.
alibi_slopes
,
kv_scale
,
k_scale
,
v_scale
,
)
)
# Reshape the output tensor.
# Reshape the output tensor.
...
...
vllm/attention/layer.py
View file @
978aed53
...
@@ -47,13 +47,14 @@ class Attention(nn.Module):
...
@@ -47,13 +47,14 @@ class Attention(nn.Module):
if
num_kv_heads
is
None
:
if
num_kv_heads
is
None
:
num_kv_heads
=
num_heads
num_kv_heads
=
num_heads
# The default kv_scale is set to 1.0. This is ignored
# The default k
/
v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
# when kv-cache is not fp8, and should be used with
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# kv-cache in fp8_e5m2. For kv-cache in fp8_e4m3, we
# expect the pre-quantized kv_scale to be loaded along
# expect the pre-quantized k
/
v_scale to be loaded along
# with the model weights.
# with the model weights.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
_kv_scale
=
1.0
self
.
_k_scale
=
1.0
self
.
_v_scale
=
1.0
quant_method
=
quant_config
.
get_quant_method
(
quant_method
=
quant_config
.
get_quant_method
(
self
)
if
quant_config
else
None
self
)
if
quant_config
else
None
if
quant_method
is
not
None
:
if
quant_method
is
not
None
:
...
@@ -66,8 +67,8 @@ class Attention(nn.Module):
...
@@ -66,8 +67,8 @@ class Attention(nn.Module):
"fp8 checkpoints."
)
"fp8 checkpoints."
)
# When FP8 quantization is enabled, we make a parameter
# When FP8 quantization is enabled, we make a parameter
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# "kv_scale" so that it can be loaded from FP8 checkpoint.
# The kv_scale will then be converted back to
self._kv_scale
# The k
/
v_scale will then be converted back to
# in a native float32 value after weight loading
.
#
self._kv_scale
in a native float32 value after weight loading
self
.
quant_method
=
quant_method
self
.
quant_method
=
quant_method
self
.
quant_method
.
create_weights
(
self
)
self
.
quant_method
.
create_weights
(
self
)
...
@@ -98,7 +99,8 @@ class Attention(nn.Module):
...
@@ -98,7 +99,8 @@ class Attention(nn.Module):
value
,
value
,
kv_cache
,
kv_cache
,
attn_metadata
,
attn_metadata
,
self
.
_kv_scale
,
self
.
_k_scale
,
self
.
_v_scale
,
attn_type
=
attn_type
)
attn_type
=
attn_type
)
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
vllm/attention/ops/ipex_attn.py
View file @
978aed53
...
@@ -45,7 +45,8 @@ class PagedAttention:
...
@@ -45,7 +45,8 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
*
args
,
*
args
,
)
->
None
:
)
->
None
:
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
ipex_modules
.
PagedAttention
.
reshape_and_cache
(
...
@@ -64,7 +65,8 @@ class PagedAttention:
...
@@ -64,7 +65,8 @@ class PagedAttention:
num_kv_heads
:
int
,
num_kv_heads
:
int
,
scale
:
float
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
*
args
,
*
args
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
output
=
torch
.
empty_like
(
query
)
output
=
torch
.
empty_like
(
query
)
...
...
vllm/attention/ops/paged_attn.py
View file @
978aed53
...
@@ -66,7 +66,8 @@ class PagedAttention:
...
@@ -66,7 +66,8 @@ class PagedAttention:
value_cache
:
torch
.
Tensor
,
value_cache
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
slot_mapping
:
torch
.
Tensor
,
kv_cache_dtype
:
str
,
kv_cache_dtype
:
str
,
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
)
->
None
:
)
->
None
:
ops
.
reshape_and_cache
(
ops
.
reshape_and_cache
(
key
,
key
,
...
@@ -75,7 +76,8 @@ class PagedAttention:
...
@@ -75,7 +76,8 @@ class PagedAttention:
value_cache
,
value_cache
,
slot_mapping
.
flatten
(),
slot_mapping
.
flatten
(),
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
)
)
@
staticmethod
@
staticmethod
...
@@ -90,7 +92,8 @@ class PagedAttention:
...
@@ -90,7 +92,8 @@ class PagedAttention:
num_kv_heads
:
int
,
num_kv_heads
:
int
,
scale
:
float
,
scale
:
float
,
alibi_slopes
:
Optional
[
torch
.
Tensor
],
alibi_slopes
:
Optional
[
torch
.
Tensor
],
kv_scale
:
float
,
k_scale
:
float
,
v_scale
:
float
,
tp_rank
:
int
=
0
,
tp_rank
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_local_blocks
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
blocksparse_vert_stride
:
int
=
0
,
...
@@ -135,7 +138,8 @@ class PagedAttention:
...
@@ -135,7 +138,8 @@ class PagedAttention:
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
...
@@ -172,7 +176,8 @@ class PagedAttention:
...
@@ -172,7 +176,8 @@ class PagedAttention:
max_seq_len
,
max_seq_len
,
alibi_slopes
,
alibi_slopes
,
kv_cache_dtype
,
kv_cache_dtype
,
kv_scale
,
k_scale
,
v_scale
,
tp_rank
,
tp_rank
,
blocksparse_local_blocks
,
blocksparse_local_blocks
,
blocksparse_vert_stride
,
blocksparse_vert_stride
,
...
...
vllm/model_executor/layers/linear.py
View file @
978aed53
...
@@ -196,6 +196,15 @@ class ReplicatedLinear(LinearBase):
...
@@ -196,6 +196,15 @@ class ReplicatedLinear(LinearBase):
else
:
else
:
self
.
register_parameter
(
"bias"
,
None
)
self
.
register_parameter
(
"bias"
,
None
)
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
978aed53
...
@@ -407,31 +407,56 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
...
@@ -407,31 +407,56 @@ class Fp8KVCacheMethod(QuantizeMethodBase):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
):
"""Create "weight" (aka kv_scale) for an attention layer.
"""Create "weight" (aka k
_scale and
v_scale) for an attention layer.
Args:
Args:
layer: The layer that is using the QuantizeMethodBase factory.
layer: The layer that is using the QuantizeMethodBase factory.
"""
"""
# Initialize the KV cache scale to 1.0
as the default
value.
# Initialize the KV cache scale
s
to
-
1.0
, which is an invalid
value.
# If the kv_scale appears in the checkpoint, it will be
# If the k
/
v_scale appears in the checkpoint, it will be
# overwritten when loading weights.
# overwritten when loading weights.
layer
.
kv_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
layer
.
k_scale
=
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
layer
.
v_scale
=
Parameter
(
torch
.
tensor
(
-
1.0
),
requires_grad
=
False
)
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
)
->
torch
.
Tensor
:
raise
RuntimeError
(
"Fp8KVCacheMethod.apply should not be called."
)
raise
RuntimeError
(
"Fp8KVCacheMethod.apply should not be called."
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If the kv-cache dtype is auto, we enforce the k
v-
scale to be 1.0
# If the kv-cache dtype is auto, we enforce the k
/v_
scale to be 1.0
# regardless whether the kv-scale is available in the checkpoint.
# regardless whether the kv-scale is available in the checkpoint.
if
layer
.
kv_cache_dtype
!=
"auto"
:
if
layer
.
kv_cache_dtype
!=
"auto"
:
kv_scale
=
layer
.
kv_scale
.
to
(
"cpu"
).
tolist
()
if
layer
.
k_scale
>
0.0
and
layer
.
v_scale
>
0.0
:
if
not
isinstance
(
kv_scale
,
float
):
# We prefer to use separate k_scale and v_scale if present
k_scale
=
layer
.
k_scale
.
to
(
"cpu"
).
tolist
()
v_scale
=
layer
.
v_scale
.
to
(
"cpu"
).
tolist
()
elif
layer
.
k_scale
<
0.0
and
layer
.
v_scale
<
0.0
:
# If no scales were loaded (both scales are invalid negative
# values), use the default value of 1.0
k_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
v_scale
=
Parameter
(
torch
.
tensor
(
1.0
),
requires_grad
=
False
)
else
:
# If we find a single kv_scale in the checkpoint, we remap
# kv_scale to k_scale during weight loading, and duplicate
# k_scale to v_scale here
assert
layer
.
k_scale
>
0.0
scale_to_duplicate
=
max
(
layer
.
k_scale
,
layer
.
v_scale
)
k_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
v_scale
=
scale_to_duplicate
.
to
(
"cpu"
).
tolist
()
if
not
isinstance
(
k_scale
,
float
)
or
not
isinstance
(
v_scale
,
float
):
raise
ValueError
(
"Only support per-tensor scaling factor "
raise
ValueError
(
"Only support per-tensor scaling factor "
"for fp8 KV cache"
)
"for fp8 KV cache"
)
layer
.
_kv_scale
=
kv_scale
if
layer
.
_kv_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
:
# These are used in the final Attention.forward()
layer
.
_k_scale
=
k_scale
layer
.
_v_scale
=
v_scale
if
(
layer
.
_k_scale
==
1.0
and
layer
.
_v_scale
==
1.0
and
"e5m2"
not
in
layer
.
kv_cache_dtype
):
print_warning_once
(
print_warning_once
(
"Using KV cache scaling factor 1.0 for fp8_e4m3. This may "
"Using KV cache scaling factor 1.0 for fp8_e4m3. This "
"cause accuracy issues. Please make sure kv-cache scaling "
"may cause accuracy issues. Please make sure k/v_scale "
"factor is available in the fp8 checkpoint."
)
"scaling factors are available in the fp8 checkpoint."
)
del
layer
.
kv_scale
del
layer
.
k_scale
del
layer
.
v_scale
vllm/model_executor/model_loader/weight_utils.py
View file @
978aed53
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
...
@@ -22,6 +22,7 @@ from vllm.logger import init_logger
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
from
vllm.model_executor.layers.quantization
import
(
QuantizationConfig
,
get_quantization_config
)
get_quantization_config
)
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.model_executor.layers.quantization.schema
import
QuantParamSchema
from
vllm.utils
import
print_warning_once
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
...
@@ -431,11 +432,6 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor:
def
default_weight_loader
(
param
:
torch
.
Tensor
,
def
default_weight_loader
(
param
:
torch
.
Tensor
,
loaded_weight
:
torch
.
Tensor
)
->
None
:
loaded_weight
:
torch
.
Tensor
)
->
None
:
"""Default weight loader."""
"""Default weight loader."""
# If the weight on disk does not have a shape, give it one
# (such scales for AutoFp8).
if
len
(
loaded_weight
.
shape
)
==
0
:
loaded_weight
=
loaded_weight
.
reshape
(
1
)
assert
param
.
size
()
==
loaded_weight
.
size
()
assert
param
.
size
()
==
loaded_weight
.
size
()
param
.
data
.
copy_
(
loaded_weight
)
param
.
data
.
copy_
(
loaded_weight
)
...
@@ -462,3 +458,55 @@ def initialize_dummy_weights(
...
@@ -462,3 +458,55 @@ def initialize_dummy_weights(
param
.
data
.
copy_
(
tmp_param
)
param
.
data
.
copy_
(
tmp_param
)
else
:
else
:
param
.
uniform_
(
low
,
high
)
param
.
uniform_
(
low
,
high
)
def
maybe_remap_kv_scale_name
(
name
:
str
,
params_dict
:
dict
)
->
Optional
[
str
]:
"""Remap the name of FP8 k/v_scale parameters.
This function handles the remapping of FP8 k/v_scale parameter names.
It detects if the given name ends with a suffix and attempts to remap
it to the expected name format in the model. If the remapped name is not
found in the params_dict, a warning is printed and None is returned.
Args:
name (str): The original loaded checkpoint parameter name.
params_dict (dict): Dictionary containing the model's named parameters.
Returns:
str: The remapped parameter name if successful, or the original name
if no remapping is needed.
None: If the remapped name is not found in params_dict.
"""
if
name
.
endswith
(
".kv_scale"
):
print_warning_once
(
"DEPRECATED. Found kv_scale in the checkpoint. "
"This format is deprecated in favor of separate k_scale and "
"v_scale tensors and will be removed in a future release. "
"Functionally, we will remap kv_scale to k_scale and duplicate "
"k_scale to v_scale"
)
# NOTE: we remap the deprecated kv_scale to k_scale
remapped_name
=
name
.
replace
(
".kv_scale"
,
".attn.k_scale"
)
if
remapped_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv_scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
). kv_scale is "
"not loaded."
)
return
None
return
remapped_name
possible_scale_names
=
[
".k_scale"
,
".v_scale"
]
for
scale_name
in
possible_scale_names
:
if
name
.
endswith
(
scale_name
):
remapped_name
=
name
.
replace
(
scale_name
,
f
".attn
{
scale_name
}
"
)
if
remapped_name
not
in
params_dict
:
print_warning_once
(
f
"Found
{
scale_name
}
in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_name
}
).
{
scale_name
}
is "
"not loaded."
)
return
None
return
remapped_name
# If there were no matches, return the untouched param name
return
name
vllm/model_executor/models/llama.py
View file @
978aed53
...
@@ -44,10 +44,10 @@ from vllm.model_executor.layers.sampler import Sampler
...
@@ -44,10 +44,10 @@ from vllm.model_executor.layers.sampler import Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
(
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
kv_cache_scales_loader
)
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
is_hip
,
print_warning_once
from
vllm.utils
import
is_hip
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
from
.utils
import
is_pp_missing_parameter
,
make_layers
from
.utils
import
is_pp_missing_parameter
,
make_layers
...
@@ -460,18 +460,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
...
@@ -460,18 +460,9 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
remapped_kv_scale_name
=
name
.
replace
(
if
name
is
None
:
".kv_scale"
,
".attn.kv_scale"
)
continue
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). kv-scale is "
"not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
if
is_pp_missing_parameter
(
name
,
self
):
if
is_pp_missing_parameter
(
name
,
self
):
continue
continue
...
...
vllm/model_executor/models/mixtral.py
View file @
978aed53
...
@@ -42,10 +42,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -42,10 +42,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
DEFAULT_VOCAB_PADDING_SIZE
,
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
...
@@ -415,19 +415,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
...
@@ -415,19 +415,10 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
remapped_kv_scale_name
=
name
.
replace
(
if
name
is
None
:
".kv_scale"
,
".attn.kv_scale"
)
continue
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
"Found kv scale in the checkpoint "
f
"(e.g.
{
name
}
), but not found the expected "
f
"name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). "
"kv-scale is not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
vllm/model_executor/models/qwen2.py
View file @
978aed53
...
@@ -43,10 +43,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
...
@@ -43,10 +43,10 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
ParallelLMHead
,
VocabParallelEmbedding
)
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
vllm.model_executor.model_loader.weight_utils
import
(
default_weight_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.utils
import
print_warning_once
from
.interfaces
import
SupportsLoRA
from
.interfaces
import
SupportsLoRA
...
@@ -382,18 +382,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
...
@@ -382,18 +382,10 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
continue
# Remapping the name of FP8 kv-scale.
# Remapping the name of FP8 kv-scale.
if
name
.
endswith
(
"kv_scale"
):
name
=
maybe_remap_kv_scale_name
(
name
,
params_dict
)
remapped_kv_scale_name
=
name
.
replace
(
if
name
is
None
:
".kv_scale"
,
".attn.kv_scale"
)
continue
if
remapped_kv_scale_name
not
in
params_dict
:
print_warning_once
(
f
"Found kv scale in the checkpoint (e.g.
{
name
}
), "
"but not found the expected name in the model "
f
"(e.g.
{
remapped_kv_scale_name
}
). kv-scale is "
"not loaded."
)
continue
else
:
name
=
remapped_kv_scale_name
param
=
params_dict
[
name
]
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
default_weight_loader
)
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment