"...git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "071d863e208b40fa1bb986ad230e322b2bbbbcf5"
Commit 4dcfd0ae authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.11.0-dev-kvscale' into 'v0.11.0-dev'

V0.11.0 dev kvscale

See merge request dcutoolkit/deeplearing/vllm!378
parents ebf3d1d8 c77bc77c
...@@ -282,9 +282,11 @@ class Attention(nn.Module, AttentionLayerBase): ...@@ -282,9 +282,11 @@ class Attention(nn.Module, AttentionLayerBase):
`vllm.forward_context.get_forward_context().attn_metadata`. `vllm.forward_context.get_forward_context().attn_metadata`.
""" """
if self.calculate_kv_scales: if self.calculate_kv_scales:
attn_metadata = get_forward_context().attn_metadata # attn_metadata = get_forward_context().attn_metadata
if attn_metadata.enable_kv_scales_calculation: # if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(query, key, value) # self.calc_kv_scales(query, key, value)
torch.ops.vllm.maybe_calc_kv_scales(query, key, value,
self.layer_name)
output_dtype = query.dtype output_dtype = query.dtype
if self.query_quant is not None: if self.query_quant is not None:
...@@ -570,6 +572,38 @@ def maybe_save_kv_layer_to_connector( ...@@ -570,6 +572,38 @@ def maybe_save_kv_layer_to_connector(
connector.save_kv_layer(layer_name, kv_cache_layer, connector.save_kv_layer(layer_name, kv_cache_layer,
attn_metadata[layer_name]) attn_metadata[layer_name])
def maybe_calc_kv_scales(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
forward_context: ForwardContext = get_forward_context()
self = forward_context.no_compile_layers[layer_name]
# Only calculate if the layer's calculate_kv_scales flag is True
# This flag gets set to False after the first forward pass
if not self.calculate_kv_scales:
return
self.calc_kv_scales(query, key, value)
def maybe_calc_kv_scales_fake(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
layer_name: str,
) -> None:
return
direct_register_custom_op(
op_name="maybe_calc_kv_scales",
op_func=maybe_calc_kv_scales,
mutates_args=["query", "key", "value"],
fake_impl=maybe_calc_kv_scales_fake,
)
def unified_attention( def unified_attention(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -593,9 +593,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -593,9 +593,9 @@ class FlashAttentionImpl(AttentionImpl):
block_table = attn_metadata.block_table block_table = attn_metadata.block_table
scheduler_metadata = attn_metadata.scheduler_metadata scheduler_metadata = attn_metadata.scheduler_metadata
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads) # descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
if not current_platform.is_rocm(): if not current_platform.is_rocm():
descale_shape = (cu_seqlens_q.shape[0] - 1, self.num_kv_heads)
flash_attn_varlen_func( flash_attn_varlen_func(
q=query[:num_actual_tokens], q=query[:num_actual_tokens],
k=key_cache, k=key_cache,
...@@ -643,8 +643,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -643,8 +643,9 @@ class FlashAttentionImpl(AttentionImpl):
scheduler_metadata=scheduler_metadata, scheduler_metadata=scheduler_metadata,
# fa_version=self.vllm_flash_attn_version, # fa_version=self.vllm_flash_attn_version,
# q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape), q_descale=None,
# v_descale=layer._v_scale.expand(descale_shape), k_descale=layer._k_scale,
v_descale=layer._v_scale,
# num_splits=attn_metadata.max_num_splits, # num_splits=attn_metadata.max_num_splits,
s_aux=self.sinks, s_aux=self.sinks,
is_prefix_cache=True, is_prefix_cache=True,
...@@ -699,8 +700,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -699,8 +700,9 @@ class FlashAttentionImpl(AttentionImpl):
prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata, prefix_scheduler_metadata=attn_metadata.prefix_scheduler_metadata,
suffix_scheduler_metadata=attn_metadata.scheduler_metadata, suffix_scheduler_metadata=attn_metadata.scheduler_metadata,
# q_descale=layer._q_scale, # q_descale=layer._q_scale,
# k_descale=layer._k_scale, q_descale=None,
# v_descale=layer._v_scale, k_descale=layer._k_scale,
v_descale=layer._v_scale,
) )
return output return output
...@@ -778,6 +780,9 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -778,6 +780,9 @@ class FlashAttentionImpl(AttentionImpl):
# q_descale=layer._q_scale.expand(descale_shape), # q_descale=layer._q_scale.expand(descale_shape),
# k_descale=layer._k_scale.expand(descale_shape), # k_descale=layer._k_scale.expand(descale_shape),
# v_descale=layer._v_scale.expand(descale_shape), # v_descale=layer._v_scale.expand(descale_shape),
q_descale=None,
k_descale=layer._k_scale,
v_descale=layer._v_scale,
is_prefix_cache=False, is_prefix_cache=False,
) )
...@@ -931,12 +936,12 @@ def cascade_attention( ...@@ -931,12 +936,12 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=prefix_scheduler_metadata, scheduler_metadata=prefix_scheduler_metadata,
# fa_version=fa_version, # fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None, if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None, if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None, if v_descale is not None else None,
is_prefix_cache=True, is_prefix_cache=True,
) )
...@@ -984,12 +989,12 @@ def cascade_attention( ...@@ -984,12 +989,12 @@ def cascade_attention(
return_softmax_lse=True, return_softmax_lse=True,
scheduler_metadata=suffix_scheduler_metadata, scheduler_metadata=suffix_scheduler_metadata,
# fa_version=fa_version, # fa_version=fa_version,
# q_descale=q_descale.expand(descale_shape) q_descale=q_descale.expand(descale_shape)
# if q_descale is not None else None, if q_descale is not None else None,
# k_descale=k_descale.expand(descale_shape) k_descale=k_descale.expand(descale_shape)
# if k_descale is not None else None, if k_descale is not None else None,
# v_descale=v_descale.expand(descale_shape) v_descale=v_descale.expand(descale_shape)
# if v_descale is not None else None, if v_descale is not None else None,
is_prefix_cache=True, is_prefix_cache=True,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment