Commit 3fe3d07c authored by zhangshao's avatar zhangshao
Browse files

修复qwen vl系列kv cache e5m2计算scale bug

parent ef16700d
...@@ -137,7 +137,7 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -137,7 +137,7 @@ class Attention(nn.Module, AttentionLayerBase):
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} : if self.kv_cache_dtype in {"fp8", "fp8_e4m3","fp8_e5m2"} :
self.check_fp8_overflow = True self.check_fp8_overflow = True
else: else:
self.check_fp8_overflow = False self.check_fp8_overflow = False
...@@ -291,7 +291,6 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -291,7 +291,6 @@ class Attention(nn.Module, AttentionLayerBase):
# self.calc_kv_scales(query, key, value) # self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name) self.layer_name)
self.check_fp8_overflow=False
output_dtype = query.dtype output_dtype = query.dtype
if self.query_quant is not None: if self.query_quant is not None:
...@@ -359,14 +358,27 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -359,14 +358,27 @@ class Attention(nn.Module, AttentionLayerBase):
query, key, value, self.layer_name) query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value): def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range) if self.calculate_kv_scales == False :
self._k_scale.copy_(torch.abs(key).max() / self.k_range) if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow
self._v_scale.copy_(torch.abs(value).max() / self.v_range) return
if torch.abs(query).max().item()>=0.01 : #check fp8 too small
return
bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
if torch.abs(query).max().item() < 0.01:
if self.kv_cache_dtype in {"fp8_e5m2"}:
bias = 0.1
else :
bias = 1.0
self._q_scale.copy_(torch.abs(query).max() / self.q_range+bias)
self._k_scale.copy_(torch.abs(key).max() / self.k_range+bias)
self._v_scale.copy_(torch.abs(value).max() / self.v_range+bias)
self._q_scale_float = self._q_scale.item() self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item() self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item() self._v_scale_float = self._v_scale.item()
# We only calculate the scales once # We only calculate the scales once
self.calculate_kv_scales = False self.calculate_kv_scales = False
self.check_fp8_overflow = False
def extra_repr(self) -> str: def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore s = f"head_size={self.impl.head_size}" # type: ignore
...@@ -588,10 +600,6 @@ def maybe_calc_kv_scales( ...@@ -588,10 +600,6 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True # Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass # This flag gets set to False after the first forward pass
if self.check_fp8_overflow and torch.abs(query).max().item()>200:
self.calculate_kv_scales=True
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value) self.calc_kv_scales(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment