Unverified Commit 116f4be4 authored by Matthew Bonanni's avatar Matthew Bonanni Committed by GitHub
Browse files

[1/N][Cleanup] Standardize on use of `is_quantized_kv_cache` (#38659)


Signed-off-by: default avatarMatthew Bonanni <mbonanni@redhat.com>
parent 7b01d97a
...@@ -13,6 +13,7 @@ from vllm.logger import init_logger ...@@ -13,6 +13,7 @@ from vllm.logger import init_logger
from vllm.model_executor.layers.attention.mla_attention import ( from vllm.model_executor.layers.attention.mla_attention import (
get_mla_dims, get_mla_dims,
) )
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]): ...@@ -231,7 +232,7 @@ class XPUMLASparseImpl(SparseMLAAttentionImpl[XPUMLASparseMetadata]):
# NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use # NOTE(lucas): for the sparse FlashMLA kernels the kernels want to use
# MQA 576/512 approach for both prefill and decode # MQA 576/512 approach for both prefill and decode
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet") raise NotImplementedError("FP8 kv is not supported with XPU MLA Sparse yet")
# Concatenate q if it's a tuple (ql_nope, q_pe) # Concatenate q if it's a tuple (ql_nope, q_pe)
......
...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform ...@@ -16,6 +16,7 @@ from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import cdiv from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import num_compute_units from vllm.utils.platform_utils import num_compute_units
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -291,7 +292,7 @@ if current_platform.is_rocm(): ...@@ -291,7 +292,7 @@ if current_platform.is_rocm():
new_key_cache = key_cache.view_as(k_cache_template) new_key_cache = key_cache.view_as(k_cache_template)
new_value_cache = value_cache.view_as(v_cache_template) new_value_cache = value_cache.view_as(v_cache_template)
QUANT = False QUANT = False
if kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(kv_cache_dtype):
QUANT = True QUANT = True
grid = ( grid = (
num_tokens, num_tokens,
...@@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder( ...@@ -494,7 +495,7 @@ class AiterFlashAttentionMetadataBuilder(
if ( if (
rocm_aiter_ops.is_shuffle_kv_cache_enabled() rocm_aiter_ops.is_shuffle_kv_cache_enabled()
and self.scale.numel() == 1 and self.scale.numel() == 1
and self.vllm_config.cache_config.cache_dtype.startswith("fp8") and is_quantized_kv_cache(self.vllm_config.cache_config.cache_dtype)
): ):
layers = get_layers_from_vllm_config(self.vllm_config, Attention) layers = get_layers_from_vllm_config(self.vllm_config, Attention)
first_layer_name = [k for k in layers][0] first_layer_name = [k for k in layers][0]
...@@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -887,7 +888,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=swa_cu_seqlens, cu_seqlens_kv=swa_cu_seqlens,
token_to_batch=swa_token_to_batch, token_to_batch=swa_token_to_batch,
seq_starts=swa_seq_starts, seq_starts=swa_seq_starts,
dequant=self.kv_cache_dtype.startswith("fp8"), dequant=is_quantized_kv_cache(self.kv_cache_dtype),
kv_cache_layout="NHD", kv_cache_layout="NHD",
total_tokens=swa_total_tokens, total_tokens=swa_total_tokens,
) )
...@@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -982,7 +983,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
cu_seqlens_kv=cu_seqlens_kv[chunk_idx], cu_seqlens_kv=cu_seqlens_kv[chunk_idx],
token_to_batch=token_to_batch[chunk_idx], token_to_batch=token_to_batch[chunk_idx],
seq_starts=chunk_starts[chunk_idx], seq_starts=chunk_starts[chunk_idx],
dequant=self.kv_cache_dtype.startswith("fp8"), dequant=is_quantized_kv_cache(self.kv_cache_dtype),
kv_cache_layout="SHUFFLE" kv_cache_layout="SHUFFLE"
if rocm_aiter_ops.is_shuffle_kv_cache_enabled() if rocm_aiter_ops.is_shuffle_kv_cache_enabled()
else "NHD", else "NHD",
...@@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1081,7 +1082,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
num_actual_tokens = attn_metadata.num_actual_tokens num_actual_tokens = attn_metadata.num_actual_tokens
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype()) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype())
...@@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1370,7 +1371,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
# key and value may be None in the case of cross attention. They are # key and value may be None in the case of cross attention. They are
# calculated once based on the output from the encoder and then cached # calculated once based on the output from the encoder and then cached
# in KV cache. # in KV cache.
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(current_platform.fp8_dtype()) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype())
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
...@@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl): ...@@ -1436,7 +1437,7 @@ class AiterFlashAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(current_platform.fp8_dtype()) key_cache = key_cache.view(current_platform.fp8_dtype())
value_cache = value_cache.view(current_platform.fp8_dtype()) value_cache = value_cache.view(current_platform.fp8_dtype())
......
...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -11,6 +11,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
QuantKey, QuantKey,
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf from vllm.v1.attention.backend import AttentionLayer, AttentionType, MultipleOf
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_attn import ( from vllm.v1.attention.backends.rocm_attn import (
...@@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -200,7 +201,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
softmax_scale = self.scale softmax_scale = self.scale
fp8_post_attn_v_rescale = False fp8_post_attn_v_rescale = False
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
# When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul). # When Q is FP8, triton kernel skips K/V dequant (for fp8xfp8 matmul).
...@@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl): ...@@ -299,7 +300,7 @@ class RocmAiterUnifiedAttentionImpl(RocmAttentionImpl):
key_cache, value_cache = kv_cache.unbind(0) key_cache, value_cache = kv_cache.unbind(0)
flash_layout = True flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
......
...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -16,6 +16,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
kFp8StaticTensorSym, kFp8StaticTensorSym,
) )
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -315,7 +316,7 @@ class RocmAttentionImpl(AttentionImpl):
layer: The attention layer layer: The attention layer
""" """
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantization is not supported for encoder attention"
) )
...@@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -406,7 +407,7 @@ class RocmAttentionImpl(AttentionImpl):
kv_cache, self.num_kv_heads, self.head_size kv_cache, self.num_kv_heads, self.head_size
) )
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
assert layer._q_scale_float == 1.0, ( assert layer._q_scale_float == 1.0, (
...@@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl): ...@@ -513,7 +514,7 @@ class RocmAttentionImpl(AttentionImpl):
) )
flash_layout = False flash_layout = False
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
......
...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import ( ...@@ -18,6 +18,7 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.platforms.interface import DeviceCapability from vllm.platforms.interface import DeviceCapability
from vllm.utils.math_utils import next_power_of_2 from vllm.utils.math_utils import next_power_of_2
from vllm.utils.torch_utils import is_quantized_kv_cache
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
AttentionBackend, AttentionBackend,
AttentionCGSupport, AttentionCGSupport,
...@@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -472,7 +473,7 @@ class TritonAttentionImpl(AttentionImpl):
# For decoder and cross-attention, use KV cache as before # For decoder and cross-attention, use KV cache as before
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
if key_cache.dtype != self.fp8_dtype: if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
...@@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -546,7 +547,7 @@ class TritonAttentionImpl(AttentionImpl):
layer: The attention layer layer: The attention layer
""" """
# For encoder attention, process FP8 quantization if needed # For encoder attention, process FP8 quantization if needed
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError( raise NotImplementedError(
"quantization is not supported for encoder attention" "quantization is not supported for encoder attention"
) )
...@@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -588,7 +589,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
# Reshape the input keys and values and store them in the cache. # Reshape the input keys and values and store them in the cache.
if self.kv_cache_dtype.startswith("fp8"): if is_quantized_kv_cache(self.kv_cache_dtype):
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
# triton kernel does not support uint8 kv_cache # triton kernel does not support uint8 kv_cache
...@@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl): ...@@ -623,7 +624,7 @@ class TritonAttentionImpl(AttentionImpl):
key_cache, value_cache = kv_cache.unbind(1) key_cache, value_cache = kv_cache.unbind(1)
flash_layout = True flash_layout = True
is_fp8_kv_cache = self.kv_cache_dtype.startswith("fp8") is_fp8_kv_cache = is_quantized_kv_cache(self.kv_cache_dtype)
if is_fp8_kv_cache: if is_fp8_kv_cache:
key_cache = key_cache.view(self.fp8_dtype) key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype) value_cache = value_cache.view(self.fp8_dtype)
......
...@@ -5,6 +5,7 @@ import torch ...@@ -5,6 +5,7 @@ import torch
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.triton_utils import tl, triton from vllm.triton_utils import tl, triton
from vllm.utils.torch_utils import is_quantized_kv_cache
@triton.jit @triton.jit
...@@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash( ...@@ -145,16 +146,18 @@ def triton_reshape_and_cache_flash(
block_stride = key_cache.stride()[0] block_stride = key_cache.stride()[0]
page_stride = key_cache.stride()[1] page_stride = key_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
) )
kv_cache_torch_dtype = ( kv_cache_torch_dtype = (
current_platform.fp8_dtype() current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8") if is_quantized_kv_cache(kv_cache_dtype)
else key_cache.dtype else key_cache.dtype
) )
if key_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): if key_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(
kv_cache_dtype
):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8) # to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
key_cache = key_cache.view(kv_cache_torch_dtype) key_cache = key_cache.view(kv_cache_torch_dtype)
...@@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash( ...@@ -164,7 +167,7 @@ def triton_reshape_and_cache_flash(
"uint8 is not supported by triton reshape_and_cache_flash" "uint8 is not supported by triton reshape_and_cache_flash"
) )
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2, torch.float8_e5m2,
...@@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv( ...@@ -323,16 +326,16 @@ def triton_reshape_and_cache_flash_diffkv(
block_stride = kv_cache.stride()[0] block_stride = kv_cache.stride()[0]
page_stride = kv_cache.stride()[1] page_stride = kv_cache.stride()[1]
assert kv_cache_dtype == "auto" or kv_cache_dtype.startswith("fp8"), ( assert kv_cache_dtype == "auto" or is_quantized_kv_cache(kv_cache_dtype), (
f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}." f"unsupported kv_cache_dtype (str), got {kv_cache_dtype}."
) )
kv_cache_torch_dtype = ( kv_cache_torch_dtype = (
current_platform.fp8_dtype() current_platform.fp8_dtype()
if kv_cache_dtype.startswith("fp8") if is_quantized_kv_cache(kv_cache_dtype)
else kv_cache.dtype else kv_cache.dtype
) )
if kv_cache.dtype != kv_cache_torch_dtype and kv_cache_dtype.startswith("fp8"): if kv_cache.dtype != kv_cache_torch_dtype and is_quantized_kv_cache(kv_cache_dtype):
# to avoid erounous implicit cast in triton kernel (tl.store to uint8) # to avoid erounous implicit cast in triton kernel (tl.store to uint8)
# (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4) # (e.g. explicit cast to fp8e4m3fnuz is not supported in triton 3.4)
kv_cache = kv_cache.view(kv_cache_torch_dtype) kv_cache = kv_cache.view(kv_cache_torch_dtype)
...@@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv( ...@@ -341,7 +344,7 @@ def triton_reshape_and_cache_flash_diffkv(
"uint8 is not supported by triton reshape_and_cache_flash_diffkv" "uint8 is not supported by triton reshape_and_cache_flash_diffkv"
) )
FP8_KV_CACHE = kv_cache_dtype.startswith("fp8") FP8_KV_CACHE = is_quantized_kv_cache(kv_cache_dtype)
assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [ assert (not FP8_KV_CACHE) or kv_cache_torch_dtype in [
torch.float8_e4m3fn, torch.float8_e4m3fn,
torch.float8_e5m2, torch.float8_e5m2,
......
...@@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks ...@@ -109,6 +109,7 @@ from vllm.utils.nvtx_pytorch_hooks import PytHooks
from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units from vllm.utils.platform_utils import is_pin_memory_available, num_compute_units
from vllm.utils.torch_utils import ( from vllm.utils.torch_utils import (
get_dtype_size, get_dtype_size,
is_quantized_kv_cache,
kv_cache_dtype_str_to_dtype, kv_cache_dtype_str_to_dtype,
) )
from vllm.v1.attention.backend import ( from vllm.v1.attention.backend import (
...@@ -896,7 +897,7 @@ class GPUModelRunner( ...@@ -896,7 +897,7 @@ class GPUModelRunner(
If these are left at 0.0 (default after wake_up), all KV cache values If these are left at 0.0 (default after wake_up), all KV cache values
become effectively zero, causing gibberish output. become effectively zero, causing gibberish output.
""" """
if not self.cache_config.cache_dtype.startswith("fp8"): if not is_quantized_kv_cache(self.cache_config.cache_dtype):
return return
kv_caches = getattr(self, "kv_caches", []) kv_caches = getattr(self, "kv_caches", [])
......
...@@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask ...@@ -46,7 +46,7 @@ from vllm.tasks import SupportedTask
from vllm.tracing import instrument from vllm.tracing import instrument
from vllm.utils.mem_constants import GiB_bytes from vllm.utils.mem_constants import GiB_bytes
from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling from vllm.utils.mem_utils import MemorySnapshot, format_gib, memory_profiling
from vllm.utils.torch_utils import set_random_seed from vllm.utils.torch_utils import is_quantized_kv_cache, set_random_seed
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec
from vllm.v1.outputs import ( from vllm.v1.outputs import (
...@@ -197,7 +197,7 @@ class Worker(WorkerBase): ...@@ -197,7 +197,7 @@ class Worker(WorkerBase):
# especially the FP8 scaling factor. # especially the FP8 scaling factor.
if ( if (
(tags is None or "kv_cache" in tags) (tags is None or "kv_cache" in tags)
and self.cache_config.cache_dtype.startswith("fp8") and is_quantized_kv_cache(self.cache_config.cache_dtype)
and hasattr(self.model_runner, "init_fp8_kv_scales") and hasattr(self.model_runner, "init_fp8_kv_scales")
): ):
self.model_runner.init_fp8_kv_scales() self.model_runner.init_fp8_kv_scales()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment