package main import ( "context" "fmt" "log" "maps" "regexp" "slices" "sshd-tool/utils" "strconv" "strings" "sync" "time" "github.com/gin-gonic/gin" ) var ( sysuerInfo map[int]utils.LinuxSysUser = nil // key uid sysuserLock = sync.RWMutex{} username2uid map[string]int = nil haveNis bool = false globalCtx context.Context = nil globalCancelFunc context.CancelFunc = nil onlineInfo map[int]*utils.OnlineUser = make(map[int]*utils.OnlineUser) // key是pid onlineLock = sync.RWMutex{} // 锁 loginedUser map[int]*LoginedUser = make(map[int]*LoginedUser) // key是pid loginedLock = sync.RWMutex{} // 锁 ) // getOnline 获取在线用户信息,参数是多个pid,但只返回找到的第一个在线用户信息 func getOnline(pid int, other ...int) *utils.OnlineUser { updateOnline() rl := onlineLock.RLocker() rl.Lock() defer rl.Unlock() on, have := onlineInfo[pid] if have { return on } for _, v := range other { on, have := onlineInfo[v] if have { return on } } return nil } // syncOnlineToLogin 剔除loginedInfo中有而onlineInfo里没有的数据 func syncOnlineToLogin() { updateOnline() rl := onlineLock.RLocker() rl.Lock() loginedLock.Lock() toDelPids := make([]int, 0, 16) for k := range loginedUser { _, have := onlineInfo[k] if !have { toDelPids = append(toDelPids, k) } } rl.Unlock() for _, v := range toDelPids { delete(loginedUser, v) } loginedLock.Unlock() } // getSysUserInfo 根据用户名获取系统用户信息 func getSysUserInfo(name string) *utils.LinuxSysUser { rl := sysuserLock.RLocker() rl.Lock() uid, have := username2uid[name] if !have { rl.Unlock() user, _ := utils.GetOneSysUser(name) if user == nil && !haveNis { return nil } if user == nil { user, _ = utils.GetOneNisUser(name) } if user == nil { return nil } sysuserLock.Lock() sysuerInfo[user.Uid] = utils.LinuxSysUser{ Name: user.Name, Home: user.Home, Shell: user.Shell, Uid: user.Uid, Gid: user.Gid, SSHkeyInfo: maps.Clone(user.SSHkeyInfo), } username2uid[user.Name] = user.Uid sysuserLock.Unlock() return user } user := sysuerInfo[uid] rl.Unlock() result := utils.LinuxSysUser{ Name: user.Name, Home: user.Home, Shell: user.Shell, Uid: user.Uid, Gid: user.Gid, SSHkeyInfo: maps.Clone(user.SSHkeyInfo), } return &result } type LoginedUser struct { Online utils.OnlineUser `json:"online"` // 在线信息 AuthType *string `json:"authType,omitempty"` // 登录时的认证方式 KeyHash *string `json:"keyHash,omitempty"` // 登录时使用的公钥hash KeyUser *string `json:"keyUser,omitempty"` // 登录公钥的用户信息 } func NewLoginedUer(ou *utils.OnlineUser, auth *string, kh, ku *string) *LoginedUser { if ou == nil { return nil } return &LoginedUser{ Online: utils.OnlineUser{ Name: ou.Name, Type: ou.Type, When: ou.When, Pid: ou.Pid, LoginFrom: ou.LoginFrom, }, AuthType: auth, KeyHash: kh, KeyUser: ku, } } // updateOnline 更新在线用户信息 func updateOnline() { us, err := utils.GetOnlineUser() if err != nil { log.Fatalf("error get online user: %v", err) } onlineLock.Lock() clear(onlineInfo) for _, v := range us { old, have := onlineInfo[v.Pid] if !have { onlineInfo[v.Pid] = &v } else { if old.When.After(v.When) { onlineInfo[v.Pid] = &v } } } onlineLock.Unlock() } // Init 初始化 func InitSSH() { ypcat := utils.FindCmd(utils.NIS_YPCAT) if ypcat != nil { u, err := utils.GetNisUser() if err != nil { log.Fatalf("error get nis user: %v", err) } sysuerInfo = u haveNis = true } sysU, err := utils.GetSysUser() if err != nil { log.Fatalf("error get sys user: %v", err) } if sysuerInfo == nil { sysuerInfo = sysU } else { maps.Copy(sysuerInfo, sysU) } username2uid = make(map[string]int) for _, v := range sysuerInfo { username2uid[v.Name] = v.Uid } globalCtx, globalCancelFunc = context.WithCancel(context.Background()) us, err := utils.GetOnlineUser() if err != nil { log.Fatalf("error get online user: %v", err) } for _, v := range us { old, have := onlineInfo[v.Pid] if !have { onlineInfo[v.Pid] = &v loginedUser[v.Pid] = NewLoginedUer(&v, nil, nil, nil) } else { if old.When.After(v.When) { onlineInfo[v.Pid] = &v loginedUser[v.Pid] = NewLoginedUer(&v, nil, nil, nil) } } } go updateSysuser() } // updateSysuser 每10秒刷新一下用户信息 func updateSysuser() { ticker := time.NewTicker(time.Second * 10) for { select { case <-ticker.C: if haveNis { u, err := utils.GetNisUser() if err != nil { log.Printf("error get nis user: %v", err) globalCancelFunc() continue } sysU, err := utils.GetSysUser() if err != nil { log.Printf("error get sys user: %v", err) globalCancelFunc() continue } maps.Copy(u, sysU) sysuserLock.Lock() sysuerInfo = u for _, v := range sysuerInfo { username2uid[v.Name] = v.Uid } sysuserLock.Unlock() } else { sysU, err := utils.GetSysUser() if err != nil { log.Printf("error get sys user: %v", err) globalCancelFunc() continue } sysuserLock.Lock() sysuerInfo = sysU for _, v := range sysuerInfo { username2uid[v.Name] = v.Uid } sysuserLock.Unlock() } syncOnlineToLogin() case <-globalCtx.Done(): ticker.Stop() return } } } /* 登录 Sep 30 11:08:47 login01 sshd[1768988]: Accepted keyboard-interactive/pam for caiyu from 61.153.50.229 port 4733 ssh2 退出 Sep 30 11:08:19 login01 sshd[1767042]: pam_unix(sshd:session): session closed for user */ var ( ReSSHLogin = regexp.MustCompile(`^<\d+>[A-Z][a-z]{2} \d+ \d+:\d+:\d+ (\S+) sshd\[(\d+)\]: Accepted (\S+) for (\S+) from (\S+) port (\d+) ssh(?:|\d+)$`) ReSSHLoginPK = regexp.MustCompile(`^<\d+>[A-Z][a-z]{2} \d+ \d+:\d+:\d+ (\S+) sshd\[(\d+)\]: Accepted publickey for (\S+) from (\S+) port (?:\d+) ssh(?:|\d+):\s+(\S+)\s+(?:sha|SHA)256:(.*)$`) ReSSHLogout = regexp.MustCompile(`^<\d+>[A-Z][a-z]{2} \d+ \d+:\d+:\d+ (\S+) sshd\[(\d+)\]: pam_unix\(sshd:session\): session closed for user (.*)$`) ) // ParseSSHLog 过滤出登录和退出ssh的sshd日志,并对全局信息做出修改 func ParseSSHLog(str string) { if !strings.Contains(str, "sshd") { // 不是sshd相关日志 return } if ReSSHLogin.MatchString(str) { go handleSSHLogin(str) return } if ReSSHLoginPK.MatchString(str) { go handleSSHLoginPK(str) return } if ReSSHLogout.MatchString(str) { go handleSSHLogout(str) return } } func handleSSHLogout(str string) { fields := ReSSHLogout.FindStringSubmatch(str) user := fields[3] pid, err := strconv.Atoi(fields[2]) if err != nil { return } rl := loginedLock.RLocker() rl.Lock() u, have := loginedUser[pid] rl.Unlock() if !have { return } if u.Online.Name == user { loginedLock.Lock() delete(loginedUser, pid) loginedLock.Unlock() onlineLock.Lock() delete(onlineInfo, pid) onlineLock.Unlock() } } func handleSSHLoginPK(str string) { fields := ReSSHLoginPK.FindStringSubmatch(str) name := fields[3] pid, err := strconv.Atoi(fields[2]) if err != nil { return } keyHash := strings.Trim(fields[6], "=") auth := "publickey" user := getSysUserInfo(name) if user == nil { log.Printf("unknow error, can't find user") return } if user.SSHkeyInfo == nil { log.Printf("error, login use publickey, but can't find the key") return } key, have := user.SSHkeyInfo[keyHash] if !have { log.Printf("error, login use publickey, but can't find the key") return } var on *utils.OnlineUser time.Sleep(time.Second) cPids, err := utils.GetPidChild2(pid) if err == nil { on = getOnline(pid, cPids...) } else { on = getOnline(pid) } if on == nil { log.Printf("sshd login, but who not find: %d,%s", pid, name) return } keyUser := key.UserInfo u := NewLoginedUer(on, &auth, &keyHash, &keyUser) loginedLock.Lock() loginedUser[u.Online.Pid] = u loginedLock.Unlock() } func handleSSHLogin(str string) { fields := ReSSHLogin.FindStringSubmatch(str) auth := fields[3] pid, err := strconv.Atoi(fields[2]) if err != nil { return } var on *utils.OnlineUser time.Sleep(time.Second) cPids, err := utils.GetPidChild2(pid) if err == nil { on = getOnline(pid, cPids...) } else { on = getOnline(pid) } if on == nil { log.Printf("sshd login, but who not find: %d,%s", pid, fields[4]) return } u := NewLoginedUer(on, &auth, nil, nil) loginedLock.Lock() loginedUser[u.Online.Pid] = u loginedLock.Unlock() } func strList(str string) string { pids := strings.Split(str, ",") pid := make([]int, 0, 4) for _, p := range pids { pp, err := strconv.Atoi(strings.Trim(p, " ")) if err != nil { log.Fatalf("error convert string to int: %v", err) } pid = append(pid, pp) } slices.Sort(pid) return fmt.Sprintf("%v", pid) } func getLoginedUserInfo(ctx *gin.Context) { rl := loginedLock.RLocker() rl.Lock() defer rl.Unlock() result := make([]*LoginedUser, 0, len(loginedUser)) for _, v := range loginedUser { result = append(result, v) } ctx.JSON(200, result) }