from flash_attn import (
    flash_attn_varlen_func as flash_attn_varlen_func_interface, 
    flash_attn_with_kvcache as flash_attn_with_kvcache_interface
)
from typing import Optional, Union

import torch

MAX_FLASH_ATTN_KERNEL_HEADDIM = 256

def flash_attn_with_kvcache(
    q,
    k_cache,
    v_cache,
    k=None,
    v=None,
    qv=None,
    rotary_cos=None,
    rotary_sin=None,
    cache_seqlens: Optional[Union[int, torch.Tensor]] = None,
    cache_batch_idx: Optional[torch.Tensor] = None,
    cache_leftpad: Optional[torch.Tensor] = None,
    page_table: Optional[torch.Tensor] = None,
    cu_seqlens_q: Optional[torch.Tensor] = None,
    cu_seqlens_k_new: Optional[torch.Tensor] = None,
    max_seqlen_q: Optional[int] = None,
    rotary_seqlens: Optional[torch.Tensor] = None,
    q_descale: Optional[torch.Tensor] = None,
    k_descale: Optional[torch.Tensor] = None,
    v_descale: Optional[torch.Tensor] = None,
    softmax_scale=None,
    causal=False,
    window_size=(-1, -1),  # -1 means infinite context window
    attention_chunk: Optional[int] = None,
    softcap=0.0,  # 0.0 means deactivated
    rotary_interleaved=True,
    scheduler_metadata=None,
    num_splits=0,  # Can be tuned for speed
    pack_gqa=None,  # Can be tuned for speed
    sm_margin=0,  # Can be tuned if some SMs are used for communication
    return_softmax_lse=False,
    sinks=None,
    ver=3,
):
    if cu_seqlens_q is not None and q.shape[0] != cu_seqlens_q.shape[0] * max_seqlen_q:
        v_cache = v_cache.view(-1, v_cache.shape[-2], v_cache.shape[-1])
        if v_cache.shape[-1] > MAX_FLASH_ATTN_KERNEL_HEADDIM:
            out_1 = flash_attn_varlen_func_interface(
                            q=q,                   # (total_q, num_heads, head_size_og)
                            k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
                            v=v_cache[:, :, :MAX_FLASH_ATTN_KERNEL_HEADDIM], # (total_k, num_heads_k, head_size_og) 
                            cu_seqlens_q=cu_seqlens_q,
                            cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
                            max_seqlen_q=max_seqlen_q,
                            max_seqlen_k=max_seqlen_q,
                            softmax_scale=softmax_scale,
                            causal=causal,
                        )
            out_2 = flash_attn_varlen_func_interface(
                            q=q,                   # (total_q, num_heads, head_size_og)
                            k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
                            v=v_cache[:, :, MAX_FLASH_ATTN_KERNEL_HEADDIM:], # (total_k, num_heads_k, head_size_og) 
                            cu_seqlens_q=cu_seqlens_q,
                            cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
                            max_seqlen_q=max_seqlen_q,
                            max_seqlen_k=max_seqlen_q,
                            softmax_scale=softmax_scale,
                            causal=causal,
                        )
            return torch.cat([out_1, out_2], dim=-1)
        else:
            return flash_attn_varlen_func_interface(
                        q=q,                   # (total_q, num_heads, head_size_og)
                        k=k_cache.view(-1, k_cache.shape[-2], k_cache.shape[-1]), # (total_k, num_heads_k, head_size_og)
                        v=v_cache.view(-1, v_cache.shape[-2], v_cache.shape[-1]), # (total_k, num_heads_k, head_size_og) 
                        cu_seqlens_q=cu_seqlens_q,
                        cu_seqlens_k=cu_seqlens_k_new if cu_seqlens_k_new is not None else None,
                        max_seqlen_q=max_seqlen_q,
                        max_seqlen_k=max_seqlen_q,
                        softmax_scale=softmax_scale,
                        causal=causal,
                    )
    else:
        return flash_attn_with_kvcache_interface(
                q=q.contiguous().view(-1, max_seqlen_q, q.shape[-2], q.shape[-1]), 
                k_cache=k_cache,                                            
                v_cache=v_cache,                                               
                block_table=page_table,                                     
                cache_seqlens=cache_seqlens,                                  
                softmax_scale=softmax_scale,                                   
                causal=causal,                
                window_size=window_size,                                     
                softcap=softcap,                                
                return_softmax_lse=return_softmax_lse,                           
                num_splits=num_splits,                                   
            )

def flash_attn_varlen_func(
    q,
    k,
    v,
    cu_seqlens_q,
    cu_seqlens_k,
    max_seqlen_q=None,
    max_seqlen_k=None,
    seqused_q=None,
    seqused_k=None,
    page_table=None,
    softmax_scale=None,
    causal=False,
    qv=None,
    q_descale=None,
    k_descale=None,
    v_descale=None,
    window_size=(-1, -1),
    attention_chunk=0,
    softcap=0.0,
    num_splits=1,
    pack_gqa=None,
    sm_margin=0,
    return_softmax_lse=False,
    sinks=None,
    ver=3,
):
    return flash_attn_varlen_func_interface(
                q=q,
                k=k,
                v=v,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_q,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_q,
                softmax_scale=softmax_scale,
                causal=causal,
            )