Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
0b2c14e3
Commit
0b2c14e3
authored
Feb 24, 2026
by
zhuwenwen
Browse files
fix glm fp8-e4m3 acc error
parent
770d33f9
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
1 deletion
+8
-1
vllm/attention/layer.py
vllm/attention/layer.py
+8
-1
No files found.
vllm/attention/layer.py
View file @
0b2c14e3
...
@@ -137,6 +137,10 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -137,6 +137,10 @@ class Attention(nn.Module, AttentionLayerBase):
# with the model weights.
# with the model weights.
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
kv_cache_dtype
=
kv_cache_dtype
self
.
calculate_kv_scales
=
calculate_kv_scales
self
.
calculate_kv_scales
=
calculate_kv_scales
if
self
.
kv_cache_dtype
in
{
"fp8"
,
"fp8_e4m3"
}
:
self
.
check_fp8_overflow
=
True
else
:
self
.
check_fp8_overflow
=
False
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_k_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
self
.
_v_scale
=
torch
.
tensor
(
1.0
,
dtype
=
torch
.
float32
)
# FlashAttn doesn't support quantizing the kv-cache only
# FlashAttn doesn't support quantizing the kv-cache only
...
@@ -281,12 +285,13 @@ class Attention(nn.Module, AttentionLayerBase):
...
@@ -281,12 +285,13 @@ class Attention(nn.Module, AttentionLayerBase):
context using
context using
`vllm.forward_context.get_forward_context().attn_metadata`.
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
"""
if
self
.
calculate_kv_scales
:
if
self
.
calculate_kv_scales
or
self
.
check_fp8_overflow
:
# attn_metadata = get_forward_context().attn_metadata
# attn_metadata = get_forward_context().attn_metadata
# if attn_metadata.enable_kv_scales_calculation:
# if attn_metadata.enable_kv_scales_calculation:
# self.calc_kv_scales(query, key, value)
# self.calc_kv_scales(query, key, value)
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
torch
.
ops
.
vllm
.
maybe_calc_kv_scales
(
query
,
key
,
value
,
self
.
layer_name
)
self
.
layer_name
)
self
.
check_fp8_overflow
=
False
output_dtype
=
query
.
dtype
output_dtype
=
query
.
dtype
if
self
.
query_quant
is
not
None
:
if
self
.
query_quant
is
not
None
:
...
@@ -583,6 +588,8 @@ def maybe_calc_kv_scales(
...
@@ -583,6 +588,8 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
# This flag gets set to False after the first forward pass
if
self
.
check_fp8_overflow
and
torch
.
abs
(
query
).
max
().
item
()
>
200
:
self
.
calculate_kv_scales
=
True
if
not
self
.
calculate_kv_scales
:
if
not
self
.
calculate_kv_scales
:
return
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment