"docs/getting_started/installation/cpu.md" did not exist on "550b2801ad57b8f1e1782037c40fd01b2632435e"
Unverified Commit 24d0c9e6 authored by elvischenv's avatar elvischenv Committed by GitHub
Browse files

[NVIDIA][torch.compile] Support Flashinfer TRTLLM FP8-q/kv NVFP4-out Attention Kernel (#22703)


Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent cc7ae5e7
...@@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -428,6 +428,7 @@ class FlexAttentionImpl(AttentionImpl):
attn_metadata: FlexAttentionMetadata, attn_metadata: FlexAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FLexAttention. """Forward pass with FLexAttention.
...@@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl): ...@@ -441,7 +442,7 @@ class FlexAttentionImpl(AttentionImpl):
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlexAttentionImpl") " for FlexAttentionImpl")
......
...@@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1138,10 +1138,11 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
attn_metadata: M, attn_metadata: M,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for MLACommonImpl") " for MLACommonImpl")
......
...@@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -227,6 +227,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
attn_metadata: PallasMetadata, attn_metadata: PallasMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with Pallas attention. """Forward pass with Pallas attention.
...@@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -239,7 +240,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for PallasAttentionBackendImpl") " for PallasAttentionBackendImpl")
......
...@@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -421,6 +421,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
attn_metadata: AiterFlashAttentionMetadata, attn_metadata: AiterFlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with AiterFlashAttention. """Forward pass with AiterFlashAttention.
...@@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -438,7 +439,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for FlashAttentionImpl") " for FlashAttentionImpl")
......
...@@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl): ...@@ -354,6 +354,7 @@ class TreeAttentionImpl(AttentionImpl):
attn_metadata: TreeAttentionMetadata, attn_metadata: TreeAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with TreeAttention. """Forward pass with TreeAttention.
...@@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl): ...@@ -368,7 +369,7 @@ class TreeAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for TreeAttentionImpl") " for TreeAttentionImpl")
......
...@@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -277,6 +277,7 @@ class TritonAttentionImpl(AttentionImpl):
attn_metadata: FlashAttentionMetadata, attn_metadata: FlashAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with FlashAttention. """Forward pass with FlashAttention.
...@@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -291,7 +292,7 @@ class TritonAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for TritonAttentionImpl") " for TritonAttentionImpl")
......
...@@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl): ...@@ -322,6 +322,7 @@ class XFormersAttentionImpl(AttentionImpl):
attn_metadata: XFormersAttentionMetadata, attn_metadata: XFormersAttentionMetadata,
output: Optional[torch.Tensor] = None, output: Optional[torch.Tensor] = None,
output_scale: Optional[torch.Tensor] = None, output_scale: Optional[torch.Tensor] = None,
output_block_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
"""Forward pass with XFormers. """Forward pass with XFormers.
...@@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl): ...@@ -336,7 +337,7 @@ class XFormersAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None: if output_scale is not None or output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused output quantization is not yet supported"
" for XFormersAttentionImpl") " for XFormersAttentionImpl")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment