package web import ( "encoding/base64" "get-container/cmd/opsflow/backend" "net/http" "strings" "time" "github.com/gin-gonic/gin" "github.com/pquerna/otp" "github.com/pquerna/otp/totp" "github.com/spf13/viper" swaggerFiles "github.com/swaggo/files" ginSwagger "github.com/swaggo/gin-swagger" ) // swagger embed files var ( globalCfg *viper.Viper = nil ) func Init(cfg *viper.Viper) { globalCfg = cfg } type RestfulResult[T any] struct { Code int `json:"code"` Msg string `json:"msg"` Data T `json:"data,omitempty"` } type RestfulNoDataResult struct { Code int `json:"code"` Msg string `json:"msg"` } type RestfulListResult[T any] struct { Code int `json:"code"` Msg string `json:"msg"` Data []T `json:"data,omitempty"` } func ReturnGin(ctx *gin.Context, data any, err error) { if err != nil { ctx.JSON(500, RestfulNoDataResult{ Code: 500, Msg: err.Error(), }) return } ctx.JSON(200, RestfulResult[any]{ Code: 200, Msg: "ok", Data: data, }) } func ReturnGinList[T any](ctx *gin.Context, data []T, err error) { if err != nil { ctx.JSON(500, RestfulNoDataResult{ Code: 500, Msg: err.Error(), }) return } ctx.JSON(200, RestfulListResult[T]{ Code: 200, Msg: "ok", Data: data, }) } func webAuth(ctx *gin.Context) { authHeader := ctx.GetHeader("Authorization") if authHeader == "" { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Not find header Authorization"}) ctx.Abort() return } fields := strings.SplitN(authHeader, " ", 2) if len(fields) != 2 || fields[0] != "Bearer" { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization error"}) ctx.Abort() return } code, err := base64.StdEncoding.DecodeString(fields[1]) if err != nil { ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) ctx.Abort() return } sec := globalCfg.GetString("auth_key") ok, err := totp.ValidateCustom(string(code), sec, time.Now(), totp.ValidateOpts{ Period: 30, // 每 30 秒更新一次 Skew: 1, // 允许前后偏移 1 个周期(即允许 30 秒的时间误差) Digits: otp.DigitsSix, Algorithm: otp.AlgorithmSHA1, }) if err != nil { ctx.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()}) ctx.Abort() return } if ok { ctx.Next() } else { ctx.JSON(http.StatusUnauthorized, gin.H{"error": "Authorization failed"}) ctx.Abort() return } } func WebServer(addr string) error { engine := gin.Default() cmdGroup := engine.Group("/api/cmd") if globalCfg.GetBool("debug_mode") { // 调试模式 engine.GET("/swagger/*any", ginSwagger.WrapHandler(swaggerFiles.Handler)) } else { // 非调试模式,添加认证 cmdGroup.Use(webAuth) } cmdGroup.GET("/all", _controller.GetAllInfo) cmdGroup.GET("/loginUser", _controller.GetOnlineUser) cmdGroup.GET("/sysload", _controller.GetSysLoad) cmdGroup.GET("/dcuload", _controller.GetDCULoad) cmdGroup.GET("/rcclinfo", _controller.GetRcclHandler) cmdGroup.POST("/rccl/post", _controller.PostRcclHandler) return engine.Run(addr) } type controller struct{} var _controller = controller{} // GetRcclHandler godoc // @Summary 获取 rccl all_reduce_perf 性能信息 // @Description 获取 rccl all_reduce_perf 性能信息 // @Accept json // @Produce json // @Success 200 {object} RestfulResult[backend.RcclTestAllReducePrefResult] // @Failure 500 {object} RestfulNoDataResult // @Router /rcclinfo [get] func (c controller) GetRcclHandler(ctx *gin.Context) { r, err := backend.AllReducePerf(globalCfg.GetString("rccl_test_path"), globalCfg.GetString("rccl_all_reduce_perf_args")) ReturnGin(ctx, r, err) } // GetDCULoad godoc // @Summary 获取 DCU 负载信息 // @Description 获取 DCU 负载信息 // @Accept json // @Produce json // @Success 200 {object} RestfulListResult[backend.DCULoad] // @Failure 500 {object} RestfulNoDataResult // @Router /dcuload [get] func (c controller) GetDCULoad(ctx *gin.Context) { dcu, err := backend.GetDCULoad() ReturnGinList(ctx, dcu, err) } // GetOnlineUser godoc // @Summary 获取在线用户信息 // @Description 获取在线用户信息 // @Accept json // @Produce json // @Success 200 {object} RestfulListResult[backend.LoginUserInfo] // @Failure 500 {object} RestfulNoDataResult // @Router /loginUser [get] func (c controller) GetOnlineUser(ctx *gin.Context) { olu, err := backend.GetOnlineUser() ReturnGinList(ctx, olu, err) } // GetSysLoad godoc // @Summary 获取系统负载信息 // @Description 获取系统负载信息 // @Accept json // @Produce json // @Success 200 {object} RestfulResult[backend.SysInfo] // @Failure 500 {object} RestfulNoDataResult // @Router /sysload [get] func (c controller) GetSysLoad(ctx *gin.Context) { sys, err := backend.GetSysLoad() ReturnGin(ctx, sys, err) } // GetAllInfo godoc // @Summary 获取所有信息(系统负载、DCU 负载、在线用户) // @Description 获取所有信息(系统负载、DCU 负载、在线用户) // @Accept json // @Produce json // @Success 200 {object} RestfulResult[backend.AllInfo] // @Failure 500 {object} RestfulNoDataResult // @Router /all [get] func (c controller) GetAllInfo(ctx *gin.Context) { olu, err := backend.GetOnlineUser() if err != nil { ReturnGin(ctx, nil, err) } sys, err := backend.GetSysLoad() if err != nil { ReturnGin(ctx, nil, err) } dcu, err := backend.GetDCULoad() if err != nil { ReturnGin(ctx, nil, err) } ReturnGin(ctx, backend.AllInfo{ DCUInfo: dcu, SysInfo: *sys, OnlineUserInfo: olu, }, err) } type RcclArgs struct { Args []string `json:"args"` } // PostRcclHandler godoc // @Summary 给出rccl all_reduce_perf参数,执行单机测试 // @Description 给出rccl all_reduce_perf参数,执行单机测试 // @Accept json // @Produce json // @Param args body RcclArgs true "rccl all reduce perf args" // @Success 200 {object} RestfulResult[backend.RcclTestAllReducePrefResult] // @Failure 500 {object} RestfulNoDataResult // @Router /rccl/post [post] func (c controller) PostRcclHandler(ctx *gin.Context) { args := RcclArgs{} err := ctx.BindJSON(&args) if err != nil { ReturnGin(ctx, nil, err) return } if len(args.Args) == 0 { r, err := backend.AllReducePerf(globalCfg.GetString("rccl_test_path"), globalCfg.GetString("rccl_all_reduce_perf_args")) ReturnGin(ctx, r, err) return } arg := strings.Join(args.Args, " ") r, err := backend.AllReducePerf(globalCfg.GetString("rccl_test_path"), arg) ReturnGin(ctx, r, err) }