# fallback_fp8.py
# PyTorch fallback implementation for DeepGEMM-like fp8 logits ops
from sglang.srt.utils import ceil_div
import torch

@torch.no_grad()
def fallback_fp8_mqa_logits(q: torch.Tensor,
                             kv: torch.Tensor,
                             weights: torch.Tensor,
                             ks: torch.Tensor,
                             ke: torch.Tensor, cost_only: bool = False) -> torch.Tensor:
    seq_len_kv = kv.shape[0]

    if cost_only:
        start = ks.clamp(min=0, max=seq_len_kv)
        end   = ke.clamp(min=0, max=seq_len_kv)
        count_ones_per_row = (end - start).clamp(min=0)
        return count_ones_per_row.sum()

    k = kv
    q = q.float()
    k = k.float()

    mask_lo = torch.arange(0, seq_len_kv, device='cuda')[None, :] >= ks[:, None]
    mask_hi = torch.arange(0, seq_len_kv, device='cuda')[None, :] < ke[:, None]
    mask = mask_lo & mask_hi

    score = torch.einsum('mhd,nd->hmn', q, k)
    logits = (score.relu() * weights.unsqueeze(-1).transpose(0, 1)).sum(dim=0)
    logits = logits.masked_fill(~mask, float('-inf'))

    #cost = mask.sum()
    return logits
    
    # """
    # PyTorch fallback for fp8_mqa_logits. 
    # No real fp8 used, just FP32.
    # Args:
    #     q: (M, H, D) query
    #     k: (N, D) key
    #     weights: (M, H)
    #     ks: (M,) int32
    #     ke: (M,) int32
    # Returns:
    #     logits: (M, N) with -inf outside of valid range
    # """
    # M, H, D = q.shape
    # N = k[0].shape[0]
    # logits = torch.full((M, N), float("-inf"), dtype=torch.float32, device=q.device)

    # # for i in range(M):
    # #     start = max(ks[i].item(), 0)
    # #     end = min(ke[i].item(), N)
    # #     if start >= end:
    # #         continue
    # #     qi = q[i]  # (H, D)
    # #     ki = k[start:end]  # (L, D)
    # #     sim = torch.matmul(qi, ki.T)  # (H, L)
    # #     weighted_sim = (sim.relu() * weights[i].unsqueeze(-1)).sum(dim=0)  # (L,)
    # #     logits[i, start:end] = weighted_sim
    # return logits


@torch.no_grad()
def fallback_fp8_paged_mqa_logits(q: torch.Tensor,
                                   kv_cache: torch.Tensor,
                                   weights: torch.Tensor,
                                   context_lens: torch.Tensor,
                                   block_tables: torch.Tensor,
                                   max_model_len: int) -> torch.Tensor:
    
    batch_size, next_n, heads, dim = q.size()
    num_block, block_size, _, dim = kv_cache.size()
    logits = torch.full([batch_size * next_n, max_model_len], float('-inf'), device=q.device, dtype=torch.float32)
    context_lens = context_lens.tolist()
    for i in range(batch_size):
        context_len = context_lens[i]
        q_offsets = torch.arange(context_len - next_n, context_len, device=q.device)
        weight_slice = weights[i * next_n:(i + 1) * next_n, :].transpose(0, 1).contiguous()
        for block_rk in range(ceil_div(context_len, block_size)):
            block_idx = block_tables[i][block_rk]
            qx, kx = q[i], kv_cache[block_idx]
            k_offsets = torch.arange(block_rk * block_size, (block_rk + 1) * block_size, device=q.device)
            mask = (k_offsets[None, :] < context_len) & (k_offsets[None, :] <= q_offsets[:, None])
            s = torch.where(mask[None, :, :], (qx.transpose(0, 1) @ kx.transpose(0, 1).transpose(1, 2)).to(logits.dtype), float('-inf'))
            s = torch.relu(s) * weight_slice[..., None]
            s = s.sum(dim=0)
            logits[i * next_n:(i + 1) * next_n, block_rk * block_size: (block_rk + 1) * block_size] = torch.where(k_offsets[None, :] <= q_offsets[:, None], s, float('-inf'))
    return logits
    
    
    """
    PyTorch fallback for fp8_paged_mqa_logits. 
    No real fp8 used, just FP32.
    Args:
        q: (B, N, H, D)
        kv_cache: (num_blocks, block_size, 1, D)
        weights: (B * N, H)
        context_lens: (B,)
        block_tables: (B, max_blocks)
        max_model_len: int
    Returns:
        logits: (B * N, max_model_len)
    """
    B, N, H, D = q.shape
    block_size = kv_cache.shape[1]
    logits = torch.full((B * N, max_model_len), float("-inf"), dtype=torch.float32, device=q.device)

    for i in range(B):
        ctx_len = context_lens[i].item()
        q_offsets = torch.arange(ctx_len - N, ctx_len, device=q.device)
        weight_slice = weights[i * N:(i + 1) * N, :].transpose(0, 1).contiguous()

        for br in range((ctx_len + block_size - 1) // block_size):
            blk_idx = block_tables[i, br].item()
            if blk_idx < 0:
                continue
            qx = q[i]  # (N, H, D)
            kx = kv_cache[blk_idx]  # (block_size, 1, D)
            kx = kx.squeeze(1)  # (block_size, D)
            k_offsets = torch.arange(br * block_size, (br + 1) * block_size, device=q.device)

            mask = (k_offsets[None, :] < ctx_len) & (k_offsets[None, :] <= q_offsets[:, None])  # (N, block_size)
            s = torch.where(mask[None, :, :],
                            torch.einsum('nhd,ld->hnl', qx, kx),
                            torch.full((H, N, block_size), float("-inf"), device=q.device))
            s = s.relu() * weight_slice[..., None]
            logits_slice = s.sum(dim=0)  # (N, block_size)

            mask_block = (k_offsets[None, :] <= q_offsets[:, None])
            logits[i * N:(i + 1) * N, br * block_size:(br + 1) * block_size] = \
                torch.where(mask_block, logits_slice, float("-inf"))

    return logits

