Unverified Commit 9a161307 authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[torch.compile][ROCm][V1] Enable attention output FP8 fusion for V1 attention backends (#19767)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Signed-off-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarLuka Govedič <lgovedic@redhat.com>
Co-authored-by: default avatarLuka Govedič <ProExpertProg@users.noreply.github.com>
parent 37e8182b
...@@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None ...@@ -40,13 +40,12 @@ backend_unfused: Optional[TestBackend] = None
@pytest.mark.parametrize( @pytest.mark.parametrize(
"model, quant_key", "model, quant_key",
[("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)]) [("amd/Llama-3.1-8B-Instruct-FP8-KV", kFp8StaticTensorSym)])
@pytest.mark.parametrize( @pytest.mark.parametrize("use_triton_fa", [True, False])
"use_triton_fa", [True, False] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_cuda_alike(), @pytest.mark.skipif(not current_platform.is_rocm(),
reason="Only test CUDA and ROCm") reason="V0 attn quant fusion only on ROCm")
def test_attention_fusion(example_prompts, monkeypatch, model: str, def test_attention_fusion_v0(example_prompts, monkeypatch, model: str,
quant_key: QuantKey, use_triton_fa: bool): quant_key: QuantKey, use_triton_fa: bool):
# Clean Dynamo cache to avoid reusing other test cases # Clean Dynamo cache to avoid reusing other test cases
# (for some reason the reset at the end is not enough) # (for some reason the reset at the end is not enough)
torch._dynamo.reset() torch._dynamo.reset()
...@@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, ...@@ -69,13 +68,17 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend_unfused", backend="tests.compile.test_fusion_attn.backend_unfused",
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
) )
vllm_config = VllmConfig(compilation_config=compile_config) vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
backend_unfused = TestBackend(NoOpEliminationPass(vllm_config)) backend_unfused = TestBackend(NoOpEliminationPass(vllm_config))
llm = LLM(model, llm = LLM(model,
enforce_eager=True, enforce_eager=True,
compilation_config=compile_config, compilation_config=compile_config,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.5,
max_model_len=2048) max_model_len=2048)
sampling_params = SamplingParams(temperature=0.0, sampling_params = SamplingParams(temperature=0.0,
...@@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, ...@@ -93,7 +96,11 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
backend="tests.compile.test_fusion_attn.backend", backend="tests.compile.test_fusion_attn.backend",
custom_ops=["+quant_fp8"], custom_ops=["+quant_fp8"],
) )
vllm_config = VllmConfig(compilation_config=compile_config) vllm_config = VllmConfig(compilation_config=compile_config,
model_config=ModelConfig(
model=model,
dtype=torch.bfloat16,
))
# AttnFusionPass needs attention layers to be registered in config upon init # AttnFusionPass needs attention layers to be registered in config upon init
# so we initialize it during compilation. # so we initialize it during compilation.
...@@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str, ...@@ -102,7 +109,7 @@ def test_attention_fusion(example_prompts, monkeypatch, model: str,
llm2 = LLM(model, llm2 = LLM(model,
enforce_eager=True, enforce_eager=True,
compilation_config=compile_config, compilation_config=compile_config,
gpu_memory_utilization=0.9, gpu_memory_utilization=0.5,
max_model_len=2048) max_model_len=2048)
# check support # check support
...@@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -171,6 +178,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
cache_config=vllm_config.cache_config, cache_config=vllm_config.cache_config,
prefix="model.layers.0.self_attn.attn", prefix="model.layers.0.self_attn.attn",
) )
self.attn._k_scale = self.attn._k_scale.to(device)
self.attn._v_scale = self.attn._v_scale.to(device)
self.block_size = 16 self.block_size = 16
...@@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -188,7 +197,7 @@ class AttentionQuantPatternModel(torch.nn.Module):
device=self.device, device=self.device,
) )
def build_attn_metadata(self, batch_size: int): def build_attn_metadata(self, batch_size: int, use_hnd: bool):
"""Initialize attention metadata.""" """Initialize attention metadata."""
# Create common attn metadata # Create common attn metadata
...@@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -205,10 +214,8 @@ class AttentionQuantPatternModel(torch.nn.Module):
num_blocks = batch_size * max_blocks num_blocks = batch_size * max_blocks
# Create dummy KV cache for FlashInfer TRTLLM # Create dummy KV cache for FlashInfer TRTLLM
# - NHD: [num_blocks, 2, block_size, num_kv_heads, head_size] # - NHD: [num_blocks, block_size, num_kv_heads, head_size]
# - HND: [num_blocks, 2, num_kv_heads, block_size, head_size] # - HND: [num_blocks, num_kv_heads, block_size, head_size]
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = torch.zeros(num_blocks, kv_cache = torch.zeros(num_blocks,
2, 2,
self.num_kv_heads, self.num_kv_heads,
...@@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module): ...@@ -216,7 +223,17 @@ class AttentionQuantPatternModel(torch.nn.Module):
self.head_size, self.head_size,
dtype=self.kv_cache_dtype, dtype=self.kv_cache_dtype,
device=self.device) device=self.device)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4) if current_platform.is_rocm():
# k/v as 1st dimention
if use_hnd:
kv_cache = kv_cache.permute(1, 0, 2, 3, 4)
else:
kv_cache = kv_cache.permute(1, 0, 3, 2, 4)
else:
# k/v as 2nd dimention
# Create kv_cache in HND layout and permute to NHD layout
# (later will be permuted back to HND layout in forward pass)
kv_cache = kv_cache.permute(0, 1, 3, 2, 4)
self.attn.kv_cache = [kv_cache] self.attn.kv_cache = [kv_cache]
# Build attn metadata # Build attn metadata
...@@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel): ...@@ -296,28 +313,51 @@ class TestAttentionNvfp4QuantPatternModel(AttentionQuantPatternModel):
out_dtype=attn_output.dtype) out_dtype=attn_output.dtype)
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", [(64, 8), (40, 8)]) if current_platform.is_cuda():
MODELS = [("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8",
TestAttentionFp8StaticQuantPatternModel),
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4",
TestAttentionNvfp4QuantPatternModel)]
HEADS = [(64, 8), (40, 8)]
elif current_platform.is_rocm():
MODELS = [("amd/Llama-3.1-8B-Instruct-FP8-KV",
TestAttentionFp8StaticQuantPatternModel)]
HEADS = [(32, 8), (40, 8)]
else:
MODELS = []
HEADS = []
@pytest.mark.parametrize("num_qo_heads, num_kv_heads", HEADS)
@pytest.mark.parametrize("head_size", [128]) @pytest.mark.parametrize("head_size", [128])
@pytest.mark.parametrize("batch_size", [7, 256, 533]) @pytest.mark.parametrize("batch_size",
@pytest.mark.parametrize("dtype", [torch.bfloat16]) [7, 256, 533] if current_platform.is_cuda() else [8])
@pytest.mark.parametrize("model_name, model_class", @pytest.mark.parametrize("dtype", [torch.bfloat16, torch.float16])
[("nvidia/Llama-4-Scout-17B-16E-Instruct-FP8", @pytest.mark.parametrize("model_name, model_class", MODELS)
TestAttentionFp8StaticQuantPatternModel), @pytest.mark.parametrize("backend", [_Backend.FLASHINFER] if
("nvidia/Llama-4-Scout-17B-16E-Instruct-FP4", current_platform.is_cuda() else [_Backend.ROCM_FLASH])
TestAttentionNvfp4QuantPatternModel)]) @pytest.mark.parametrize(
@pytest.mark.parametrize("backend", [_Backend.FLASHINFER]) "split_attention",
@pytest.mark.skipif(not current_platform.is_cuda(), reason="Only test CUDA") [False, True] if current_platform.is_rocm() else [False])
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
@pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8") @pytest.mark.skipif(not current_platform.supports_fp8(), reason="Need FP8")
@pytest.mark.skipif(not current_platform.is_device_capability((10, 0)), @pytest.mark.skipif(current_platform.is_cuda()
reason="Only test on SM100(Blackwell)") and not current_platform.is_device_capability((10, 0)),
reason="On CUDA only test on SM100(Blackwell)")
@pytest.mark.skipif(not current_platform.is_cuda_alike(),
reason="Only test ROCm or CUDA")
def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
head_size: int, batch_size: int, head_size: int, batch_size: int,
dtype: torch.dtype, model_name: str, dtype: torch.dtype, model_name: str,
model_class: type[AttentionQuantPatternModel], model_class: type[AttentionQuantPatternModel],
backend: _Backend, monkeypatch, dist_init): backend: _Backend, split_attention: bool,
monkeypatch, dist_init):
"""Test AttentionStaticQuantPattern fusion pass""" """Test AttentionStaticQuantPattern fusion pass"""
monkeypatch.setenv("VLLM_USE_V1", "1") monkeypatch.setenv("VLLM_USE_V1", "1")
if split_attention:
monkeypatch.setenv("VLLM_V1_USE_PREFILL_DECODE_ATTENTION", "1")
device = torch.device("cuda:0") device = torch.device("cuda:0")
torch.manual_seed(42) torch.manual_seed(42)
...@@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, ...@@ -326,6 +366,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
model_config=ModelConfig( model_config=ModelConfig(
model=model_name, model=model_name,
max_model_len=2048, max_model_len=2048,
dtype=dtype,
), ),
scheduler_config=SchedulerConfig(max_num_seqs=1024), scheduler_config=SchedulerConfig(max_num_seqs=1024),
compilation_config=CompilationConfig( compilation_config=CompilationConfig(
...@@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, ...@@ -368,7 +409,7 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_unfused.build_attn_metadata( forward_ctx.attn_metadata = model_unfused.build_attn_metadata(
batch_size) batch_size, use_hnd=split_attention)
# Run model directly without compilation and fusion # Run model directly without compilation and fusion
result_unfused = model_unfused(q, k, v) result_unfused = model_unfused(q, k, v)
...@@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, ...@@ -389,7 +430,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
model_fused = model_fused.to(device) model_fused = model_fused.to(device)
forward_ctx = get_forward_context() forward_ctx = get_forward_context()
forward_ctx.attn_metadata = model_fused.build_attn_metadata(batch_size) forward_ctx.attn_metadata = model_fused.build_attn_metadata(
batch_size, use_hnd=split_attention)
# Create test backend with fusion passes enabled # Create test backend with fusion passes enabled
noop_pass = NoOpEliminationPass(vllm_config) noop_pass = NoOpEliminationPass(vllm_config)
...@@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, ...@@ -404,12 +446,19 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert model_compiled.attn._o_scale_float is None assert model_compiled.attn._o_scale_float is None
result_fused_1 = model_compiled(q, k, v) result_fused_1 = model_compiled(q, k, v)
# After the 1st round of the forward pass, output quant scale should be if backend == _Backend.FLASHINFER:
# loaded into the attn layer's _o_scale_float, the 2nd round should # With the Flashinfer backend after the 1st round of the forward
# reuse the loaded _o_scale_float # pass, output quant scale should be loaded into the attn layer's
assert model_compiled.attn._o_scale_float is not None # _o_scale_float, the 2nd round should reuse the loaded
result_fused_2 = model_compiled(q, k, v) # _o_scale_float
assert model_compiled.attn._o_scale_float is not None assert model_compiled.attn._o_scale_float is not None
result_fused_2 = model_compiled(q, k, v)
assert model_compiled.attn._o_scale_float is not None
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)
# Check attn fusion support # Check attn fusion support
quant_key = model_class.quant_key quant_key = model_class.quant_key
...@@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int, ...@@ -444,12 +493,8 @@ def test_attention_quant_pattern(num_qo_heads: int, num_kv_heads: int,
assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \ assert attn_nodes_post[0].kwargs.get("output_block_scale") is not None, \
"Attention should have output_block_scale after FP4 fusion" # noqa: E501 "Attention should have output_block_scale after FP4 fusion" # noqa: E501
# Check that results are closed # Check that results are close
torch.testing.assert_close(result_unfused, torch.testing.assert_close(result_unfused,
result_fused_1, result_fused_1,
atol=1e-2, atol=1e-2,
rtol=1e-2) rtol=1e-2)
torch.testing.assert_close(result_unfused,
result_fused_2,
atol=1e-2,
rtol=1e-2)
...@@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton ...@@ -15,6 +15,8 @@ from vllm.triton_utils import tl, triton
from .prefix_prefill import context_attention_fwd from .prefix_prefill import context_attention_fwd
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit @triton.jit
def cdiv_fn(x, y): def cdiv_fn(x, y):
...@@ -34,6 +36,7 @@ def kernel_paged_attention_2d( ...@@ -34,6 +36,7 @@ def kernel_paged_attention_2d(
scale, # float32 scale, # float32
k_scale, # float32 k_scale, # float32
v_scale, # float32 v_scale, # float32
out_scale_inv,
num_query_heads: tl.constexpr, # int num_query_heads: tl.constexpr, # int
num_queries_per_kv: tl.constexpr, # int num_queries_per_kv: tl.constexpr, # int
num_queries_per_kv_padded: tl.constexpr, # int num_queries_per_kv_padded: tl.constexpr, # int
...@@ -60,7 +63,9 @@ def kernel_paged_attention_2d( ...@@ -60,7 +63,9 @@ def kernel_paged_attention_2d(
filter_by_query_len: tl.constexpr, # bool filter_by_query_len: tl.constexpr, # bool
query_start_len_ptr, # [num_seqs+1] query_start_len_ptr, # [num_seqs+1]
USE_SINKS: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool
): USE_FP8: tl.constexpr,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
seq_idx = tl.program_id(0) seq_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1) kv_head_idx = tl.program_id(1)
...@@ -204,6 +209,9 @@ def kernel_paged_attention_2d( ...@@ -204,6 +209,9 @@ def kernel_paged_attention_2d(
# epilogue # epilogue
acc = acc / L[:, None] acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (cur_batch_in_all_start_index * output_stride_0 + output_offset = (cur_batch_in_all_start_index * output_stride_0 +
query_head_idx * output_stride_1) query_head_idx * output_stride_1)
...@@ -234,6 +242,7 @@ def chunked_prefill_paged_decode( ...@@ -234,6 +242,7 @@ def chunked_prefill_paged_decode(
alibi_slopes=None, alibi_slopes=None,
sliding_window=None, sliding_window=None,
sm_scale=None, sm_scale=None,
output_scale=None,
# Optional tensor for sinks # Optional tensor for sinks
sinks=None, sinks=None,
): ):
...@@ -266,6 +275,7 @@ def chunked_prefill_paged_decode( ...@@ -266,6 +275,7 @@ def chunked_prefill_paged_decode(
sliding_window=sliding_window, sliding_window=sliding_window,
sm_scale=sm_scale, sm_scale=sm_scale,
skip_decode=True, skip_decode=True,
fp8_out_scale=output_scale,
sinks=sinks, sinks=sinks,
) )
...@@ -316,7 +326,7 @@ def chunked_prefill_paged_decode( ...@@ -316,7 +326,7 @@ def chunked_prefill_paged_decode(
tmp_output = torch.empty( tmp_output = torch.empty(
size=(total_num_seq, num_query_heads, max_num_partitions, size=(total_num_seq, num_query_heads, max_num_partitions,
head_size), head_size),
dtype=output.dtype, dtype=query.dtype,
device=output.device, device=output.device,
) )
exp_sums = torch.empty( exp_sums = torch.empty(
...@@ -345,6 +355,7 @@ def chunked_prefill_paged_decode( ...@@ -345,6 +355,7 @@ def chunked_prefill_paged_decode(
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
k_scale=k_scale, k_scale=k_scale,
v_scale=v_scale, v_scale=v_scale,
fp8_out_scale=output_scale,
) )
else: else:
kernel_paged_attention_2d[( kernel_paged_attention_2d[(
...@@ -362,6 +373,8 @@ def chunked_prefill_paged_decode( ...@@ -362,6 +373,8 @@ def chunked_prefill_paged_decode(
scale=sm_scale, scale=sm_scale,
k_scale=k_scale, k_scale=k_scale,
v_scale=v_scale, v_scale=v_scale,
out_scale_inv=1.0 /
output_scale if output_scale is not None else 1.0,
num_query_heads=num_query_heads, num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv, num_queries_per_kv=num_queries_per_kv,
num_queries_per_kv_padded=num_queries_per_kv_padded, num_queries_per_kv_padded=num_queries_per_kv_padded,
...@@ -388,4 +401,5 @@ def chunked_prefill_paged_decode( ...@@ -388,4 +401,5 @@ def chunked_prefill_paged_decode(
filter_by_query_len=True, filter_by_query_len=True,
query_start_len_ptr=query_start_loc, query_start_len_ptr=query_start_loc,
USE_SINKS=sinks is not None, USE_SINKS=sinks is not None,
USE_FP8=output_scale is not None,
) )
...@@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8 ...@@ -15,6 +15,7 @@ NUM_WARPS = 4 if current_platform.is_rocm() else 8
# To check compatibility # To check compatibility
IS_TURING = current_platform.get_device_capability() == (7, 5) IS_TURING = current_platform.get_device_capability() == (7, 5)
float8_info = torch.finfo(current_platform.fp8_dtype())
# Here's an example autotuner config for this kernel. This config does provide # Here's an example autotuner config for this kernel. This config does provide
...@@ -43,6 +44,7 @@ def _fwd_kernel(Q, ...@@ -43,6 +44,7 @@ def _fwd_kernel(Q,
sm_scale, sm_scale,
k_scale, k_scale,
v_scale, v_scale,
out_scale_inv,
B_Start_Loc, B_Start_Loc,
B_Seqlen, B_Seqlen,
x: tl.constexpr, x: tl.constexpr,
...@@ -82,8 +84,11 @@ def _fwd_kernel(Q, ...@@ -82,8 +84,11 @@ def _fwd_kernel(Q,
num_unroll_request: tl.constexpr, num_unroll_request: tl.constexpr,
SKIP_DECODE: tl.constexpr, SKIP_DECODE: tl.constexpr,
USE_SINKS: tl.constexpr, USE_SINKS: tl.constexpr,
USE_FP8: tl.constexpr,
MAX_Q_LEN: tl.constexpr = 0, MAX_Q_LEN: tl.constexpr = 0,
MAX_CTX_LEN: tl.constexpr = 0): MAX_CTX_LEN: tl.constexpr = 0,
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max):
cur_batch = tl.program_id(0) cur_batch = tl.program_id(0)
cur_head = tl.program_id(1) cur_head = tl.program_id(1)
...@@ -284,6 +289,9 @@ def _fwd_kernel(Q, ...@@ -284,6 +289,9 @@ def _fwd_kernel(Q,
off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + off_o = ((cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs +
cur_head * stride_oh + offs_d[None, :] * stride_od) cur_head * stride_oh + offs_d[None, :] * stride_od)
out_ptrs = Out + off_o out_ptrs = Out + off_o
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
tl.store(out_ptrs, tl.store(out_ptrs,
acc, acc,
mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len)) mask=dim_mask[None, :] & (offs_m[:, None] < cur_batch_query_len))
...@@ -743,6 +751,7 @@ def context_attention_fwd(q, ...@@ -743,6 +751,7 @@ def context_attention_fwd(q,
sliding_window=None, sliding_window=None,
sm_scale=None, sm_scale=None,
skip_decode=False, skip_decode=False,
fp8_out_scale=None,
sinks=None): sinks=None):
q_dtype_is_f32 = q.dtype is torch.float32 q_dtype_is_f32 = q.dtype is torch.float32
...@@ -793,6 +802,7 @@ def context_attention_fwd(q, ...@@ -793,6 +802,7 @@ def context_attention_fwd(q,
if alibi_slopes is not None: if alibi_slopes is not None:
assert sinks is None, "Sinks arg is not supported with alibi" assert sinks is None, "Sinks arg is not supported with alibi"
assert fp8_out_scale is None, "FP8 output not supported with alibi"
# need to reduce num. blocks when using fp32 # need to reduce num. blocks when using fp32
# due to increased use of GPU shared memory # due to increased use of GPU shared memory
# if q.dtype is torch.float32: # if q.dtype is torch.float32:
...@@ -870,6 +880,7 @@ def context_attention_fwd(q, ...@@ -870,6 +880,7 @@ def context_attention_fwd(q,
sm_scale, sm_scale,
k_scale, k_scale,
v_scale, v_scale,
1.0 / fp8_out_scale if fp8_out_scale is not None else 1.0,
b_start_loc, b_start_loc,
b_seq_len, b_seq_len,
k_cache.shape[4], k_cache.shape[4],
...@@ -905,6 +916,7 @@ def context_attention_fwd(q, ...@@ -905,6 +916,7 @@ def context_attention_fwd(q,
BLOCK_DMODEL_PADDED=Lk_padded, BLOCK_DMODEL_PADDED=Lk_padded,
SLIDING_WINDOW=sliding_window, SLIDING_WINDOW=sliding_window,
SKIP_DECODE=skip_decode, SKIP_DECODE=skip_decode,
USE_FP8=fp8_out_scale is not None,
BLOCK_M=128, BLOCK_M=128,
BLOCK_N=64, BLOCK_N=64,
num_unroll_cache=4, num_unroll_cache=4,
......
...@@ -10,9 +10,11 @@ ...@@ -10,9 +10,11 @@
import torch import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
logger = init_logger(__name__) logger = init_logger(__name__)
float8_info = torch.finfo(current_platform.fp8_dtype())
@triton.jit @triton.jit
...@@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs, ...@@ -48,47 +50,51 @@ def find_seq_idx(query_start_len_ptr, target_idx, num_seqs,
@triton.jit @triton.jit
def kernel_unified_attention_2d( def kernel_unified_attention_2d(
output_ptr, # [num_tokens, num_query_heads, head_size] output_ptr, # [num_tokens, num_query_heads, head_size]
query_ptr, # [num_tokens, num_query_heads, head_size] query_ptr, # [num_tokens, num_query_heads, head_size]
key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] key_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size] value_cache_ptr, # [num_blks, blk_size, num_kv_heads, head_size]
sink_ptr, # [num_query_heads] sink_ptr, # [num_query_heads]
block_tables_ptr, # [num_seqs, max_num_blocks_per_seq] block_tables_ptr, # [num_seqs, max_num_blocks_per_seq]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
alibi_slopes_ptr, # [num_query_heads] alibi_slopes_ptr, # [num_query_heads]
qq_bias_ptr, # [num_query_tokens, num_query_tokens] qq_bias_ptr, # [num_query_tokens, num_query_tokens]
scale, # float32 scale, # float32
k_scale, # float32 k_scale, # float32
v_scale, # float32 v_scale, # float32
softcap, # float32 out_scale, # float32
num_query_heads: tl.constexpr, # int softcap, # float32
num_queries_per_kv: tl.constexpr, # int num_query_heads: tl.constexpr, # int
block_table_stride: tl.int64, # int num_queries_per_kv: tl.constexpr, # int
query_stride_0: tl.int64, # int block_table_stride: tl.int64, # int
query_stride_1: tl.int64, # int, should be equal to head_size query_stride_0: tl.int64, # int
output_stride_0: tl.int64, # int query_stride_1: tl.int64, # int, should be equal to head_size
output_stride_1: tl.int64, # int, should be equal to head_size output_stride_0: tl.int64, # int
qq_bias_stride_0: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int qq_bias_stride_0: tl.int64, # int
HEAD_SIZE: tl.constexpr, # int BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int
USE_ALIBI_SLOPES: tl.constexpr, # bool HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
USE_QQ_BIAS: tl.constexpr, # bool USE_ALIBI_SLOPES: tl.constexpr, # bool
USE_SOFTCAP: tl.constexpr, # bool USE_QQ_BIAS: tl.constexpr, # bool
USE_SINKS: tl.constexpr, # bool USE_SOFTCAP: tl.constexpr, # bool
SLIDING_WINDOW: tl.constexpr, # int USE_SINKS: tl.constexpr, # bool
stride_k_cache_0: tl.int64, # int SLIDING_WINDOW: tl.constexpr, # int
stride_k_cache_1: tl.int64, # int stride_k_cache_0: tl.int64, # int
stride_k_cache_2: tl.int64, # int stride_k_cache_1: tl.int64, # int
stride_k_cache_3: tl.constexpr, # int stride_k_cache_2: tl.int64, # int
stride_v_cache_0: tl.int64, # int stride_k_cache_3: tl.constexpr, # int
stride_v_cache_1: tl.int64, # int stride_v_cache_0: tl.int64, # int
stride_v_cache_2: tl.int64, # int stride_v_cache_1: tl.int64, # int
stride_v_cache_3: tl.constexpr, # int stride_v_cache_2: tl.int64, # int
query_start_len_ptr, # [num_seqs+1] stride_v_cache_3: tl.constexpr, # int
BLOCK_Q: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1]
num_seqs: tl.int32, BLOCK_Q: tl.constexpr, # int
BLOCK_M: tl.constexpr, # int num_seqs: tl.int32,
BLOCK_M: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
): ):
q_block_global_idx = tl.program_id(0) q_block_global_idx = tl.program_id(0)
kv_head_idx = tl.program_id(1) kv_head_idx = tl.program_id(1)
...@@ -281,6 +287,9 @@ def kernel_unified_attention_2d( ...@@ -281,6 +287,9 @@ def kernel_unified_attention_2d(
# epilogue # epilogue
acc = acc / L[:, None] acc = acc / L[:, None]
if USE_FP8:
acc = acc * tl.load(out_scale)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
output_offset = (query_offset_0[:, None] * output_stride_0 + output_offset = (query_offset_0[:, None] * output_stride_0 +
query_offset_1[:, None] * output_stride_1 + query_offset_1[:, None] * output_stride_1 +
...@@ -552,23 +561,27 @@ def kernel_unified_attention_3d( ...@@ -552,23 +561,27 @@ def kernel_unified_attention_3d(
@triton.jit @triton.jit
def reduce_segments( def reduce_segments(
output_ptr, # [num_tokens, num_query_heads, head_size] output_ptr, # [num_tokens, num_query_heads, head_size]
segm_output_ptr, segm_output_ptr,
#[num_tokens, num_query_heads, max_num_segments, head_size] #[num_tokens, num_query_heads, max_num_segments, head_size]
segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_max_ptr, # [num_tokens, num_query_heads, max_num_segments]
segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments] segm_expsum_ptr, # [num_tokens, num_query_heads, max_num_segments]
seq_lens_ptr, # [num_seqs] seq_lens_ptr, # [num_seqs]
num_seqs, # int num_seqs, # int
num_query_heads: tl.constexpr, # int num_query_heads: tl.constexpr, # int
output_stride_0: tl.int64, # int out_scale_inv, # float32
output_stride_1: tl.int64, # int, should be equal to head_size output_stride_0: tl.int64, # int
block_table_stride: tl.int64, # int output_stride_1: tl.int64, # int, should be equal to head_size
BLOCK_SIZE: tl.constexpr, # int block_table_stride: tl.int64, # int
HEAD_SIZE: tl.constexpr, # int, must be power of 2 BLOCK_SIZE: tl.constexpr, # int
HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2 HEAD_SIZE: tl.constexpr, # int, must be power of 2
query_start_len_ptr, # [num_seqs+1] HEAD_SIZE_PADDED: tl.constexpr, # int, must be power of 2
BLOCK_Q: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1]
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int BLOCK_Q: tl.constexpr, # int
NUM_SEGMENTS_PER_SEQ: tl.constexpr, # int
USE_FP8: tl.constexpr, # bool
FP8_MIN: tl.constexpr = float8_info.min,
FP8_MAX: tl.constexpr = float8_info.max,
): ):
query_token_idx = tl.program_id(0) query_token_idx = tl.program_id(0)
query_head_idx = tl.program_id(1) query_head_idx = tl.program_id(1)
...@@ -624,6 +637,10 @@ def reduce_segments( ...@@ -624,6 +637,10 @@ def reduce_segments(
# safely divide by overall_expsum, returning 0.0 if overall_expsum is 0 # safely divide by overall_expsum, returning 0.0 if overall_expsum is 0
acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum) acc = tl.where(overall_expsum == 0.0, 0.0, acc_sum / overall_expsum)
if USE_FP8:
acc = acc * tl.load(out_scale_inv)
acc = tl.clamp(acc, FP8_MIN, FP8_MAX)
# write result # write result
output_offset = (query_token_idx * output_stride_0 + output_offset = (query_token_idx * output_stride_0 +
query_head_idx * output_stride_1 + query_head_idx * output_stride_1 +
...@@ -649,6 +666,7 @@ def unified_attention( ...@@ -649,6 +666,7 @@ def unified_attention(
k_descale, k_descale,
v_descale, v_descale,
alibi_slopes=None, alibi_slopes=None,
output_scale=None,
qq_bias=None, qq_bias=None,
# Optional tensor for sinks # Optional tensor for sinks
sinks=None, sinks=None,
...@@ -707,6 +725,7 @@ def unified_attention( ...@@ -707,6 +725,7 @@ def unified_attention(
scale=softmax_scale, scale=softmax_scale,
k_scale=k_descale, k_scale=k_descale,
v_scale=v_descale, v_scale=v_descale,
out_scale=1 / output_scale if output_scale is not None else 1.0,
softcap=softcap, softcap=softcap,
num_query_heads=num_query_heads, num_query_heads=num_query_heads,
num_queries_per_kv=num_queries_per_kv, num_queries_per_kv=num_queries_per_kv,
...@@ -736,6 +755,7 @@ def unified_attention( ...@@ -736,6 +755,7 @@ def unified_attention(
BLOCK_Q=BLOCK_Q, BLOCK_Q=BLOCK_Q,
num_seqs=num_seqs, num_seqs=num_seqs,
BLOCK_M=BLOCK_M, BLOCK_M=BLOCK_M,
USE_FP8=output_scale is not None,
) )
else: else:
# for initial version, NUM_SEGMENTS = 16 is chosen as a default # for initial version, NUM_SEGMENTS = 16 is chosen as a default
...@@ -819,6 +839,8 @@ def unified_attention( ...@@ -819,6 +839,8 @@ def unified_attention(
seq_lens_ptr=seqused_k, seq_lens_ptr=seqused_k,
num_seqs=num_seqs, num_seqs=num_seqs,
num_query_heads=num_query_heads, num_query_heads=num_query_heads,
out_scale_inv=1 /
output_scale if output_scale is not None else 1.0,
output_stride_0=out.stride(0), output_stride_0=out.stride(0),
output_stride_1=out.stride(1), output_stride_1=out.stride(1),
block_table_stride=block_table.stride(0), block_table_stride=block_table.stride(0),
...@@ -828,4 +850,5 @@ def unified_attention( ...@@ -828,4 +850,5 @@ def unified_attention(
query_start_len_ptr=cu_seqlens_q, query_start_len_ptr=cu_seqlens_q,
BLOCK_Q=BLOCK_Q, BLOCK_Q=BLOCK_Q,
NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS, NUM_SEGMENTS_PER_SEQ=NUM_SEGMENTS,
USE_FP8=output_scale is not None,
) )
...@@ -454,11 +454,12 @@ class VllmBackend: ...@@ -454,11 +454,12 @@ class VllmBackend:
inductor_config = config.inductor_compile_config inductor_config = config.inductor_compile_config
PASS_KEY = "post_grad_custom_post_pass" PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config: if PASS_KEY in inductor_config:
# Config should automatically wrap all inductor passes
if isinstance(inductor_config[PASS_KEY], PostGradPassManager): if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
# PassManager already added to config, make sure it's correct
assert (inductor_config[PASS_KEY].uuid() == assert (inductor_config[PASS_KEY].uuid() ==
self.post_grad_pass_manager.uuid()) self.post_grad_pass_manager.uuid())
else: else:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass) assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager inductor_config[PASS_KEY] = self.post_grad_pass_manager
......
...@@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC): ...@@ -39,6 +39,7 @@ class AttentionQuantPattern(ABC):
self, self,
layer: Attention, layer: Attention,
quant_key: QuantKey, quant_key: QuantKey,
dtype: torch.dtype,
): ):
self.layer = layer self.layer = layer
self.layer_name = layer.layer_name self.layer_name = layer.layer_name
...@@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC): ...@@ -46,11 +47,16 @@ class AttentionQuantPattern(ABC):
self.head_size = layer.head_size self.head_size = layer.head_size
self.quant_key = quant_key self.quant_key = quant_key
self.quant_dtype = quant_key.dtype self.quant_dtype = quant_key.dtype
self.dtype = dtype
assert self.quant_key in QUANT_OPS, \ assert self.quant_key in QUANT_OPS, \
f"unsupported quantization scheme {self.quant_key}" f"unsupported quantization scheme {self.quant_key}"
self.QUANT_OP = QUANT_OPS[self.quant_key] self.QUANT_OP = QUANT_OPS[self.quant_key]
def empty(self, *args, **kwargs):
kwargs = {'dtype': self.dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs)
def empty_quant(self, *args, **kwargs): def empty_quant(self, *args, **kwargs):
kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs} kwargs = {'dtype': self.quant_dtype, 'device': "cuda", **kwargs}
return torch.empty(*args, **kwargs) return torch.empty(*args, **kwargs)
...@@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -91,12 +97,13 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
def __init__( def __init__(
self, self,
layer: Attention, layer: Attention,
dtype: torch.dtype,
symmetric: bool = True, symmetric: bool = True,
): ):
quant_key = QuantKey(dtype=FP8_DTYPE, quant_key = QuantKey(dtype=FP8_DTYPE,
scale=kStaticTensorScale, scale=kStaticTensorScale,
symmetric=symmetric) symmetric=symmetric)
super().__init__(layer, quant_key) super().__init__(layer, quant_key, dtype)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass):
...@@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern): ...@@ -139,10 +146,14 @@ class AttentionFp8StaticQuantPattern(AttentionQuantPattern):
return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size]) return RESHAPE_OP(at1[1], [-1, self.num_heads * self.head_size])
inputs = [ inputs = [
empty_bf16(5, self.num_heads, self.head_size), # q self.empty(5, self.num_heads, self.head_size,
empty_bf16(5, self.num_heads, self.head_size), # k dtype=self.dtype), # q
empty_bf16(5, self.num_heads, self.head_size), # v self.empty(5, self.num_heads, self.head_size,
empty_bf16(5, self.num_heads, self.head_size), # attn_output dtype=self.dtype), # k
self.empty(5, self.num_heads, self.head_size,
dtype=self.dtype), # v
self.empty(5, self.num_heads, self.head_size,
dtype=self.dtype), # attn_output
self.empty_quant(5, self.empty_quant(5,
self.num_heads * self.head_size), # quant_output self.num_heads * self.head_size), # quant_output
empty_fp32(1, 1) # scale empty_fp32(1, 1) # scale
...@@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern): ...@@ -165,8 +176,8 @@ class AttentionNvfp4QuantPattern(AttentionQuantPattern):
will be passed into Attention op as the `output_scale` argument. will be passed into Attention op as the `output_scale` argument.
""" """
def __init__(self, layer: Attention): def __init__(self, layer: Attention, dtype: torch.dtype):
super().__init__(layer, kNvfp4Quant) super().__init__(layer, kNvfp4Quant, dtype)
def _register(self, pm_pass: PatternMatcherPass): def _register(self, pm_pass: PatternMatcherPass):
...@@ -255,12 +266,14 @@ class AttnFusionPass(VllmInductorPass): ...@@ -255,12 +266,14 @@ class AttnFusionPass(VllmInductorPass):
attn_layers = get_layers_from_vllm_config(config, Attention) attn_layers = get_layers_from_vllm_config(config, Attention)
for layer_name, layer in attn_layers.items(): for layer_name, layer in attn_layers.items():
pattern_fp8 = AttentionFp8StaticQuantPattern(layer) pattern_fp8 = AttentionFp8StaticQuantPattern(
layer, config.model_config.dtype)
pattern_fp8.register_if_supported(self.patterns) pattern_fp8.register_if_supported(self.patterns)
if current_platform.is_cuda() and hasattr(torch.ops._C, if current_platform.is_cuda() and hasattr(torch.ops._C,
"scaled_fp4_quant"): "scaled_fp4_quant"):
pattern_nvfp4 = AttentionNvfp4QuantPattern(layer) pattern_nvfp4 = AttentionNvfp4QuantPattern(
layer, config.model_config.dtype)
pattern_nvfp4.register_if_supported(self.patterns) pattern_nvfp4.register_if_supported(self.patterns)
if len(attn_layers) == 0: if len(attn_layers) == 0:
......
...@@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor, ...@@ -171,10 +171,12 @@ def flashinfer_w8a8_scaled_mm(*, qinput: torch.Tensor, weight: torch.Tensor,
bias=bias) bias=bias)
def rocm_per_tensor_w8a8_scaled_mm_impl( def rocm_per_tensor_w8a8_scaled_mm_impl(qinput: torch.Tensor,
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, weight: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, out_dtype: torch.dtype,
input_2d: torch.Tensor) -> torch.Tensor: scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
from vllm.platforms.rocm import on_mi3xx from vllm.platforms.rocm import on_mi3xx
if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx( if envs.VLLM_ROCM_USE_SKINNY_GEMM and on_mi3xx(
) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0: ) and qinput.shape[0] == 1 and qinput.shape[1] % 16 == 0:
...@@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl( ...@@ -190,10 +192,12 @@ def rocm_per_tensor_w8a8_scaled_mm_impl(
return output return output
def rocm_per_tensor_w8a8_scaled_mm_fake( def rocm_per_tensor_w8a8_scaled_mm_fake(qinput: torch.Tensor,
qinput: torch.Tensor, weight: torch.Tensor, out_dtype: torch.dtype, weight: torch.Tensor,
scale_a: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor, out_dtype: torch.dtype,
input_2d: torch.Tensor) -> torch.Tensor: scale_a: torch.Tensor,
scale_b: torch.Tensor,
bias: torch.Tensor) -> torch.Tensor:
return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]), return qinput.new_empty((*qinput.shape[:-1], weight.shape[1]),
dtype=out_dtype) dtype=out_dtype)
...@@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -203,11 +207,10 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor: output_shape: list) -> torch.Tensor:
output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl( output = torch.ops.vllm.rocm_per_tensor_w8a8_scaled_mm_impl(
qinput, weight, out_dtype, scale_a, scale_b, bias, input_2d) qinput, weight, out_dtype, scale_a, scale_b, bias)
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
direct_register_custom_op( direct_register_custom_op(
...@@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -224,7 +227,6 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list) -> torch.Tensor: output_shape: list) -> torch.Tensor:
output = torch._scaled_mm(qinput, output = torch._scaled_mm(qinput,
weight, weight,
...@@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -237,7 +239,7 @@ def torch_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
output = output[0] output = output[0]
return torch.narrow(output, 0, 0, input_2d.shape[0]).view(*output_shape) return torch.narrow(output, 0, 0, qinput.shape[0]).view(*output_shape)
def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
...@@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -245,7 +247,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor, output_shape: list, output_shape: list,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
# Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM # Note: Callers of this function should check USE_ROWWISE_TORCH_SCALED_MM
# when using it. # when using it.
...@@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -265,7 +267,7 @@ def torch_per_token_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b=scale_b.t(), scale_b=scale_b.t(),
bias=bias) bias=bias)
output = torch.narrow(output, 0, 0, input_2d.shape[0]) output = torch.narrow(output, 0, 0, qinput.shape[0])
output = output.view(*output_shape) output = output.view(*output_shape)
return output return output
...@@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -275,7 +277,6 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
out_dtype: torch.dtype, out_dtype: torch.dtype,
scale_a: torch.Tensor, scale_a: torch.Tensor,
scale_b: torch.Tensor, bias: torch.Tensor, scale_b: torch.Tensor, bias: torch.Tensor,
input_2d: torch.Tensor,
output_shape: list, output_shape: list,
**kwargs) -> torch.Tensor: **kwargs) -> torch.Tensor:
# Use unfused DQ due to limitations with scaled_mm # Use unfused DQ due to limitations with scaled_mm
...@@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor, ...@@ -305,8 +306,8 @@ def torch_channelwise_w8a8_scaled_mm(*, qinput: torch.Tensor,
if type(output) is tuple and len(output) == 2: if type(output) is tuple and len(output) == 2:
output = output[0] output = output[0]
# Unpad (undo num_token_padding) # Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input_2d.shape[0]) output = torch.narrow(output, 0, 0, qinput.shape[0])
x_scale = torch.narrow(scale_a, 0, 0, input_2d.shape[0]) x_scale = torch.narrow(scale_a, 0, 0, qinput.shape[0])
# DQ # DQ
# C = sw * sx * (X * W) + bias # C = sw * sx * (X * W) + bias
...@@ -430,7 +431,6 @@ class Fp8LinearOp: ...@@ -430,7 +431,6 @@ class Fp8LinearOp:
scale_a=x_scale, scale_a=x_scale,
scale_b=weight_scale, scale_b=weight_scale,
bias=bias, bias=bias,
input_2d=input_2d,
output_shape=output_shape) output_shape=output_shape)
......
...@@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import ( ...@@ -15,6 +15,8 @@ from vllm.attention.ops.chunked_prefill_paged_decode import (
from vllm.attention.ops.paged_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, kFp8StaticTensorSym)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionCGSupport, from vllm.v1.attention.backends.utils import (AttentionCGSupport,
...@@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool: ...@@ -202,6 +204,9 @@ def use_aiter_unified_attention() -> bool:
class TritonAttentionImpl(AttentionImpl): class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym
def __init__( def __init__(
self, self,
num_heads: int, num_heads: int,
...@@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -297,9 +302,9 @@ class TritonAttentionImpl(AttentionImpl):
""" """
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
if output_scale is not None or output_block_scale is not None: if output_block_scale is not None:
raise NotImplementedError( raise NotImplementedError(
"fused output quantization is not yet supported" "fused block_scale output quantization is not yet supported"
" for TritonAttentionImpl") " for TritonAttentionImpl")
if attn_metadata is None: if attn_metadata is None:
...@@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -394,6 +399,7 @@ class TritonAttentionImpl(AttentionImpl):
alibi_slopes=self.alibi_slopes, alibi_slopes=self.alibi_slopes,
sliding_window=self.sliding_window[0], sliding_window=self.sliding_window[0],
sm_scale=self.scale, sm_scale=self.scale,
output_scale=output_scale,
sinks=self.sinks, sinks=self.sinks,
) )
...@@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -419,6 +425,6 @@ class TritonAttentionImpl(AttentionImpl):
k_descale=layer._k_scale.expand(descale_shape), k_descale=layer._k_scale.expand(descale_shape),
v_descale=layer._v_scale.expand(descale_shape), v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks, sinks=self.sinks,
) output_scale=output_scale)
return output return output
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment