from __future__ import annotations

import warnings

import torch

from sglang.srt.utils import get_bool_env_var

_USE_OPT_CAT = get_bool_env_var("SGLANG_USE_OPT_CAT")

if _USE_OPT_CAT:
    try:
        from lightop import ds_cat  # type: ignore
    except ImportError:  # pragma: no cover
        ds_cat = None
        warnings.warn(
            "SGLANG_USE_OPT_CAT 已开启但无法导入 lightop.ds_cat，退回 torch.cat"
        )
else:
    ds_cat = None


def concat_decode_opt(A:torch.Tensor, B:torch.Tensor, dim:int):
    assert dim==2 , "tensor dim must be 3 and concat dim must be 2"
    output_shape = list(A.shape)
    output_shape[dim] = A.shape[dim] + B.shape[dim]  
    C = torch.empty(output_shape, device=A.device, dtype=A.dtype)
    mode = 0
    if dim!=0 :
        ds_cat( A, B, C, mode)
        return C
    assert False, "not support"