""" Compactor 压缩:与 kvpress ``CompactorPress`` / ``LeverageScorePress`` / ``NonCausalAttnPress`` 算法对齐(Cholesky 杠杆分、右高斯 sketch、非因果分块注意力无 1/sqrt(d) 缩放、×||V||、avg_pool、 全局 z-score、blending 与首尾 sink pad)。 非因果分块注意力与 ``×||V||``+``avg_pool1d(k=3)`` 在 CUDA 上为 Triton;非 CUDA 回退 PyTorch。 杠杆分路径使用 batched ``torch.matmul``;在 transpose 与进入线性代数前对张量 ``.contiguous()``。 CUDA 上用 ``cholesky_solve``;在 HIP/ROCm 上对小的 sketch 维 ``k`` 用 ``linalg.inv(G+λI) @ X^T`` 代替 ``cholesky_solve``,避开 rocBLAS TRSM 的 launch-bounds 告警与部分栈上的不稳定行为。 非因果 PyTorch 回退同理。 """ from __future__ import annotations import math from typing import List, Optional import torch import triton import triton.language as tl from transformers.models.llama.modeling_llama import repeat_kv from vllm.kvprune.compression.common import BaseCompressionMethod from vllm.kvprune.utils.helpers import maybe_execute_in_stream def resolve_kvpress_compactor_blending(compression_context) -> float: """与 kvpress ``CompactorPress.score`` 相同:``blending`` 或 ``compression_ratio``,再否则 0.35。""" if compression_context is None: return 0.35 b = getattr(compression_context, "compactor_blending", None) if b is not None: return float(b) cr = getattr(compression_context, "compression_ratio", None) if cr is not None: return float(cr) return 0.35 class CompactorCompression(BaseCompressionMethod): """与 kvpress ``CompactorPress`` / ``NonCausalAttnPress`` 默认 ``chunk_size=256`` 一致。""" chunk_size: int = 256 @staticmethod def pre_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, context ) -> Optional[torch.Tensor]: compression_context = context.compression_context return maybe_execute_in_stream( kvpress_leverage_scores_packed, k, context.cu_seqlens_q, compression_context, STORE_STREAM=context.STORE_STREAM, ) @staticmethod def post_rope_scoring( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, pre_rope_scores: torch.Tensor, context, ) -> Optional[torch.Tensor]: compression_context = context.compression_context blending = resolve_kvpress_compactor_blending(compression_context) return maybe_execute_in_stream( kvpress_compactor_post_rope, q, k, v, context.cu_seqlens_q, pre_rope_scores, compression_context, context.max_seqlen_q, chunk_size=CompactorCompression.chunk_size, blending=float(blending), STORE_STREAM=context.STORE_STREAM, ) # --------------------------------------------------------------------------- # Cholesky 杠杆分(kvpress ``LeverageScorePress``) # --------------------------------------------------------------------------- def chol_with_jitter( G: torch.Tensor, jitter: float = 0.0, max_tries: int = 5 ) -> torch.Tensor: identity = torch.eye(G.shape[-1], device=G.device, dtype=G.dtype) cur = float(jitter) for _ in range(max_tries): L, info = torch.linalg.cholesky_ex( (G + cur * identity).contiguous(), upper=False ) if bool((info == 0).all()): return L cur = max(1e-8, (1e-2 if cur == 0.0 else 10.0 * cur)) raise RuntimeError(f"Cholesky failed after {max_tries} tries.") def compute_leverage_scores_mid( key_states: torch.Tensor, sketch_dimension: int ) -> torch.Tensor: """ 与 kvpress ``LeverageScorePress.compute_leverage_scores`` 相同;输入 ``[L, H, D]``, 返回 ``[L, H]``(未 z-score)。 维序与 kvpress 的 ``(B, H, S, D)`` 对齐;batched GEMM + ``.contiguous()`` 以利于后端库。 """ d, k = key_states.shape[-1], sketch_dimension device, dtype = key_states.device, key_states.dtype H = key_states.shape[1] Phi = torch.randn(1, H, d, k, device=device, dtype=dtype) * (1.0 / math.sqrt(k)) X0 = key_states.transpose(0, 1).unsqueeze(0).contiguous() X = (X0 - X0.mean(dim=-2, keepdim=True)).contiguous() Phi = Phi.contiguous() X = torch.matmul(X, Phi).to(torch.float32).contiguous() XT = X.transpose(-2, -1).contiguous() G = torch.matmul(XT, X) G_sym = 0.5 * (G + G.transpose(-2, -1)).contiguous() # HIP: avoid batched cholesky_solve -> rocBLAS TRSM (launch_bounds noise / edge cases). # k is sketch_dim (typically modest); inv is O(k^3) but batched over heads. if torch.version.hip is not None: kk = G_sym.shape[-1] eye = torch.eye( kk, device=G_sym.device, dtype=G_sym.dtype, requires_grad=False ) G_reg = G_sym + 1e-2 * eye inv_Xt = torch.linalg.inv(G_reg) @ XT else: L_mat = chol_with_jitter(G_sym, jitter=1e-2, max_tries=5) inv_Xt = torch.cholesky_solve(XT, L_mat, upper=False) inv_Xt_T = inv_Xt.transpose(-2, -1).contiguous() scores = (X * inv_Xt_T).sum(dim=-1).clamp_min(0) return scores.squeeze(0).transpose(0, 1).contiguous() def kvpress_leverage_scores_packed( key_states: torch.Tensor, cu_seqlens: torch.Tensor, compression_ctx, ) -> torch.Tensor: device = key_states.device N, Hkv, _D = key_states.shape sketch_dim = int(getattr(compression_ctx, "sketch_dimension", 48)) sink_start = int(getattr(compression_ctx, "sink_size_start", 8)) sink_end = int(getattr(compression_ctx, "sink_size_end", 4)) out = torch.zeros(N, Hkv, device=device, dtype=torch.float32) mids_flat: list[torch.Tensor] = [] mid_ranges: list[tuple[int, int, int]] = [] for b in range(cu_seqlens.numel() - 1): k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) L = k_end - k_beg if L == 0: continue left_keep = min(sink_start, L) right_keep = min(sink_end, max(0, L - left_keep)) mid_start = k_beg + left_keep mid_end = k_end - right_keep if mid_start >= mid_end: continue k_mid = key_states[mid_start:mid_end, :, :].contiguous() raw = compute_leverage_scores_mid(k_mid, sketch_dim) mids_flat.append(raw.reshape(-1)) mid_ranges.append((mid_start, mid_end, Hkv)) if not mids_flat: return out flat = torch.cat(mids_flat, dim=0) z = _zscore_flat_f32_global(flat) offset = 0 for (mid_start, mid_end, _Hkv), r in zip(mid_ranges, mids_flat): n = r.numel() seg = z[offset : offset + n].view(mid_end - mid_start, Hkv) out[mid_start:mid_end, :] = seg offset += n return out # --------------------------------------------------------------------------- # 非因果分块注意力(kvpress ``NonCausalAttnPress.non_causal_chunked_attn``)— Triton # --------------------------------------------------------------------------- def _non_causal_chunked_attn_pytorch( q: torch.Tensor, k: torch.Tensor, chunk_size: int ) -> torch.Tensor: """参考实现:与 kvpress 逐算子一致。""" assert chunk_size > 0 and q.shape == k.shape L, H, d = q.shape B = 1 q = q.permute(1, 0, 2).unsqueeze(0).contiguous() k = k.permute(1, 0, 2).unsqueeze(0).contiguous() _B, H, S, _d = k.shape S_pad = math.ceil(S / chunk_size) * chunk_size pad_len = S_pad - S if pad_len > 0: q_padded = torch.cat( [q, torch.zeros(B, H, pad_len, d, device=q.device, dtype=q.dtype)], dim=2 ) k_padded = torch.cat( [k, torch.zeros(B, H, pad_len, d, device=k.device, dtype=k.dtype)], dim=2 ) last_chunk_start = (S // chunk_size) * chunk_size in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size) else: q_padded, k_padded = q, k last_chunk_start = ((S - 1) // chunk_size) * chunk_size in_valid = torch.arange(last_chunk_start, S_pad, device=q.device) >= S query_mask = key_mask = in_valid.view(1, 1, chunk_size).expand(B, H, chunk_size) num_chunks = S_pad // chunk_size q_chunks = q_padded.contiguous().view(B, H, num_chunks, chunk_size, d) k_chunks = k_padded.contiguous().view(B, H, num_chunks, chunk_size, d) dots = torch.matmul( q_chunks, k_chunks.transpose(-2, -1).contiguous() ) dots[:, :, -1].masked_fill_(query_mask.unsqueeze(-1), 0) dots[:, :, -1].masked_fill_(key_mask.unsqueeze(-2), -1e-9) attn = torch.softmax(dots.to(torch.float32), dim=-1) out = attn.sum(dim=-2).view(B, H, S_pad)[..., :S] return out.squeeze(0).transpose(0, 1).contiguous() @triton.jit def _non_causal_chunk_row_kernel( Q_ptr, K_ptr, Out_ptr, stride_qh, stride_qs, stride_qd, stride_kh, stride_ks, stride_kd, stride_oh, stride_os, S, S_pad, num_chunks, CHUNK_SIZE: tl.constexpr, D: tl.constexpr, BLOCK_D: tl.constexpr, ND: tl.constexpr, ): """ 每个 program:一个 head、一个 chunk、一条 query 行。 对 logits 行做 softmax(dim=-1),再对 key 列 j 做 atomic_add 累加到输出(与 sum over query 等价)。 """ h = tl.program_id(0) c = tl.program_id(1) iq = tl.program_id(2) g_i = c * CHUNK_SIZE + iq offs_j = tl.arange(0, CHUNK_SIZE) logits = tl.zeros([CHUNK_SIZE], dtype=tl.float32) for db in range(ND): offs_d = tl.arange(0, BLOCK_D) + db * BLOCK_D mask_d = offs_d < D q_off = ( h * stride_qh + g_i * stride_qs + offs_d * stride_qd ) qd = tl.load(Q_ptr + q_off, mask=mask_d, other=0.0).to(tl.float32) g_j = c * CHUNK_SIZE + offs_j k_row_off = h * stride_kh + g_j[:, None] * stride_ks + offs_d[None, :] * stride_kd kj = tl.load(K_ptr + k_row_off, mask=mask_d[None, :], other=0.0).to(tl.float32) logits += tl.sum(qd[None, :] * kj, axis=1) row_invalid = g_i >= S g_j_all = c * CHUNK_SIZE + offs_j col_invalid = g_j_all >= S logits = tl.where(row_invalid, tl.zeros([CHUNK_SIZE], dtype=tl.float32), logits) logits = tl.where( row_invalid, logits, tl.where(col_invalid, tl.full([CHUNK_SIZE], -1e-9, dtype=tl.float32), logits), ) m = tl.max(logits) logits = logits - m exp_v = tl.exp(logits) denom = tl.sum(exp_v) p = exp_v / denom out_base = h * stride_oh + g_j_all * stride_os tl.atomic_add(Out_ptr + out_base, p, mask=g_j_all < S) def _non_causal_chunked_attn_triton( q: torch.Tensor, k: torch.Tensor, chunk_size: int ) -> torch.Tensor: """CUDA Triton:与 ``_non_causal_chunked_attn_pytorch`` 同算法。""" assert q.is_cuda and k.is_cuda and q.shape == k.shape L, H, d = q.shape assert chunk_size > 0 S_pad = math.ceil(L / chunk_size) * chunk_size pad_len = S_pad - L if pad_len > 0: zq = torch.zeros( pad_len, H, d, device=q.device, dtype=q.dtype, requires_grad=False ) zk = torch.zeros( pad_len, H, d, device=k.device, dtype=k.dtype, requires_grad=False ) q = torch.cat([q, zq], dim=0) k = torch.cat([k, zk], dim=0) Q = q.transpose(0, 1).contiguous().to(dtype=torch.float32) K = k.transpose(0, 1).contiguous().to(dtype=torch.float32) num_chunks = S_pad // chunk_size out_acc = torch.zeros(H, S_pad, device=q.device, dtype=torch.float32) S = int(L) grid = (H, num_chunks, chunk_size) BLOCK_D = 32 if d <= 128 else 64 ND = (d + BLOCK_D - 1) // BLOCK_D _non_causal_chunk_row_kernel[grid]( Q, K, out_acc, Q.stride(0), Q.stride(1), Q.stride(2), K.stride(0), K.stride(1), K.stride(2), out_acc.stride(0), out_acc.stride(1), S, S_pad, int(num_chunks), CHUNK_SIZE=chunk_size, D=d, BLOCK_D=BLOCK_D, ND=ND, num_warps=4, ) return out_acc[:, :S].transpose(0, 1).contiguous() def non_causal_chunked_attn(q: torch.Tensor, k: torch.Tensor, chunk_size: int) -> torch.Tensor: """q, k: ``[L, H, d]`` → ``[L, H]``;**无** ``1/sqrt(d)``。CUDA 用 Triton,否则 PyTorch。""" if q.is_cuda and k.is_cuda: return _non_causal_chunked_attn_triton(q, k, chunk_size) return _non_causal_chunked_attn_pytorch(q, k, chunk_size) # --------------------------------------------------------------------------- # ×||V|| + avg_pool1d(k=3) — Triton(CUDA) # --------------------------------------------------------------------------- @triton.jit def _mul_vnorm_avgpool3_kernel( A_ptr, V_ptr, OUT_ptr, stride_al, stride_ah, stride_vl, stride_vh, stride_vd, stride_ol, stride_oh, L, D: tl.constexpr, ): """Triton 不支持嵌套 def;``t_at`` 逻辑对 ``l-1,l,l+1`` 各展开一份。""" l = tl.program_id(0) h = tl.program_id(1) offs = tl.arange(0, D) pos_m1 = l - 1 inb_m1 = (pos_m1 >= 0) & (pos_m1 < L) ps_m1 = tl.where(inb_m1, pos_m1, 0) a_m1 = tl.load( A_ptr + ps_m1 * stride_al + h * stride_ah, mask=inb_m1, other=0.0, ).to(tl.float32) v_m1 = tl.load( V_ptr + ps_m1 * stride_vl + h * stride_vh + offs * stride_vd, mask=inb_m1, other=0.0, ).to(tl.float32) s_m1 = tl.where(inb_m1, a_m1 * tl.sqrt(tl.sum(v_m1 * v_m1)), 0.0) inb_0 = (l >= 0) & (l < L) ps0 = tl.where(inb_0, l, 0) a0 = tl.load( A_ptr + ps0 * stride_al + h * stride_ah, mask=inb_0, other=0.0, ).to(tl.float32) v0 = tl.load( V_ptr + ps0 * stride_vl + h * stride_vh + offs * stride_vd, mask=inb_0, other=0.0, ).to(tl.float32) s_0 = tl.where(inb_0, a0 * tl.sqrt(tl.sum(v0 * v0)), 0.0) pos_p1 = l + 1 inb_p1 = (pos_p1 >= 0) & (pos_p1 < L) ps_p1 = tl.where(inb_p1, pos_p1, 0) a_p1 = tl.load( A_ptr + ps_p1 * stride_al + h * stride_ah, mask=inb_p1, other=0.0, ).to(tl.float32) v_p1 = tl.load( V_ptr + ps_p1 * stride_vl + h * stride_vh + offs * stride_vd, mask=inb_p1, other=0.0, ).to(tl.float32) s_p1 = tl.where(inb_p1, a_p1 * tl.sqrt(tl.sum(v_p1 * v_p1)), 0.0) out = (s_m1 + s_0 + s_p1) * (1.0 / 3.0) tl.store(OUT_ptr + l * stride_ol + h * stride_oh, out) def _mul_vnorm_avgpool3_fused( a: torch.Tensor, v: torch.Tensor, out: torch.Tensor | None = None ) -> torch.Tensor: assert a.dim() == 2 and v.dim() == 3 and a.shape[0] == v.shape[0] and a.shape[1] == v.shape[1] L, H, D = v.shape a = a.contiguous() v = v.contiguous() if a.dtype != torch.float32: a = a.float() if out is None: out = torch.empty((L, H), device=v.device, dtype=torch.float32) if L == 0 or H == 0: return out grid = (L, H) _mul_vnorm_avgpool3_kernel[grid]( a, v, out, a.stride(0), a.stride(1), v.stride(0), v.stride(1), v.stride(2), out.stride(0), out.stride(1), L, D=D, num_warps=4, ) return out def _maybe_mul_vnorm_avgpool3_fused(a: torch.Tensor, v: torch.Tensor) -> torch.Tensor: if not a.is_cuda or not v.is_cuda: import torch.nn.functional as F s = a * v.norm(dim=-1) return ( F.avg_pool1d(s.transpose(0, 1).unsqueeze(0), kernel_size=3, padding=1, stride=1) .squeeze(0) .transpose(0, 1) ) return _mul_vnorm_avgpool3_fused(a, v) @triton.jit def _zscore_elem_1d_kernel( X_ptr, OUT_ptr, n, mean, inv_std, BLOCK: tl.constexpr, ): pid = tl.program_id(0) offs = pid * BLOCK + tl.arange(0, BLOCK) mask = offs < n x = tl.load(X_ptr + offs, mask=mask, other=0.0) tl.store(OUT_ptr + offs, (x - mean) * inv_std, mask=mask) def _zscore_flat_f32_global(x: torch.Tensor) -> torch.Tensor: """ 与 kvpress ``(t - t.mean()) / t.std()`` 一致的一维全局 z-score。 ``mean/std`` 用 PyTorch;CUDA 上缩放阶段用 Triton 逐元素写入。 """ if x.numel() == 0: return x mu = x.mean() sig = x.std().clamp_min(1e-6) inv = 1.0 / sig if not x.is_cuda: return (x - mu) * inv x = x.contiguous() out = torch.empty_like(x) n = x.numel() BLOCK = 1024 grid = (triton.cdiv(n, BLOCK),) _zscore_elem_1d_kernel[grid]( x, out, n, float(mu.item()), float(inv.item()), BLOCK=BLOCK, num_warps=4, ) return out def _attn_scores_kvpress_middle( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, sink_start: int, sink_end: int, chunk_size: int, do_zscore: bool = True, ) -> torch.Tensor: """仅中间子序列上的非因果分 + ×||V|| + avg_pool;输出全长 ``[N, Hkv]``,非中间为 0。""" N, HQ, D = q.shape Hkv = k.shape[1] G = HQ // Hkv device = q.device attn_out = torch.zeros(N, Hkv, device=device, dtype=torch.float32) parts: list[torch.Tensor] = [] for b in range(cu_seqlens.numel() - 1): k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) L = k_end - k_beg if L == 0: continue left_keep = min(sink_start, L) right_keep = min(sink_end, max(0, L - left_keep)) mid_start = k_beg + left_keep mid_end = k_end - right_keep if mid_start >= mid_end: continue q_m = q[mid_start:mid_end, :, :].contiguous() k_m = k[mid_start:mid_end, :, :].contiguous() v_m = v[mid_start:mid_end, :, :].contiguous() # HF ``repeat_kv`` 约定:``[batch, num_kv_heads, seq_len, head_dim]`` k_4d = k_m.unsqueeze(0).transpose(1, 2).contiguous() # [1, Hkv, Lm, D] k_rep = repeat_kv(k_4d, G)[0].transpose(0, 1).contiguous() # [Lm, HQ, D] A = non_causal_chunked_attn(q_m, k_rep, chunk_size) Lm, HQa = A.shape assert HQa == HQ A = A.view(Lm, Hkv, G).mean(dim=-1) scores = _maybe_mul_vnorm_avgpool3_fused(A, v_m) parts.append(scores.reshape(-1)) if not parts: return attn_out flat_a = torch.cat(parts, dim=0) if do_zscore: z_a = _zscore_flat_f32_global(flat_a) else: z_a = flat_a offset = 0 for b in range(cu_seqlens.numel() - 1): k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) L = k_end - k_beg if L == 0: continue left_keep = min(sink_start, L) right_keep = min(sink_end, max(0, L - left_keep)) mid_start = k_beg + left_keep mid_end = k_end - right_keep if mid_start >= mid_end: continue n = (mid_end - mid_start) * Hkv attn_out[mid_start:mid_end, :] = z_a[offset : offset + n].view( mid_end - mid_start, Hkv ) offset += n return attn_out def non_causal_attn_scores( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens_qk: torch.Tensor, max_seqlen_qk: int, chunk_size: int, sm_scale: float = None, normalize: bool = True, context_lens: Optional[List[int]] = None, protected_first_tokens: Optional[List[int]] = None, protected_last_tokens: Optional[List[int]] = None, *, accum_scores: torch.Tensor = None, accum_blending: float = None, ) -> torch.Tensor: """ 与 kvpress 非因果分支一致(**忽略** ``sm_scale``:点积不乘 ``1/sqrt(d)``)。 ``normalize=True``:对中间子序列拼接后做全局 z-score(与单独非因果 press 一致)。 然后 ``out += accum_blending * accum_scores``(若给定);最后可对首尾 protected 置 ``inf``。 """ del sm_scale, max_seqlen_qk sink_start, sink_end = 8, 4 out = _attn_scores_kvpress_middle( q, k, v, cu_seqlens_qk, sink_start, sink_end, chunk_size, do_zscore=normalize, ) if accum_scores is not None: w = 0.5 if accum_blending is None else float(accum_blending) out = out + w * accum_scores.to(device=out.device, dtype=out.dtype) if protected_first_tokens is not None and protected_last_tokens is not None and context_lens: start = 0 for first, last, Lc in zip( protected_first_tokens, protected_last_tokens, context_lens ): out[start : start + int(first)].fill_(torch.inf) out[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf) start += int(Lc) return out def kvpress_compactor_post_rope( q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, cu_seqlens: torch.Tensor, pre_rope_scores: torch.Tensor, compression_ctx, max_seqlen_q: int, chunk_size: int, blending: float, ) -> torch.Tensor: del max_seqlen_q Hkv = k.shape[1] device = q.device sink_start = int(getattr(compression_ctx, "sink_size_start", 8)) sink_end = int(getattr(compression_ctx, "sink_size_end", 4)) context_lens: Optional[List[int]] = getattr( compression_ctx, "context_lens", None ) protected_first: Optional[List[int]] = getattr( compression_ctx, "protected_first_tokens", None ) protected_last: Optional[List[int]] = getattr( compression_ctx, "protected_last_tokens", None ) attn_out = _attn_scores_kvpress_middle( q, k, v, cu_seqlens, sink_start, sink_end, chunk_size ) lev = pre_rope_scores.to(device=device, dtype=torch.float32) blended = torch.zeros_like(lev) for b in range(cu_seqlens.numel() - 1): k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) L = k_end - k_beg if L == 0: continue left_keep = min(sink_start, L) right_keep = min(sink_end, max(0, L - left_keep)) mid_start = k_beg + left_keep mid_end = k_end - right_keep if mid_start >= mid_end: continue blended[mid_start:mid_end, :] = ( blending * lev[mid_start:mid_end, :] + attn_out[mid_start:mid_end, :] ) pad_val = blended.max() if not torch.isfinite(pad_val) or pad_val == 0: pad_val = torch.tensor(1.0, device=device, dtype=torch.float32) for b in range(cu_seqlens.numel() - 1): k_beg = int(cu_seqlens[b].item()) k_end = int(cu_seqlens[b + 1].item()) L = k_end - k_beg if L == 0: continue left_keep = min(sink_start, L) right_keep = min(sink_end, max(0, L - left_keep)) mid_start = k_beg + left_keep mid_end = k_end - right_keep if left_keep > 0: blended[k_beg:mid_start, :] = pad_val if right_keep > 0: blended[mid_end:k_end, :] = pad_val if protected_first is not None and protected_last is not None and context_lens: start = 0 for first, last, Lc in zip( protected_first, protected_last, context_lens ): blended[start : start + int(first)].fill_(torch.inf) blended[start + int(Lc) - int(last) : start + int(Lc)].fill_(torch.inf) start += int(Lc) return blended