Commit 0b2c14e3 authored by zhuwenwen's avatar zhuwenwen
Browse files

fix glm fp8-e4m3 acc error

parent 770d33f9
...@@ -137,6 +137,10 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -137,6 +137,10 @@ class Attention(nn.Module, AttentionLayerBase):
# with the model weights. # with the model weights.
self.kv_cache_dtype = kv_cache_dtype self.kv_cache_dtype = kv_cache_dtype
self.calculate_kv_scales = calculate_kv_scales self.calculate_kv_scales = calculate_kv_scales
if self.kv_cache_dtype in {"fp8", "fp8_e4m3"} :
self.check_fp8_overflow = True
else:
self.check_fp8_overflow = False
self._k_scale = torch.tensor(1.0, dtype=torch.float32) self._k_scale = torch.tensor(1.0, dtype=torch.float32)
self._v_scale = torch.tensor(1.0, dtype=torch.float32) self._v_scale = torch.tensor(1.0, dtype=torch.float32)
# FlashAttn doesn't support quantizing the kv-cache only # FlashAttn doesn't support quantizing the kv-cache only
...@@ -281,12 +285,13 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -281,12 +285,13 @@ class Attention(nn.Module, AttentionLayerBase):
context using context using
`vllm.forward_context.get_forward_context().attn_metadata`. `vllm.forward_context.get_forward_context().attn_metadata`.
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales or self.check_fp8_overflow:
# attn_metadata = get_forward_context().attn_metadata # attn_metadata = get_forward_context().attn_metadata
# if attn_metadata.enable_kv_scales_calculation: # if attn_metadata.enable_kv_scales_calculation:
# self.calc_kv_scales(query, key, value) # self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name) self.layer_name)
self.check_fp8_overflow=False
output_dtype = query.dtype output_dtype = query.dtype
if self.query_quant is not None: if self.query_quant is not None:
...@@ -583,6 +588,8 @@ def maybe_calc_kv_scales( ...@@ -583,6 +588,8 @@ def maybe_calc_kv_scales(
# Only calculate if the layer's calculate_kv_scales flag is True # Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass # This flag gets set to False after the first forward pass
if self.check_fp8_overflow and torch.abs(query).max().item()>200:
self.calculate_kv_scales=True
if not self.calculate_kv_scales: if not self.calculate_kv_scales:
return return
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment