Unverified Commit 7d6abdd0 authored by elvischenv's avatar elvischenv Committed by GitHub
Browse files

[Fix] Use torch.empty for output in attention+quant fusion (#31785)


Signed-off-by: default avatarelvischenv <219235043+elvischenv@users.noreply.github.com>
parent a8ff2cca
...@@ -170,9 +170,8 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -170,9 +170,8 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
kv_cache_dummy_dep: torch.Tensor, kv_cache_dummy_dep: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
# attn output in quant_dtype # attn output in quant_dtype
output_attn = torch.ops.aten.full.default( output_attn = torch.empty(
[q.shape[0], self.num_heads, self.head_size], [q.shape[0], self.num_heads, self.head_size],
0.0,
dtype=self.quant_dtype, dtype=self.quant_dtype,
device=q.device, device=q.device,
) )
...@@ -271,9 +270,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -271,9 +270,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
kv_cache_dummy_dep: torch.Tensor, kv_cache_dummy_dep: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]: ) -> tuple[torch.Tensor, torch.Tensor]:
# attention output in quant_dtype # attention output in quant_dtype
output_attn = torch.ops.aten.full.default( output_attn = torch.empty(
[q.shape[0], self.num_heads, self.head_size // 2], [q.shape[0], self.num_heads, self.head_size // 2],
0.0,
dtype=self.quant_dtype, dtype=self.quant_dtype,
device=q.device, device=q.device,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment