Commit 84c276d7 authored by Dao007forever's avatar Dao007forever Committed by khluu
Browse files

[Bugfix] Cap SWA/chunked-local runtime admission to startup pool-sizing bound (#40946)


Signed-off-by: default avatarDao Le <Dao007forever@gmail.com>
Signed-off-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarClaude <noreply@anthropic.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
(cherry picked from commit 7b1bc0a3eb01a6bc2650eda9970049f7825240d7)
parent 5eb36575
...@@ -2512,3 +2512,111 @@ def test_block_lookup_cache_multi_blocks_per_key(): ...@@ -2512,3 +2512,111 @@ def test_block_lookup_cache_multi_blocks_per_key():
assert cache.pop(key1, 11) is block11 assert cache.pop(key1, 11) is block11
assert cache.get_one_block(key1) is None assert cache.get_one_block(key1) is None
assert cache.pop(key1, 12) is None assert cache.pop(key1, 12) is None
def test_can_fit_full_sequence_swa_cap_admits_long_prompt():
"""Hybrid full+SWA model with a pool sized at the startup minimum should
admit a prompt longer than the SWA cap, because SlidingWindowManager
recycles blocks during chunked prefill (issue #39734)."""
block_size = 16
sliding_window = 4 * block_size # 64 tokens
max_num_batched_tokens = 8 * block_size # 128 tokens
max_model_len = 64 * block_size # 1024 tokens — much larger than the SWA cap
# Startup pool sizing: full demands cdiv(max_model_len, bs) = 64 blocks,
# SWA demands cdiv(SW-1+max_batched, bs) + 1 = cdiv(191, 16) + 1 = 13.
# Pool minimum = 64 + 13 = 77; +1 for the null block.
num_blocks = 64 + 13 + 1
config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer_full"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer_swa"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)
manager = KVCacheManager(
config,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_caching=True,
hash_block_size=block_size,
)
# A prompt that is shorter than max_model_len but longer than SW + chunk:
# cdiv(prompt_len, bs) = 32 blocks. Without the cap, admission would
# demand 32 (full) + 32 (SWA) = 64 blocks. With the cap, SWA contributes
# only 13, so total = 32 + 13 = 45 ≤ pool size.
prompt_len = 32 * block_size
req = make_request("long", list(range(prompt_len)), block_size, sha256)
assert manager.can_fit_full_sequence(req)
def test_can_fit_full_sequence_full_attention_still_gates_oversized():
"""The cap only loosens the SWA group; a prompt that exceeds the
full-attention pool capacity must still be rejected."""
block_size = 16
sliding_window = 4 * block_size
max_num_batched_tokens = 8 * block_size
max_model_len = 64 * block_size
# Provide a tiny pool — even a small prompt should be rejected.
num_blocks = 5
config = KVCacheConfig(
num_blocks=num_blocks,
kv_cache_tensors=[],
kv_cache_groups=[
KVCacheGroupSpec(
["layer_full"],
FullAttentionSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
),
),
KVCacheGroupSpec(
["layer_swa"],
SlidingWindowSpec(
block_size=block_size,
num_kv_heads=1,
head_size=1,
dtype=torch.float32,
sliding_window=sliding_window,
),
),
],
)
manager = KVCacheManager(
config,
max_model_len=max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
enable_caching=True,
hash_block_size=block_size,
)
# 16 blocks of full attention demand alone exceeds the 5-block pool.
prompt_len = 16 * block_size
req = make_request("oversized", list(range(prompt_len)), block_size, sha256)
assert not manager.can_fit_full_sequence(req)
...@@ -22,11 +22,13 @@ pytestmark = pytest.mark.cpu_test ...@@ -22,11 +22,13 @@ pytestmark = pytest.mark.cpu_test
def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True): def get_sliding_window_manager(sliding_window_spec, block_pool, enable_caching=True):
# Tests don't exercise admission gating; pass a large cap that is a no-op.
return SlidingWindowManager( return SlidingWindowManager(
sliding_window_spec, sliding_window_spec,
block_pool=block_pool, block_pool=block_pool,
enable_caching=enable_caching, enable_caching=enable_caching,
kv_cache_group_id=0, kv_cache_group_id=0,
max_admission_blocks_per_request=10**9,
) )
...@@ -38,6 +40,7 @@ def get_chunked_local_attention_manager( ...@@ -38,6 +40,7 @@ def get_chunked_local_attention_manager(
block_pool=block_pool, block_pool=block_pool,
enable_caching=enable_caching, enable_caching=enable_caching,
kv_cache_group_id=0, kv_cache_group_id=0,
max_admission_blocks_per_request=10**9,
) )
......
...@@ -34,6 +34,7 @@ class KVCacheCoordinator(ABC): ...@@ -34,6 +34,7 @@ class KVCacheCoordinator(ABC):
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
...@@ -65,6 +66,8 @@ class KVCacheCoordinator(ABC): ...@@ -65,6 +66,8 @@ class KVCacheCoordinator(ABC):
self.single_type_managers = tuple( self.single_type_managers = tuple(
get_manager_for_kv_cache_spec( get_manager_for_kv_cache_spec(
kv_cache_spec=kv_cache_group.kv_cache_spec, kv_cache_spec=kv_cache_group.kv_cache_spec,
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
block_pool=self.block_pool, block_pool=self.block_pool,
enable_caching=enable_caching, enable_caching=enable_caching,
kv_cache_group_id=i, kv_cache_group_id=i,
...@@ -271,6 +274,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): ...@@ -271,6 +274,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool, use_eagle: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
dcp_world_size: int, dcp_world_size: int,
...@@ -281,6 +285,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator): ...@@ -281,6 +285,7 @@ class KVCacheCoordinatorNoPrefixCache(KVCacheCoordinator):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
max_model_len, max_model_len,
max_num_batched_tokens,
use_eagle, use_eagle,
False, False,
enable_kv_cache_events, enable_kv_cache_events,
...@@ -316,6 +321,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): ...@@ -316,6 +321,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
...@@ -327,6 +333,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator): ...@@ -327,6 +333,7 @@ class UnitaryKVCacheCoordinator(KVCacheCoordinator):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
max_model_len, max_model_len,
max_num_batched_tokens,
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
...@@ -381,6 +388,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -381,6 +388,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
self, self,
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
...@@ -392,6 +400,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -392,6 +400,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
super().__init__( super().__init__(
kv_cache_config, kv_cache_config,
max_model_len, max_model_len,
max_num_batched_tokens,
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
...@@ -574,6 +583,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator): ...@@ -574,6 +583,7 @@ class HybridKVCacheCoordinator(KVCacheCoordinator):
def get_kv_cache_coordinator( def get_kv_cache_coordinator(
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
max_num_batched_tokens: int,
use_eagle: bool, use_eagle: bool,
enable_caching: bool, enable_caching: bool,
enable_kv_cache_events: bool, enable_kv_cache_events: bool,
...@@ -586,6 +596,7 @@ def get_kv_cache_coordinator( ...@@ -586,6 +596,7 @@ def get_kv_cache_coordinator(
return KVCacheCoordinatorNoPrefixCache( return KVCacheCoordinatorNoPrefixCache(
kv_cache_config, kv_cache_config,
max_model_len, max_model_len,
max_num_batched_tokens,
use_eagle, use_eagle,
enable_kv_cache_events, enable_kv_cache_events,
dcp_world_size=dcp_world_size, dcp_world_size=dcp_world_size,
...@@ -597,6 +608,7 @@ def get_kv_cache_coordinator( ...@@ -597,6 +608,7 @@ def get_kv_cache_coordinator(
return UnitaryKVCacheCoordinator( return UnitaryKVCacheCoordinator(
kv_cache_config, kv_cache_config,
max_model_len, max_model_len,
max_num_batched_tokens,
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
...@@ -608,6 +620,7 @@ def get_kv_cache_coordinator( ...@@ -608,6 +620,7 @@ def get_kv_cache_coordinator(
return HybridKVCacheCoordinator( return HybridKVCacheCoordinator(
kv_cache_config, kv_cache_config,
max_model_len, max_model_len,
max_num_batched_tokens,
use_eagle, use_eagle,
enable_caching, enable_caching,
enable_kv_cache_events, enable_kv_cache_events,
......
...@@ -109,6 +109,7 @@ class KVCacheManager: ...@@ -109,6 +109,7 @@ class KVCacheManager:
kv_cache_config: KVCacheConfig, kv_cache_config: KVCacheConfig,
max_model_len: int, max_model_len: int,
hash_block_size: int, hash_block_size: int,
max_num_batched_tokens: int | None = None,
enable_caching: bool = True, enable_caching: bool = True,
use_eagle: bool = False, use_eagle: bool = False,
log_stats: bool = False, log_stats: bool = False,
...@@ -118,6 +119,11 @@ class KVCacheManager: ...@@ -118,6 +119,11 @@ class KVCacheManager:
metrics_collector: KVCacheMetricsCollector | None = None, metrics_collector: KVCacheMetricsCollector | None = None,
) -> None: ) -> None:
self.max_model_len = max_model_len self.max_model_len = max_model_len
# When unset, fall back to `max_model_len` so the recycling-aware cap
# collapses to the prior (uncapped) admission behavior. The scheduler
# always supplies the real value at runtime.
if max_num_batched_tokens is None:
max_num_batched_tokens = max_model_len
self.enable_caching = enable_caching self.enable_caching = enable_caching
self.use_eagle = use_eagle self.use_eagle = use_eagle
...@@ -131,6 +137,7 @@ class KVCacheManager: ...@@ -131,6 +137,7 @@ class KVCacheManager:
self.coordinator = get_kv_cache_coordinator( self.coordinator = get_kv_cache_coordinator(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
max_num_batched_tokens=max_num_batched_tokens,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
enable_caching=self.enable_caching, enable_caching=self.enable_caching,
enable_kv_cache_events=enable_kv_cache_events, enable_kv_cache_events=enable_kv_cache_events,
......
...@@ -228,6 +228,7 @@ class Scheduler(SchedulerInterface): ...@@ -228,6 +228,7 @@ class Scheduler(SchedulerInterface):
self.kv_cache_manager = KVCacheManager( self.kv_cache_manager = KVCacheManager(
kv_cache_config=kv_cache_config, kv_cache_config=kv_cache_config,
max_model_len=self.max_model_len, max_model_len=self.max_model_len,
max_num_batched_tokens=self.scheduler_config.max_num_batched_tokens,
enable_caching=self.cache_config.enable_prefix_caching, enable_caching=self.cache_config.enable_prefix_caching,
use_eagle=self.use_eagle, use_eagle=self.use_eagle,
log_stats=self.log_stats, log_stats=self.log_stats,
......
...@@ -41,6 +41,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -41,6 +41,7 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_group_id: int, kv_cache_group_id: int,
dcp_world_size: int = 1, dcp_world_size: int = 1,
pcp_world_size: int = 1, pcp_world_size: int = 1,
max_admission_blocks_per_request: int | None = None,
) -> None: ) -> None:
""" """
Initializes the SingleTypeKVCacheManager. Initializes the SingleTypeKVCacheManager.
...@@ -48,6 +49,12 @@ class SingleTypeKVCacheManager(ABC): ...@@ -48,6 +49,12 @@ class SingleTypeKVCacheManager(ABC):
kv_cache_spec: The kv_cache_spec for this manager. kv_cache_spec: The kv_cache_spec for this manager.
block_pool: The block pool. block_pool: The block pool.
kv_cache_group_id: The id of the kv cache group of this manager. kv_cache_group_id: The id of the kv cache group of this manager.
max_admission_blocks_per_request: Recycling-aware per-request
block cap used by `get_num_blocks_to_allocate`. Only set for
spec types that recycle blocks across chunks (SWA,
chunked-local); `None` (the default) means no cap, which is
correct for full-attention-style specs that hold every
block until the request finishes.
""" """
self.block_size = kv_cache_spec.block_size self.block_size = kv_cache_spec.block_size
self.dcp_world_size = dcp_world_size self.dcp_world_size = dcp_world_size
...@@ -57,6 +64,7 @@ class SingleTypeKVCacheManager(ABC): ...@@ -57,6 +64,7 @@ class SingleTypeKVCacheManager(ABC):
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
self.block_pool = block_pool self.block_pool = block_pool
self.enable_caching = enable_caching self.enable_caching = enable_caching
self._max_admission_blocks_per_request = max_admission_blocks_per_request
self.new_block_ids: list[int] = [] self.new_block_ids: list[int] = []
# Mapping from request ID to blocks to track the blocks allocated # Mapping from request ID to blocks to track the blocks allocated
...@@ -105,6 +113,19 @@ class SingleTypeKVCacheManager(ABC): ...@@ -105,6 +113,19 @@ class SingleTypeKVCacheManager(ABC):
""" """
num_required_blocks = cdiv(num_tokens, self.block_size) num_required_blocks = cdiv(num_tokens, self.block_size)
if self._max_admission_blocks_per_request is not None:
# Recycling-aware specs (SWA, chunked-local) cap the per-request
# reservation here so admission matches the startup pool sizer
# (`SlidingWindowSpec.max_admission_blocks_per_request` / its
# chunked-local counterpart). `remove_skipped_blocks` runs from
# `allocate_slots` before each chunk's `get_num_blocks_to_allocate`,
# so per-request peak real-held blocks <= this cap, which keeps
# `sum(reservations) <= pool` <=> `sum(peak_real_held) <= pool`.
# Drift between the two would re-introduce the deadlock from
# issue #39734 or, worse, mid-prefill OOM.
num_required_blocks = min(
num_required_blocks, self._max_admission_blocks_per_request
)
num_req_blocks = len(self.req_to_blocks.get(request_id, ())) num_req_blocks = len(self.req_to_blocks.get(request_id, ()))
if request_id in self.num_cached_block: if request_id in self.num_cached_block:
...@@ -1126,8 +1147,21 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = { ...@@ -1126,8 +1147,21 @@ spec_manager_map: dict[type[KVCacheSpec], type[SingleTypeKVCacheManager]] = {
def get_manager_for_kv_cache_spec( def get_manager_for_kv_cache_spec(
kv_cache_spec: KVCacheSpec, **kwargs kv_cache_spec: KVCacheSpec,
max_num_batched_tokens: int,
max_model_len: int,
**kwargs,
) -> SingleTypeKVCacheManager: ) -> SingleTypeKVCacheManager:
manager_class = spec_manager_map[type(kv_cache_spec)] manager_class = spec_manager_map[type(kv_cache_spec)]
# SlidingWindow / ChunkedLocalAttention managers recycle blocks across
# chunks; the runtime admission cap must match the recycling-aware bound
# the startup pool sizer uses (single source of truth: the spec method).
if isinstance(kv_cache_spec, (SlidingWindowSpec, ChunkedLocalAttentionSpec)):
kwargs["max_admission_blocks_per_request"] = (
kv_cache_spec.max_admission_blocks_per_request(
max_num_batched_tokens=max_num_batched_tokens,
max_model_len=max_model_len,
)
)
manager = manager_class(kv_cache_spec, **kwargs) manager = manager_class(kv_cache_spec, **kwargs)
return manager return manager
...@@ -376,19 +376,28 @@ class MLAAttentionSpec(FullAttentionSpec): ...@@ -376,19 +376,28 @@ class MLAAttentionSpec(FullAttentionSpec):
class ChunkedLocalAttentionSpec(AttentionSpec): class ChunkedLocalAttentionSpec(AttentionSpec):
attention_chunk_size: int attention_chunk_size: int
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_admission_blocks_per_request(
max_model_len = vllm_config.model_config.max_model_len self, max_num_batched_tokens: int, max_model_len: int
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens ) -> int:
"""Per-request admission cap, in blocks.
# During chunked prefill, we allocate KV cache for at most Single source of truth for both startup pool sizing
# `self.attention_chunk_size` computed tokens plus the newly scheduled (`max_memory_usage_bytes`) and the runtime admission gate, so requests
# tokens. And we won't allocate KV cache for more than `max_model_len` admitted by startup can also be admitted at runtime.
# tokens. """
# During chunked prefill, we hold KV for at most one chunk window.
num_tokens = min( num_tokens = min(
self.attention_chunk_size + max_num_batched_tokens, max_model_len self.attention_chunk_size + max_num_batched_tokens, max_model_len
) )
return cdiv(num_tokens, self.block_size)
return cdiv(num_tokens, self.block_size) * self.page_size_bytes def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
max_blocks = self.max_admission_blocks_per_request(
max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len
)
return max_blocks * self.page_size_bytes
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
...@@ -409,26 +418,38 @@ class SlidingWindowSpec(AttentionSpec): ...@@ -409,26 +418,38 @@ class SlidingWindowSpec(AttentionSpec):
* get_dtype_size(self.dtype) * get_dtype_size(self.dtype)
) )
def max_admission_blocks_per_request(
self, max_num_batched_tokens: int, max_model_len: int
) -> int:
"""Per-request admission cap, in blocks.
Single source of truth for both startup pool sizing
(`max_memory_usage_bytes`) and the runtime admission gate. Per-request
real-held blocks plateau at this bound because
`SlidingWindowManager.remove_skipped_blocks` runs from `allocate_slots`
before each chunk's `get_num_blocks_to_allocate`.
"""
# During chunked prefill, we hold KV for the last `sliding_window-1`
# computed tokens plus the newly scheduled tokens, and never more
# than `max_model_len`.
num_tokens = min(
self.sliding_window - 1 + max_num_batched_tokens, max_model_len
)
# +1 because the sliding window may not start from the beginning of
# the block. E.g. block size 4 and num_token 4 needs two blocks
# [XXCD][EF] to store the 6-token window [CDEF].
return cdiv(num_tokens, self.block_size) + 1
def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int: def max_memory_usage_bytes(self, vllm_config: VllmConfig) -> int:
assert vllm_config.parallel_config.decode_context_parallel_size == 1, ( assert vllm_config.parallel_config.decode_context_parallel_size == 1, (
"DCP not support sliding window." "DCP not support sliding window."
) )
max_model_len = vllm_config.model_config.max_model_len max_model_len = vllm_config.model_config.max_model_len
max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens max_num_batched_tokens = vllm_config.scheduler_config.max_num_batched_tokens
max_blocks = self.max_admission_blocks_per_request(
# During chunked prefill, we allocate KV cache for the last max_num_batched_tokens=max_num_batched_tokens, max_model_len=max_model_len
# `self.sliding_window-1` computed tokens plus the newly scheduled
# tokens. And we won't allocate KV cache for more than `max_model_len`
# tokens.
num_tokens = min(
self.sliding_window - 1 + max_num_batched_tokens, max_model_len
) )
return max_blocks * self.page_size_bytes
# +1 here because the sliding window may not start from the beginning
# of the block. For example, if the block size is 4 and num_token
# is 4, we need two blocks [XXCD] [EF] to store the sliding
# window [CDEF] of 6 tokens.
return (cdiv(num_tokens, self.block_size) + 1) * self.page_size_bytes
@dataclass(frozen=True, kw_only=True) @dataclass(frozen=True, kw_only=True)
......
...@@ -110,6 +110,9 @@ class SimpleCPUOffloadScheduler: ...@@ -110,6 +110,9 @@ class SimpleCPUOffloadScheduler:
self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator( self.cpu_coordinator: KVCacheCoordinator = get_kv_cache_coordinator(
kv_cache_config=self.cpu_kv_cache_config, kv_cache_config=self.cpu_kv_cache_config,
max_model_len=vllm_config.model_config.max_model_len, max_model_len=vllm_config.model_config.max_model_len,
max_num_batched_tokens=(
vllm_config.scheduler_config.max_num_batched_tokens
),
use_eagle=False, use_eagle=False,
enable_caching=True, enable_caching=True,
enable_kv_cache_events=self.enable_kv_cache_events, enable_kv_cache_events=self.enable_kv_cache_events,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment