import torch
import os
from loguru import logger

from lightx2v_platform.ops.attn.template import AttnWeightTemplate
from lightx2v_platform.registry_factory import PLATFORM_ATTN_WEIGHT_REGISTER

try:
    from flash_attn import sparse_attn_with_sla
    SAPRDE_LINEAR_ATTN = True
except ModuleNotFoundError:
    SAPRDE_LINEAR_ATTN = False

# Try to import Flash Attention (ROCm version 2.6.1)
try:
    from flash_attn import flash_attn_varlen_func

    FLASH_ATTN_AVAILABLE = True
    logger.info(f"Flash Attention (ROCm) is available")
except ImportError:
    logger.warning("Flash Attention not found. Will use PyTorch SDPA as fallback.")
    flash_attn_varlen_func = None
    FLASH_ATTN_AVAILABLE = False


@PLATFORM_ATTN_WEIGHT_REGISTER("flash_attn_hygon_dcu")
class FlashAttnHygonDcu(AttnWeightTemplate):
    """
    Hygon DCU Flash Attention implementation.

    Uses AMD ROCm version of Flash Attention 2.6.1 when available.
    Falls back to PyTorch SDPA (Scaled Dot Product Attention) if Flash Attention is not installed.

    Tested Environment:
    - PyTorch: 2.7.1
    - Python: 3.10
    - Flash Attention: 2.6.1 (ROCm)
    Reference: https://developer.sourcefind.cn/codes/modelzoo/wan2.1_pytorch/-/blob/master/wan/modules/attention.py
    """

    def __init__(self, weight_name="flash_attn_hygon_dcu"):
        super().__init__(weight_name)
        self.use_flash_attn = FLASH_ATTN_AVAILABLE

        if self.use_flash_attn:
            logger.info("Flash Attention 2.6.1 (ROCm) is available and will be used.")
        else:
            logger.warning("Flash Attention not available. Using PyTorch SDPA fallback.")

    def apply(
        self,
        q,
        k,
        v,
        cu_seqlens_q=None,
        cu_seqlens_kv=None,
        max_seqlen_q=None,
        max_seqlen_kv=None,
        dropout_p=0.0,
        softmax_scale=None,
        causal=False,
        window_size=(-1, -1),
        deterministic=False,
        **kwargs,
    ):
        """
        Execute Flash Attention computation with variable-length sequences.

        This method signature matches the standard LightX2V attention interface,
        compatible with other platform implementations (e.g., MLU, NVIDIA).

        Args:
            q: [B*Lq, Nq, C1] Query tensor (flattened batch)
            k: [B*Lk, Nk, C1] Key tensor (flattened batch)
            v: [B*Lk, Nk, C2] Value tensor (flattened batch)
            cu_seqlens_q: [B+1] Cumulative sequence lengths for queries
            cu_seqlens_kv: [B+1] Cumulative sequence lengths for keys/values
            max_seqlen_q: Maximum sequence length in queries
            max_seqlen_kv: Maximum sequence length in keys/values
            model_cls: Model class identifier (unused but kept for interface compatibility)
            dropout_p: Dropout probability
            softmax_scale: Scaling factor for QK^T before softmax
            causal: Whether to apply causal mask
            window_size: Sliding window size tuple (left, right)
            deterministic: Whether to use deterministic algorithm
        Returns:
            Output tensor: [B*Lq, C2] (flattened batch)
        """
        if not self.use_flash_attn:
            # Fallback to PyTorch SDPA
            return self._sdpa_fallback(q, k, v, cu_seqlens_q, max_seqlen_q, causal, dropout_p)

        # Ensure data types are half precision
        import math

        half_dtypes = (torch.float16, torch.bfloat16)
        dtype = q.dtype if q.dtype in half_dtypes else torch.bfloat16
        out_dtype = q.dtype

        def half(x):
            return x if x.dtype in half_dtypes else x.to(dtype)

        # Convert to half precision
        q_flat = half(q)
        k_flat = half(k)
        v_flat = half(v)

        # Ensure cu_seqlens tensors are on the same device as q and have correct dtype
        # Flash Attention requires these tensors to be on CUDA device with int32 dtype
        device = q.device
        if cu_seqlens_q is not None:
            if cu_seqlens_q.device != device:
                cu_seqlens_q = cu_seqlens_q.to(device, non_blocking=True)
            if cu_seqlens_q.dtype != torch.int32:
                cu_seqlens_q = cu_seqlens_q.to(torch.int32)
        if cu_seqlens_kv is not None:
            if cu_seqlens_kv.device != device:
                cu_seqlens_kv = cu_seqlens_kv.to(device, non_blocking=True)
            if cu_seqlens_kv.dtype != torch.int32:
                cu_seqlens_kv = cu_seqlens_kv.to(torch.int32)
        
        bs = cu_seqlens_q.shape[0] - 1
        
        if SAPRDE_LINEAR_ATTN and int(os.getenv('USE_SLA', 0)) and max_seqlen_q == max_seqlen_kv:
            topk_value = float(os.getenv("SPARSE_ATTN_TOPK", "0.5"))
            # print("q: ",q_flat.shape)
            # print("k: ",k_flat.shape)
            # print("v: ",v_flat.shape)
            # q_blhd = q_flat.unsqueeze(0)
            # k_blhd = k_flat.unsqueeze(0)
            # v_blhd = v_flat.unsqueeze(0)
            q_blhd = q_flat.reshape(bs, max_seqlen_q, q_flat.shape[-2], q_flat.shape[-1])
            k_blhd = k_flat.reshape(bs, max_seqlen_q, k_flat.shape[-2], k_flat.shape[-1])
            v_blhd = v_flat.reshape(bs, max_seqlen_q, v_flat.shape[-2], v_flat.shape[-1])
            # print("q_blhd: ",q_blhd.shape)
            # print("k_blhd: ",k_blhd.shape)
            # print("v_blhd: ",v_blhd.shape)

            output = sparse_attn_with_sla(
                q=q_blhd,
                k=k_blhd,
                v=v_blhd,
                topk=topk_value,
            )
            # output = output.reshape(max_seqlen_q, -1) # phb-test
            # return out.to(dtype)

        # Use Flash Attention 2.6.1 (ROCm version) with varlen interface
        else:
            # Compute softmax scale if not provided
            if softmax_scale is None:
                softmax_scale = 1.0 / math.sqrt(q.shape[-1])
            output = flash_attn_varlen_func(
                q=q_flat,
                k=k_flat,
                v=v_flat,
                cu_seqlens_q=cu_seqlens_q,
                cu_seqlens_k=cu_seqlens_kv,
                max_seqlen_q=max_seqlen_q,
                max_seqlen_k=max_seqlen_kv,
                dropout_p=dropout_p,
                softmax_scale=softmax_scale,
                causal=causal,
                window_size=window_size,
                deterministic=deterministic,
            )
        '''
        # Compute softmax scale if not provided
        if softmax_scale is None:
            softmax_scale = 1.0 / math.sqrt(q.shape[-1])

        # Use Flash Attention 2.6.1 (ROCm version) with varlen interface
        output = flash_attn_varlen_func(
            q=q_flat,
            k=k_flat,
            v=v_flat,
            cu_seqlens_q=cu_seqlens_q,
            cu_seqlens_k=cu_seqlens_kv,
            max_seqlen_q=max_seqlen_q,
            max_seqlen_k=max_seqlen_kv,
            dropout_p=dropout_p,
            softmax_scale=softmax_scale,
            causal=causal,
            window_size=window_size,
            deterministic=deterministic,
        )

        # Reshape to [B*max_seqlen_q, num_heads * head_dim]
        bs = cu_seqlens_q.shape[0] - 1
        '''
        output = output.reshape(bs * max_seqlen_q, -1)
        return output.to(out_dtype)

    def _sdpa_fallback(self, q, k, v, cu_seqlens_q, max_seqlen_q, causal=False, dropout_p=0.0):
        """
        Fallback to PyTorch Scaled Dot Product Attention when Flash Attention is not available.

        Args:
            q: [B*Lq, Nq, C] Query tensor (flattened batch)
            k: [B*Lk, Nk, C] Key tensor (flattened batch)
            v: [B*Lk, Nk, C] Value tensor (flattened batch)
            cu_seqlens_q: [B+1] Cumulative sequence lengths for queries
            max_seqlen_q: Maximum sequence length in queries
            causal: Whether to apply causal mask
            dropout_p: Dropout probability
        Returns:
            Output tensor: [B*Lq, C] (flattened batch)
        """
        # Reshape from flattened format to batched format
        bs = cu_seqlens_q.shape[0] - 1

        # Reshape q, k, v to [B, L, Nq, C]
        q = q.reshape(bs, max_seqlen_q, q.shape[-2], q.shape[-1])
        k = k.reshape(bs, max_seqlen_q, k.shape[-2], k.shape[-1])
        v = v.reshape(bs, max_seqlen_q, v.shape[-2], v.shape[-1])

        # Transpose to [B, Nq, L, C] for SDPA
        q = q.transpose(1, 2)
        k = k.transpose(1, 2)
        v = v.transpose(1, 2)

        out = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, is_causal=causal, dropout_p=dropout_p)

        # Transpose back to [B, L, Nq, C] and flatten
        out = out.transpose(1, 2).contiguous()
        out = out.reshape(bs * max_seqlen_q, -1)

        return out
