Commit 2830c329 authored by zhangshao's avatar zhangshao
Browse files

Merge branch 'fp8e5m2_bw' into 'v0.11.0-dev'

解决fp8 kv cache scale错误问题

See merge request dcutoolkit/deeplearing/vllm!455
parents f03e2ab3 bf93e83b
......@@ -358,10 +358,11 @@ class Attention(nn.Module, AttentionLayerBase):
query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value):
self.check_fp8_overflow = False
if self.calculate_kv_scales == False :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow
return
if torch.abs(query).max().item()>=0.01 : #check fp8 too small
if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()>=0.01 : #check fp8 too small
return
bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
if torch.abs(query).max().item() < 0.01:
......@@ -378,7 +379,7 @@ class Attention(nn.Module, AttentionLayerBase):
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
self.check_fp8_overflow = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment