Unverified Commit e97f802b authored by Gregory Shtrasberg's avatar Gregory Shtrasberg Committed by GitHub
Browse files

[FP8][Kernel] Dynamic kv cache scaling factors computation (#11906)


Signed-off-by: default avatarGregory Shtrasberg <Gregory.Shtrasberg@amd.com>
Co-authored-by: default avatarMicah Williamson <micah.williamson@amd.com>
parent 6e650f56
{
"model_type": "llama",
"kv_cache": {
"dtype": "float8_e4m3fn",
"scaling_factor": {
"0": {
"0": 0.0152239128947258,
"1": 0.0188860222697258,
"2": 0.0354178324341774,
"3": 0.0376674123108387,
"4": 0.0418526791036129,
"5": 0.0433175228536129,
"6": 0.0397600457072258,
"7": 0.0424455925822258,
"8": 0.0415387861430645,
"9": 0.0408412404358387,
"10": 0.0395856611430645,
"11": 0.0377371683716774,
"12": 0.0400739423930645,
"13": 0.040771484375,
"14": 0.0393415205180645,
"15": 0.0369001142680645,
"16": 0.03857421875,
"17": 0.0387486070394516,
"18": 0.0403180830180645,
"19": 0.0396205373108387,
"20": 0.0375627800822258,
"21": 0.0407366082072258,
"22": 0.0432477705180645,
"23": 0.0377022884786129,
"24": 0.0399693101644516,
"25": 0.0374581478536129,
"26": 0.0413295216858387,
"27": 0.0442243330180645,
"28": 0.0424804724752903,
"29": 0.0456891767680645,
"30": 0.0409109964966774,
"31": 0.0482352152466774
}
}
}
}
...@@ -182,7 +182,7 @@ def test_paged_attention( ...@@ -182,7 +182,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = 1.0 k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the paged attention kernel. # Call the paged attention kernel.
output = torch.empty_like(query) output = torch.empty_like(query)
......
...@@ -210,7 +210,7 @@ def test_paged_attention( ...@@ -210,7 +210,7 @@ def test_paged_attention(
key_cache, value_cache = key_caches[0], value_caches[0] key_cache, value_cache = key_caches[0], value_caches[0]
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = 1.0 k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
tp_rank = 0 tp_rank = 0
# Call the paged attention kernel. # Call the paged attention kernel.
......
...@@ -160,7 +160,7 @@ def test_reshape_and_cache( ...@@ -160,7 +160,7 @@ def test_reshape_and_cache(
cloned_value_cache = value_cache.clone() cloned_value_cache = value_cache.clone()
# Using default kv_scale # Using default kv_scale
k_scale = v_scale = 1.0 k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Call the reshape_and_cache kernel. # Call the reshape_and_cache kernel.
opcheck(torch.ops._C_cache_ops.reshape_and_cache, opcheck(torch.ops._C_cache_ops.reshape_and_cache,
...@@ -258,8 +258,8 @@ def test_reshape_and_cache_flash( ...@@ -258,8 +258,8 @@ def test_reshape_and_cache_flash(
del key_caches del key_caches
del value_caches del value_caches
k_scale = key.amax().item() / 256 k_scale = (key.amax() / 256.0).to(torch.float32)
v_scale = value.amax().item() / 256 v_scale = (value.amax() / 256.0).to(torch.float32)
# Clone the KV caches. # Clone the KV caches.
if kv_cache_dtype == "fp8": if kv_cache_dtype == "fp8":
...@@ -284,12 +284,12 @@ def test_reshape_and_cache_flash( ...@@ -284,12 +284,12 @@ def test_reshape_and_cache_flash(
result_key_cache = torch.empty_like(key_cache, dtype=torch.float16) result_key_cache = torch.empty_like(key_cache, dtype=torch.float16)
ops.convert_fp8(result_key_cache, ops.convert_fp8(result_key_cache,
key_cache, key_cache,
k_scale, k_scale.item(),
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
result_value_cache = torch.empty_like(value_cache, dtype=torch.float16) result_value_cache = torch.empty_like(value_cache, dtype=torch.float16)
ops.convert_fp8(result_value_cache, ops.convert_fp8(result_value_cache,
value_cache, value_cache,
v_scale, v_scale.item(),
kv_dtype=kv_cache_dtype) kv_dtype=kv_cache_dtype)
# Run the reference implementation. # Run the reference implementation.
......
...@@ -138,6 +138,7 @@ def test_contexted_kv_attention( ...@@ -138,6 +138,7 @@ def test_contexted_kv_attention(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads, v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous() head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring # Warm up the Triton kernel by calling it once before actually measuring
# generation time # generation time
...@@ -153,6 +154,8 @@ def test_contexted_kv_attention( ...@@ -153,6 +154,8 @@ def test_contexted_kv_attention(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window) sliding_window=sliding_window)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
...@@ -168,6 +171,8 @@ def test_contexted_kv_attention( ...@@ -168,6 +171,8 @@ def test_contexted_kv_attention(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
sliding_window=sliding_window) sliding_window=sliding_window)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
...@@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi( ...@@ -366,6 +371,7 @@ def test_contexted_kv_attention_alibi(
# to V_cache[num_blocks, num_kv_heads, head_size, block_size] # to V_cache[num_blocks, num_kv_heads, head_size, block_size]
v_cache = v_cache.view(-1, block_size, num_kv_heads, v_cache = v_cache.view(-1, block_size, num_kv_heads,
head_size).permute(0, 2, 3, 1).contiguous() head_size).permute(0, 2, 3, 1).contiguous()
k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device)
# Warm up the Triton kernel by calling it once before actually measuring # Warm up the Triton kernel by calling it once before actually measuring
# generation time # generation time
...@@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi( ...@@ -381,6 +387,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
torch.cuda.synchronize() torch.cuda.synchronize()
start_time = time.time() start_time = time.time()
...@@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi( ...@@ -396,6 +404,8 @@ def test_contexted_kv_attention_alibi(
b_seq_len, b_seq_len,
b_ctx_len, b_ctx_len,
max_input_len, max_input_len,
k_scale,
v_scale,
alibi_slopes=alibi_slopes) alibi_slopes=alibi_slopes)
torch.cuda.synchronize() torch.cuda.synchronize()
end_time = time.time() end_time = time.time()
......
...@@ -909,6 +909,7 @@ def make_test_metadata( ...@@ -909,6 +909,7 @@ def make_test_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping), slot_mapping=(None if kv_mmap is None else kv_mmap.slot_mapping),
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -958,6 +959,7 @@ def make_test_metadata( ...@@ -958,6 +959,7 @@ def make_test_metadata(
num_prefills=num_prefills, num_prefills=num_prefills,
slot_mapping=kv_mmap.slot_mapping, slot_mapping=kv_mmap.slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
num_prefill_tokens=num_prefill_tokens, num_prefill_tokens=num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
......
...@@ -19,18 +19,17 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true" ...@@ -19,18 +19,17 @@ os.environ["TOKENIZERS_PARALLELISM"] = "true"
@pytest.mark.skipif(not is_quant_method_supported("fp8"), @pytest.mark.skipif(not is_quant_method_supported("fp8"),
reason="fp8 is not supported on this GPU type.") reason="fp8 is not supported on this GPU type.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"kv_cache_dtype,base_model,test_model,scale_path", "kv_cache_dtype,base_model,test_model",
[ [
# Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors. # Test FP8 checkpoint w. fp8_e4m3 kv-cache scaling factors.
("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct", ("fp8_e4m3", "meta-llama/Llama-3.2-1B-Instruct",
"nm-testing/Llama-3.2-1B-Instruct-FP8-KV", None), "nm-testing/Llama-3.2-1B-Instruct-FP8-KV"),
# Test FP16 checkpoint w. fp8_e5m2 kv-cache. # Test FP16 checkpoint w. fp8_e5m2 kv-cache.
("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct", ("fp8_e5m2", "meta-llama/Llama-3.2-1B-Instruct",
"meta-llama/Llama-3.2-1B-Instruct", None), "meta-llama/Llama-3.2-1B-Instruct"),
# Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json. # Test FP16 checkpoint w. fp8_e4m3 kv-cache scaling factors in json.
("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf", ("fp8_e4m3", "meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-7b-chat-hf", "meta-llama/Llama-2-7b-chat-hf")
"./tests/fp8_kv/llama2-7b-fp8-kv/kv_cache_scales.json")
]) ])
# Due to low-precision numerical divergence, we only test logprob of 4 tokens # Due to low-precision numerical divergence, we only test logprob of 4 tokens
@pytest.mark.parametrize("max_tokens", [4]) @pytest.mark.parametrize("max_tokens", [4])
...@@ -48,7 +47,6 @@ def test_models( ...@@ -48,7 +47,6 @@ def test_models(
kv_cache_dtype: str, kv_cache_dtype: str,
base_model: str, base_model: str,
test_model: str, test_model: str,
scale_path: Optional[str],
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
backend: str, backend: str,
...@@ -76,10 +74,6 @@ def test_models( ...@@ -76,10 +74,6 @@ def test_models(
baseline_outputs = vllm_model.generate_greedy_logprobs( baseline_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS) example_prompts, max_tokens, NUM_LOG_PROBS)
extra_kwargs = {}
if scale_path is not None:
extra_kwargs["quantization_param_path"] = scale_path
with vllm_runner( with vllm_runner(
test_model, test_model,
max_model_len=MAX_MODEL_LEN, max_model_len=MAX_MODEL_LEN,
...@@ -87,7 +81,6 @@ def test_models( ...@@ -87,7 +81,6 @@ def test_models(
enforce_eager=enforce_eager, enforce_eager=enforce_eager,
kv_cache_dtype=kv_cache_dtype, kv_cache_dtype=kv_cache_dtype,
disable_async_output_proc=disable_async_output_proc, disable_async_output_proc=disable_async_output_proc,
**extra_kwargs,
) as vllm_model: ) as vllm_model:
test_outputs = vllm_model.generate_greedy_logprobs( test_outputs = vllm_model.generate_greedy_logprobs(
example_prompts, max_tokens, NUM_LOG_PROBS) example_prompts, max_tokens, NUM_LOG_PROBS)
......
...@@ -74,6 +74,7 @@ def test_model_runner_input(): ...@@ -74,6 +74,7 @@ def test_model_runner_input():
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
) )
model_input = ModelInputForGPUWithSamplingMetadata( model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
...@@ -126,6 +127,7 @@ def test_embedding_model_runner_input(): ...@@ -126,6 +127,7 @@ def test_embedding_model_runner_input():
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
) )
model_input = ModelInputForGPUWithPoolingMetadata( model_input = ModelInputForGPUWithPoolingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
...@@ -177,6 +179,7 @@ def test_multi_step_model_runner_input(): ...@@ -177,6 +179,7 @@ def test_multi_step_model_runner_input():
num_decode_tokens=3, num_decode_tokens=3,
slot_mapping=torch.zeros(1), slot_mapping=torch.zeros(1),
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
) )
frozen_model_input = ModelInputForGPUWithSamplingMetadata( frozen_model_input = ModelInputForGPUWithSamplingMetadata(
input_tokens=torch.ones(10), input_tokens=torch.ones(10),
......
...@@ -48,8 +48,8 @@ def paged_attention_v1( ...@@ -48,8 +48,8 @@ def paged_attention_v1(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -80,8 +80,8 @@ def paged_attention_v2( ...@@ -80,8 +80,8 @@ def paged_attention_v2(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
tp_rank: int = 0, tp_rank: int = 0,
blocksparse_local_blocks: int = 0, blocksparse_local_blocks: int = 0,
blocksparse_vert_stride: int = 0, blocksparse_vert_stride: int = 0,
...@@ -112,8 +112,8 @@ def paged_attention_rocm( ...@@ -112,8 +112,8 @@ def paged_attention_rocm(
max_seq_len: int, max_seq_len: int,
alibi_slopes: Optional[torch.Tensor], alibi_slopes: Optional[torch.Tensor],
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
) -> None: ) -> None:
torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query, torch.ops._rocm_C.paged_attention(out, exp_sum, max_logits, tmp_out, query,
key_cache, value_cache, num_kv_heads, key_cache, value_cache, num_kv_heads,
...@@ -956,8 +956,8 @@ def reshape_and_cache( ...@@ -956,8 +956,8 @@ def reshape_and_cache(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
) -> None: ) -> None:
torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache, torch.ops._C_cache_ops.reshape_and_cache(key, value, key_cache,
value_cache, slot_mapping, value_cache, slot_mapping,
...@@ -971,8 +971,8 @@ def reshape_and_cache_flash( ...@@ -971,8 +971,8 @@ def reshape_and_cache_flash(
value_cache: torch.Tensor, value_cache: torch.Tensor,
slot_mapping: torch.Tensor, slot_mapping: torch.Tensor,
kv_cache_dtype: str, kv_cache_dtype: str,
k_scale: float, k_scale: torch.Tensor,
v_scale: float, v_scale: torch.Tensor,
) -> None: ) -> None:
torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache, torch.ops._C_cache_ops.reshape_and_cache_flash(key, value, key_cache,
value_cache, slot_mapping, value_cache, slot_mapping,
......
...@@ -123,6 +123,10 @@ class AttentionMetadata: ...@@ -123,6 +123,10 @@ class AttentionMetadata:
multi_modal_placeholder_index_maps: Optional[Dict[ multi_modal_placeholder_index_maps: Optional[Dict[
str, MultiModalPlaceholderMap.IndexMap]] str, MultiModalPlaceholderMap.IndexMap]]
# Enable/disable KV scales calculation. This is so that we can disable the
# calculation until after prefill and cuda graph capture.
enable_kv_scales_calculation: bool
@property @property
@abstractmethod @abstractmethod
def prefill_metadata(self) -> Optional["AttentionMetadata"]: def prefill_metadata(self) -> Optional["AttentionMetadata"]:
...@@ -226,8 +230,10 @@ class AttentionMetadataBuilder(ABC, Generic[T]): ...@@ -226,8 +230,10 @@ class AttentionMetadataBuilder(ABC, Generic[T]):
class AttentionLayer(Protocol): class AttentionLayer(Protocol):
_k_scale: float _k_scale: torch.Tensor
_v_scale: float _v_scale: torch.Tensor
_k_scale_float: float
_v_scale_float: float
def forward( def forward(
self, self,
......
...@@ -222,6 +222,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -222,6 +222,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -251,6 +252,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): ...@@ -251,6 +252,7 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_query_len=None,
......
...@@ -230,6 +230,7 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -230,6 +230,7 @@ class FlashAttentionMetadata(AttentionMetadata):
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -274,6 +275,7 @@ class FlashAttentionMetadata(AttentionMetadata): ...@@ -274,6 +275,7 @@ class FlashAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
...@@ -557,6 +559,7 @@ class FlashAttentionMetadataBuilder( ...@@ -557,6 +559,7 @@ class FlashAttentionMetadataBuilder(
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=max_query_len, max_query_len=max_query_len,
max_decode_query_len=max_decode_query_len, max_decode_query_len=max_decode_query_len,
...@@ -675,7 +678,7 @@ class FlashAttentionImpl(AttentionImpl): ...@@ -675,7 +678,7 @@ class FlashAttentionImpl(AttentionImpl):
NOTE: It in-place updates the output tensor. NOTE: It in-place updates the output tensor.
""" """
# NOTE(woosuk): FlashAttention does not support FP8 KV cache. # NOTE(woosuk): FlashAttention does not support FP8 KV cache.
assert layer._k_scale == 1.0 and layer._v_scale == 1.0, ( assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0, (
"key/v_scale is not supported in FlashAttention.") "key/v_scale is not supported in FlashAttention.")
assert output is not None, "Output tensor must be provided." assert output is not None, "Output tensor must be provided."
......
...@@ -219,6 +219,7 @@ class FlashInferState(AttentionState): ...@@ -219,6 +219,7 @@ class FlashInferState(AttentionState):
num_prefills=0, num_prefills=0,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=False,
num_prefill_tokens=0, num_prefill_tokens=0,
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
max_prefill_seq_len=0, max_prefill_seq_len=0,
...@@ -733,6 +734,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]): ...@@ -733,6 +734,7 @@ class FlashInferMetadataBuilder(AttentionMetadataBuilder[FlashInferMetadata]):
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
max_prefill_seq_len=max_prefill_seq_len, max_prefill_seq_len=max_prefill_seq_len,
...@@ -888,8 +890,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -888,8 +890,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache, kv_cache,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
causal=True, causal=True,
k_scale=layer._k_scale, k_scale=layer._k_scale_float,
v_scale=layer._v_scale, v_scale=layer._v_scale_float,
window_left=window_left) window_left=window_left)
if decode_meta := attn_metadata.decode_metadata: if decode_meta := attn_metadata.decode_metadata:
assert decode_meta is not None assert decode_meta is not None
...@@ -899,8 +901,8 @@ class FlashInferImpl(AttentionImpl): ...@@ -899,8 +901,8 @@ class FlashInferImpl(AttentionImpl):
kv_cache, kv_cache,
sm_scale=softmax_scale, sm_scale=softmax_scale,
logits_soft_cap=logits_soft_cap, logits_soft_cap=logits_soft_cap,
k_scale=layer._k_scale, k_scale=layer._k_scale_float,
v_scale=layer._v_scale, v_scale=layer._v_scale_float,
window_left=window_left) window_left=window_left)
if prefill_output is None and decode_output is not None: if prefill_output is None and decode_output is not None:
......
...@@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]): ...@@ -193,7 +193,7 @@ class IpexAttnBackendImpl(AttentionImpl[IpexAttnMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert layer._k_scale == 1.0 and layer._v_scale == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
num_tokens, hidden_size = query.shape num_tokens, hidden_size = query.shape
# Reshape the query, key, and value tensors. # Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size) query = query.view(-1, self.num_heads, self.head_size)
......
...@@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl): ...@@ -173,7 +173,7 @@ class PallasAttentionBackendImpl(AttentionImpl):
Returns: Returns:
shape = [batch_size, seq_len, num_heads * head_size] shape = [batch_size, seq_len, num_heads * head_size]
""" """
assert layer._k_scale == 1.0 and layer._v_scale == 1.0 assert layer._k_scale_float == 1.0 and layer._v_scale_float == 1.0
batch_size, seq_len, hidden_size = query.shape batch_size, seq_len, hidden_size = query.shape
query = query.view(batch_size, seq_len, self.num_heads, self.head_size) query = query.view(batch_size, seq_len, self.num_heads, self.head_size)
key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size) key = key.view(batch_size, seq_len, self.num_kv_heads, self.head_size)
......
...@@ -140,6 +140,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -140,6 +140,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_decode_query_len=0, max_decode_query_len=0,
...@@ -173,6 +174,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata): ...@@ -173,6 +174,7 @@ class PlaceholderAttentionMetadata(AttentionMetadata):
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_decode_query_len=self.max_decode_query_len, max_decode_query_len=self.max_decode_query_len,
...@@ -380,6 +382,7 @@ class PlaceholderAttentionMetadataBuilder( ...@@ -380,6 +382,7 @@ class PlaceholderAttentionMetadataBuilder(
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
......
...@@ -153,6 +153,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -153,6 +153,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=self.slot_mapping[:self.num_prefill_tokens], slot_mapping=self.slot_mapping[:self.num_prefill_tokens],
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=self.seq_lens[:self.num_prefills], seq_lens=self.seq_lens[:self.num_prefills],
seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills],
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -182,6 +183,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -182,6 +183,7 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata):
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=self.slot_mapping[self.num_prefill_tokens:], slot_mapping=self.slot_mapping[self.num_prefill_tokens:],
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:],
max_query_len=None, max_query_len=None,
......
...@@ -379,6 +379,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): ...@@ -379,6 +379,7 @@ class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]):
prefill_block_tables=prefill_block_tables, prefill_block_tables=prefill_block_tables,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=False,
) )
return attn_metadata return attn_metadata
...@@ -454,7 +455,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): ...@@ -454,7 +455,6 @@ class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]):
Returns: Returns:
shape = [num_tokens, num_heads * head_size] shape = [num_tokens, num_heads * head_size]
""" """
assert layer._k_scale == 1.0 and layer._v_scale == 1.0
attn_type = self.attn_type attn_type = self.attn_type
if (attn_type == AttentionType.ENCODER if (attn_type == AttentionType.ENCODER
and (not attn_metadata.is_all_encoder_attn_metadata_set)): and (not attn_metadata.is_all_encoder_attn_metadata_set)):
......
...@@ -265,6 +265,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]): ...@@ -265,6 +265,7 @@ class CommonMetadataBuilder(AttentionMetadataBuilder[TAttentionMetadata]):
num_prefills=self.num_prefills, num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor, slot_mapping=slot_mapping_tensor,
multi_modal_placeholder_index_maps=placeholder_index_maps, multi_modal_placeholder_index_maps=placeholder_index_maps,
enable_kv_scales_calculation=True,
num_prefill_tokens=self.num_prefill_tokens, num_prefill_tokens=self.num_prefill_tokens,
num_decode_tokens=num_decode_tokens, num_decode_tokens=num_decode_tokens,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -317,6 +318,7 @@ class CommonAttentionState(AttentionState): ...@@ -317,6 +318,7 @@ class CommonAttentionState(AttentionState):
num_decode_tokens=batch_size, num_decode_tokens=batch_size,
slot_mapping=self._graph_slot_mapping[:batch_size], slot_mapping=self._graph_slot_mapping[:batch_size],
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens=None, seq_lens=None,
seq_lens_tensor=self._graph_seq_lens[:batch_size], seq_lens_tensor=self._graph_seq_lens[:batch_size],
max_query_len=1, max_query_len=1,
......
...@@ -218,6 +218,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -218,6 +218,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=self. multi_modal_placeholder_index_maps=self.
multi_modal_placeholder_index_maps, multi_modal_placeholder_index_maps,
enable_kv_scales_calculation=self.enable_kv_scales_calculation,
seq_lens=seq_lens, seq_lens=seq_lens,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_query_len=self.max_query_len, max_query_len=self.max_query_len,
...@@ -262,6 +263,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): ...@@ -262,6 +263,7 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata):
num_decode_tokens=self.num_decode_tokens, num_decode_tokens=self.num_decode_tokens,
slot_mapping=slot_mapping, slot_mapping=slot_mapping,
multi_modal_placeholder_index_maps=None, multi_modal_placeholder_index_maps=None,
enable_kv_scales_calculation=True,
seq_lens_tensor=seq_lens_tensor, seq_lens_tensor=seq_lens_tensor,
max_prefill_seq_len=0, max_prefill_seq_len=0,
max_decode_seq_len=self.max_decode_seq_len, max_decode_seq_len=self.max_decode_seq_len,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment