Unverified Commit 6e98f6d8 authored by Taeksang Kim's avatar Taeksang Kim Committed by GitHub
Browse files

Implement zero-copy GQA for multimodal and CPU (#33732)


Signed-off-by: default avatarTaeksang Kim <ts.kim@hyperaccel.ai>
parent 2f6d17cb
...@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp): ...@@ -80,7 +80,7 @@ class MMEncoderAttention(CustomOp):
def enabled(cls) -> bool: def enabled(cls) -> bool:
return True return True
def maybe_reshape_qkv_to_4d( def view_qkv_to_4d(
self, self,
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
...@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp): ...@@ -97,11 +97,6 @@ class MMEncoderAttention(CustomOp):
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size) key = key.view(bsz, kv_len, self.num_kv_heads, self.head_size)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_size)
if (num_repeat := self.num_queries_per_kv) > 1:
# Handle MQA and GQA
key = torch.repeat_interleave(key, num_repeat, dim=2)
value = torch.repeat_interleave(value, num_repeat, dim=2)
return query, key, value return query, key, value
def _forward_sdpa( def _forward_sdpa(
...@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp): ...@@ -119,9 +114,7 @@ class MMEncoderAttention(CustomOp):
kv_len = key.size(1) kv_len = key.size(1)
is_reshaped = query.dim() != 4 is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d( query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
query, key, value, bsz, q_len, kv_len
)
output = vit_torch_sdpa_wrapper( output = vit_torch_sdpa_wrapper(
q=query, q=query,
...@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp): ...@@ -129,6 +122,7 @@ class MMEncoderAttention(CustomOp):
v=value, v=value,
scale=self.scale, scale=self.scale,
cu_seqlens=cu_seqlens, cu_seqlens=cu_seqlens,
enable_gqa=self.num_heads > self.num_kv_heads,
) )
if is_reshaped: if is_reshaped:
output = output.reshape(bsz, q_len, -1) output = output.reshape(bsz, q_len, -1)
...@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp): ...@@ -154,9 +148,7 @@ class MMEncoderAttention(CustomOp):
kv_len = key.size(1) kv_len = key.size(1)
is_reshaped = query.dim() != 4 is_reshaped = query.dim() != 4
query, key, value = self.maybe_reshape_qkv_to_4d( query, key, value = self.view_qkv_to_4d(query, key, value, bsz, q_len, kv_len)
query, key, value, bsz, q_len, kv_len
)
output = vit_flash_attn_wrapper( output = vit_flash_attn_wrapper(
q=query, q=query,
......
...@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module): ...@@ -628,18 +628,6 @@ class ImagePoolingAttention(nn.Module):
key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim) key = key.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim) value = value.view(bsz, kv_len, self.num_kv_heads, self.head_dim)
if self.num_heads != self.num_kv_heads:
key = torch.repeat_interleave(
key,
self.num_heads // self.num_kv_heads,
dim=2,
)
value = torch.repeat_interleave(
value,
self.num_heads // self.num_kv_heads,
dim=2,
)
query, key, value = (x.transpose(1, 2) for x in (query, key, value)) query, key, value = (x.transpose(1, 2) for x in (query, key, value))
out = F.scaled_dot_product_attention( out = F.scaled_dot_product_attention(
...@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module): ...@@ -648,6 +636,7 @@ class ImagePoolingAttention(nn.Module):
value, value,
attn_mask=attn_mask, attn_mask=attn_mask,
is_causal=False, is_causal=False,
enable_gqa=self.num_heads > self.num_kv_heads,
).transpose(1, 2) ).transpose(1, 2)
return out.reshape(bsz, q_len, -1) return out.reshape(bsz, q_len, -1)
......
...@@ -398,10 +398,6 @@ class CPUAttentionBackendImpl(AttentionImpl): ...@@ -398,10 +398,6 @@ class CPUAttentionBackendImpl(AttentionImpl):
key = key.movedim(0, key.dim() - 2) key = key.movedim(0, key.dim() - 2)
value = value.movedim(0, value.dim() - 2) value = value.movedim(0, value.dim() - 2)
if self.num_kv_heads != self.num_heads:
key = key.repeat_interleave(self.num_queries_per_kv, dim=-3)
value = value.repeat_interleave(self.num_queries_per_kv, dim=-3)
causal_attn = attn_type == AttentionType.DECODER causal_attn = attn_type == AttentionType.DECODER
sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore sdpa_start_loc = attn_metadata.sdpa_start_loc.numpy() # type: ignore
...@@ -418,6 +414,7 @@ class CPUAttentionBackendImpl(AttentionImpl): ...@@ -418,6 +414,7 @@ class CPUAttentionBackendImpl(AttentionImpl):
dropout_p=0.0, dropout_p=0.0,
is_causal=causal_attn and mask is None, is_causal=causal_attn and mask is None,
scale=self.scale, scale=self.scale,
enable_gqa=self.num_heads > self.num_kv_heads,
) )
.squeeze(0) .squeeze(0)
.movedim(query.dim() - 2, 0) .movedim(query.dim() - 2, 0)
......
...@@ -115,13 +115,16 @@ def apply_sdpa( ...@@ -115,13 +115,16 @@ def apply_sdpa(
k: torch.Tensor, k: torch.Tensor,
v: torch.Tensor, v: torch.Tensor,
scale: float | None = None, scale: float | None = None,
enable_gqa: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Input shape: Input shape:
(batch_size x seq_len x num_heads x head_size) (batch_size x seq_len x num_heads x head_size)
""" """
q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v]) q, k, v = (einops.rearrange(x, "b s h d -> b h s d") for x in [q, k, v])
output = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, scale=scale) output = F.scaled_dot_product_attention(
q, k, v, dropout_p=0.0, scale=scale, enable_gqa=enable_gqa
)
output = einops.rearrange(output, "b h s d -> b s h d ") output = einops.rearrange(output, "b h s d -> b s h d ")
return output return output
...@@ -134,6 +137,7 @@ def torch_sdpa_wrapper( ...@@ -134,6 +137,7 @@ def torch_sdpa_wrapper(
v: torch.Tensor, v: torch.Tensor,
scale: float | None = None, scale: float | None = None,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
enable_gqa: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
# Never remove the contiguous logic for ROCm # Never remove the contiguous logic for ROCm
# Without it, hallucinations occur with the backend # Without it, hallucinations occur with the backend
...@@ -143,7 +147,7 @@ def torch_sdpa_wrapper( ...@@ -143,7 +147,7 @@ def torch_sdpa_wrapper(
v = v.contiguous() v = v.contiguous()
if cu_seqlens is None: if cu_seqlens is None:
return apply_sdpa(q, k, v, scale=scale) return apply_sdpa(q, k, v, scale=scale, enable_gqa=enable_gqa)
outputs = [] outputs = []
...@@ -152,7 +156,7 @@ def torch_sdpa_wrapper( ...@@ -152,7 +156,7 @@ def torch_sdpa_wrapper(
k_chunks = torch.split(k, lens, dim=1) k_chunks = torch.split(k, lens, dim=1)
v_chunks = torch.split(v, lens, dim=1) v_chunks = torch.split(v, lens, dim=1)
for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks): for q_i, k_i, v_i in zip(q_chunks, k_chunks, v_chunks):
output_i = apply_sdpa(q_i, k_i, v_i, scale=scale) output_i = apply_sdpa(q_i, k_i, v_i, scale=scale, enable_gqa=enable_gqa)
outputs.append(output_i) outputs.append(output_i)
context_layer = torch.cat(outputs, dim=1) context_layer = torch.cat(outputs, dim=1)
return context_layer return context_layer
...@@ -164,6 +168,7 @@ def torch_sdpa_wrapper_fake( ...@@ -164,6 +168,7 @@ def torch_sdpa_wrapper_fake(
v: torch.Tensor, v: torch.Tensor,
scale: float | None, scale: float | None,
cu_seqlens: torch.Tensor | None, cu_seqlens: torch.Tensor | None,
enable_gqa: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.empty_like(q) return torch.empty_like(q)
...@@ -181,5 +186,8 @@ def vit_torch_sdpa_wrapper( ...@@ -181,5 +186,8 @@ def vit_torch_sdpa_wrapper(
v: torch.Tensor, v: torch.Tensor,
scale: float | None = None, scale: float | None = None,
cu_seqlens: torch.Tensor | None = None, cu_seqlens: torch.Tensor | None = None,
enable_gqa: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
return torch.ops.vllm.torch_sdpa_wrapper(q, k, v, scale, cu_seqlens) return torch.ops.vllm.torch_sdpa_wrapper(
q, k, v, scale, cu_seqlens, enable_gqa=enable_gqa
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment