"vscode:/vscode.git/clone" did not exist on "78d13ea9de4b1ce5e4d8a5af9738fea71fb024e5"
Commit f03e2ab3 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'fp8e5m2_bw' into 'v0.11.0-dev'

修复qwen vl系列kv cache e5m2计算scale bug

See merge request dcutoolkit/deeplearing/vllm!452
parents ef16700d 3fe3d07c
......@@ -137,7 +137,7 @@ class Attention(nn.Module, AttentionLayerBase):
# with the model weights.
self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3","fp8_e5m2"} :
self.check_fp8_overflow = True
else:
self.check_fp8_overflow = False
......@@ -291,7 +291,6 @@ class Attention(nn.Module, AttentionLayerBase):
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
self.check_fp8_overflow=False
output_dtype = query.dtype
if self.query_quant is not None:
......@@ -359,14 +358,27 @@ class Attention(nn.Module, AttentionLayerBase):
query, key, value, self.layer_name)
def calc_kv_scales(self, query, key, value):
self._q_scale.copy_(torch.abs(query).max() / self.q_range)
self._k_scale.copy_(torch.abs(key).max() / self.k_range)
self._v_scale.copy_(torch.abs(value).max() / self.v_range)
if self.calculate_kv_scales == False :
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} and torch.abs(query).max().item()<=200 : #check fp8 overflow
return
if torch.abs(query).max().item()>=0.01 : #check fp8 too small
return
bias=0.0 # add bias to avoid q values are too small(or zeros) and scales are not correct
if torch.abs(query).max().item() < 0.01:
if self.kv_cache_dtype in {"fp8_e5m2"}:
bias = 0.1
else :
bias = 1.0
self._q_scale.copy_(torch.abs(query).max() / self.q_range+bias)
self._k_scale.copy_(torch.abs(key).max() / self.k_range+bias)
self._v_scale.copy_(torch.abs(value).max() / self.v_range+bias)
self._q_scale_float = self._q_scale.item()
self._k_scale_float = self._k_scale.item()
self._v_scale_float = self._v_scale.item()
# We only calculate the scales once
self.calculate_kv_scales = False
self.check_fp8_overflow = False
def extra_repr(self) -> str:
s = f"head_size={self.impl.head_size}" # type: ignore
......@@ -588,10 +600,6 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if self.check_fp8_overflow and torch.abs(query).max().item()>200:
self.calculate_kv_scales=True
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment