Commit bf93e83b authored by zhangshao's avatar zhangshao
Browse files

解决fp8 kv cache scale错误问题

parent 3fe3d07c
...@@ -358,10 +358,11 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -358,10 +358,11 @@ class Attention(nn.Module, AttentionLayerBase):
query, key, value, self.layer_name) query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value): def calc_kv_scales(self, query, key, value):
self.check_fp8_overflow = False
if self.calculate_kv_scales == False : if self.calculate_kv_scales == False :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow
return return
if torch.abs(query).max().item()>=0.01 : #check fp8 too small if self.kv_cache_dtype in {"fp8_e5m2"} and torch.abs(query).max().item()>=0.01 : #check fp8 too small
return return
bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
if torch.abs(query).max().item() < 0.01: if torch.abs(query).max().item() < 0.01:
...@@ -378,7 +379,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -378,7 +379,7 @@ class Attention(nn.Module, AttentionLayerBase):
self._v_scale_float = self._v_scale.item() self._v_scale_float = self._v_scale.item()
# We only calculate the scales once # We only calculate the scales once
self.calculate_kv_scales = False self.calculate_kv_scales = False
self.check_fp8_overflow = False
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment