Commit 3d01cce7 authored by xiabo's avatar xiabo
Browse files

1、kvcache支持fp8的scale

parent 6dcb89d2
......@@ -282,9 +282,11 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`.
"""
if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value)
# attn_metadata = get_forward_context().attn_metadata
# if attn_metadata.enable_kv_scales_calculation:
# self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
output_dtype = query.dtype
if self.query_quant is not None:
......@@ -570,6 +572,38 @@ def maybe_save_kv_layer_to_connector(
connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)
def unified_attention(
query: torch.Tensor,
......
......@@ -593,9 +593,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
# descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
if not current_platform.is_rocm():
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
flash_attn_varlen_func(
q=query[:num_actual_tokens],
k=key_cache,
......@@ -643,8 +643,9 @@ class FlashAttentionImpl(AttentionImpl):
scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks,
is_prefix_cache=True,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment