package web

import (
	"encoding/base64"
	"get-container/cmd/opsflow/backend"
	"net/http"
	"strings"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/pquerna/otp"
	"github.com/pquerna/otp/totp"
	"github.com/spf13/viper"

	swaggerFiles "github.com/swaggo/files"
	ginSwagger "github.com/swaggo/gin-swagger"
)

// swagger embed files

var (
	globalCfg *viper.Viper = nil
)

func Init(cfg *viper.Viper) {
	globalCfg = cfg
}

type RestfulResult[T any] struct {
	Code int    `json:"code"`
	Msg  string `json:"msg"`
	Data T      `json:"data,omitempty"`
}

type RestfulNoDataResult struct {
	Code int    `json:"code"`
	Msg  string `json:"msg"`
}

type RestfulListResult[T any] struct {
	Code int    `json:"code"`
	Msg  string `json:"msg"`
	Data []T    `json:"data,omitempty"`
}

func ReturnGin(ctx *gin.Context, data any, err error) {
	if err != nil {
		ctx.JSON(500, RestfulNoDataResult{
			Code: 500,
			Msg:  err.Error(),
		})
		return
	}
	ctx.JSON(200, RestfulResult[any]{
		Code: 200,
		Msg:  "ok",
		Data: data,
	})
}

func ReturnGinList[T any](ctx *gin.Context, data []T, err error) {
	if err != nil {
		ctx.JSON(500, RestfulNoDataResult{
			Code: 500,
			Msg:  err.Error(),
		})
		return
	}
	ctx.JSON(200, RestfulListResult[T]{
		Code: 200,
		Msg:  "ok",
		Data: data,
	})
}

// webAuth 接口认证中间件
func webAuth(ctx *gin.Context) {
	authHeader := ctx.GetHeader("Authorization")
	if authHeader == "" {
		ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Not find header Authorization"})
		ctx.Abort()
		return
	}
	fields := strings.SplitN(authHeader, " ", 2)
	if len(fields) != 2 || fields[0] != "Bearer" {
		ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization error"})
		ctx.Abort()
		return
	}
	code, err := base64.StdEncoding.DecodeString(fields[1])
	if err != nil {
		ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
		ctx.Abort()
		return
	}
	sec := globalCfg.GetString("auth_key")
	ok, err := totp.ValidateCustom(string(code), sec, time.Now(), totp.ValidateOpts{
		Period:    30, // 每 30 秒更新一次
		Skew:      1,  // 允许前后偏移 1 个周期（即允许 30 秒的时间误差）
		Digits:    otp.DigitsSix,
		Algorithm: otp.AlgorithmSHA1,
	})
	if err != nil {
		ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
		ctx.Abort()
		return
	}
	if ok {
		ctx.Next()
	} else {
		ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization failed"})
		ctx.Abort()
		return
	}
}

func WebServer(addr string) error {
	engine := gin.Default()
	cmdGroup := engine.Group("/api/cmd")
	if globalCfg.GetBool("debug_mode") {
		// 调试模式
		engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler))
	} else {
		// 非调试模式，添加认证
		cmdGroup.Use(webAuth)
	}

	cmdGroup.GET("/all", _controller.GetAllInfo)
	cmdGroup.GET("/loginUser", _controller.GetOnlineUser)
	cmdGroup.GET("/sysload", _controller.GetSysLoad)
	cmdGroup.GET("/dcuload", _controller.GetDCULoad)
	cmdGroup.GET("/rcclinfo", _controller.GetRcclHandler)
	cmdGroup.POST("/rccl/post", _controller.PostRcclHandler)
	return engine.Run(addr)
}

type controller struct{}

var _controller = controller{}

// GetRcclHandler godoc
// @Summary 获取 rccl all_reduce_perf 性能信息
// @Description 获取 rccl all_reduce_perf 性能信息
// @Accept json
// @Produce json
// @Success 200 {object} RestfulResult[backend.RcclTestAllReducePrefResult]
// @Failure 500 {object} RestfulNoDataResult
// @Router /rcclinfo [get]
func (c controller) GetRcclHandler(ctx *gin.Context) {
	r, err := backend.AllReducePerf(globalCfg.GetString("rccl_test_path"), globalCfg.GetString("rccl_all_reduce_perf_args"))
	ReturnGin(ctx, r, err)
}

// GetDCULoad godoc
// @Summary 获取 DCU 负载信息
// @Description 获取 DCU 负载信息
// @Accept json
// @Produce json
// @Success 200 {object} RestfulListResult[backend.DCULoad]
// @Failure 500 {object} RestfulNoDataResult
// @Router /dcuload [get]
func (c controller) GetDCULoad(ctx *gin.Context) {
	dcu, err := backend.GetDCULoad()
	ReturnGinList(ctx, dcu, err)
}

// GetOnlineUser godoc
// @Summary 获取在线用户信息
// @Description 获取在线用户信息
// @Accept json
// @Produce json
// @Success 200 {object} RestfulListResult[backend.LoginUserInfo]
// @Failure 500 {object} RestfulNoDataResult
// @Router /loginUser [get]
func (c controller) GetOnlineUser(ctx *gin.Context) {
	olu, err := backend.GetOnlineUser()
	ReturnGinList(ctx, olu, err)
}

// GetSysLoad godoc
// @Summary 获取系统负载信息
// @Description 获取系统负载信息
// @Accept json
// @Produce json
// @Success 200 {object} RestfulResult[backend.SysInfo]
// @Failure 500 {object} RestfulNoDataResult
// @Router /sysload [get]
func (c controller) GetSysLoad(ctx *gin.Context) {
	sys, err := backend.GetSysLoad()
	ReturnGin(ctx, sys, err)
}

// GetAllInfo godoc
// @Summary 获取所有信息（系统负载、DCU 负载、在线用户）
// @Description 获取所有信息（系统负载、DCU 负载、在线用户）
// @Accept json
// @Produce json
// @Success 200 {object} RestfulResult[backend.AllInfo]
// @Failure 500 {object} RestfulNoDataResult
// @Router /all [get]
func (c controller) GetAllInfo(ctx *gin.Context) {
	olu, err := backend.GetOnlineUser()
	if err != nil {
		ReturnGin(ctx, nil, err)
	}
	sys, err := backend.GetSysLoad()
	if err != nil {
		ReturnGin(ctx, nil, err)
	}
	dcu, err := backend.GetDCULoad()
	if err != nil {
		ReturnGin(ctx, nil, err)
	}
	ReturnGin(ctx, backend.AllInfo{
		DCUInfo:        dcu,
		SysInfo:        *sys,
		OnlineUserInfo: olu,
	}, err)
}

type RcclArgs struct {
	Args []string `json:"args"`
}

// PostRcclHandler godoc
// @Summary 给出rccl all_reduce_perf参数，执行单机测试
// @Description 给出rccl all_reduce_perf参数，执行单机测试
// @Accept json
// @Produce json
// @Param   args body  RcclArgs true "rccl all reduce perf args"
// @Success 200 {object} RestfulResult[backend.RcclTestAllReducePrefResult]
// @Failure 500 {object} RestfulNoDataResult
// @Router /rccl/post [post]
func (c controller) PostRcclHandler(ctx *gin.Context) {
	args := RcclArgs{}
	err := ctx.BindJSON(&args)
	if err != nil {
		ReturnGin(ctx, nil, err)
		return
	}
	if len(args.Args) == 0 {
		r, err := backend.AllReducePerf(globalCfg.GetString("rccl_test_path"), globalCfg.GetString("rccl_all_reduce_perf_args"))
		ReturnGin(ctx, r, err)
		return
	}
	arg := strings.Join(args.Args, " ")
	r, err := backend.AllReducePerf(globalCfg.GetString("rccl_test_path"), arg)
	ReturnGin(ctx, r, err)
}
