Commit 1225db14 authored by liming6's avatar liming6
Browse files

feature 添加快速查找子进程的功能

parent abfc4893
...@@ -22,3 +22,7 @@ sshd-tool监听这个unix socket,过滤出需要的信息 ...@@ -22,3 +22,7 @@ sshd-tool监听这个unix socket,过滤出需要的信息
- 适配ubuntu系统,对于Ubuntu,系统who -u中的pid是ssh日志中pid的子进程,需要处理一下 - 适配ubuntu系统,对于Ubuntu,系统who -u中的pid是ssh日志中pid的子进程,需要处理一下
- 能查询出以前的,没有被sshd-tool记录的在线情况 - 能查询出以前的,没有被sshd-tool记录的在线情况
## 注意
- sftp登录在sshd日志里有记录,而who -u的输出是没有记录的
- who -u的输出里,可能有多个pid相同的数据,那是同一个ssh连接的多个虚拟终端,由于没有登录动作,所以sshd日志里没有对应的日志条目
...@@ -2,7 +2,6 @@ package asset ...@@ -2,7 +2,6 @@ package asset
import ( import (
"regexp" "regexp"
"strings"
"testing" "testing"
"time" "time"
) )
...@@ -53,6 +52,19 @@ func Test3(t *testing.T) { ...@@ -53,6 +52,19 @@ func Test3(t *testing.T) {
} }
func Test4(t *testing.T) { func Test4(t *testing.T) {
a := "abc===" c := make(chan int, 128)
t.Log(strings.Trim(a, "=")) go func(c chan<- int) {
for i := range 100 {
t.Logf("put: %d", i)
c <- i
}
}(c)
go func() {
time.Sleep(time.Second * 5)
close(c)
}()
for i := range c {
t.Logf("%d", i)
}
} }
...@@ -54,7 +54,6 @@ func main() { ...@@ -54,7 +54,6 @@ func main() {
for { for {
n, _, err := conn.ReadFrom(buffer) n, _, err := conn.ReadFrom(buffer)
if err != nil { if err != nil {
// log.Printf("error read unix socket: %v", err)
globalCancelFunc() globalCancelFunc()
break break
} }
......
...@@ -5,7 +5,6 @@ import ( ...@@ -5,7 +5,6 @@ import (
"fmt" "fmt"
"log" "log"
"maps" "maps"
"os"
"regexp" "regexp"
"slices" "slices"
"sshd-tool/utils" "sshd-tool/utils"
...@@ -24,15 +23,53 @@ var ( ...@@ -24,15 +23,53 @@ var (
haveNis bool = false haveNis bool = false
globalCtx context.Context = nil globalCtx context.Context = nil
globalCancelFunc context.CancelFunc = nil globalCancelFunc context.CancelFunc = nil
onlineInfo map[string]utils.OnlineUser = make(map[string]utils.OnlineUser) // key pidstr onlineInfo map[int]*utils.OnlineUser = make(map[int]*utils.OnlineUser) // key是pid
onlineLock = sync.RWMutex{} onlineLock = sync.RWMutex{} // 锁
hostname string loginedUser map[int]*LoginedUser = make(map[int]*LoginedUser) // key是pid
loginedUser map[string]LoginedUser = make(map[string]LoginedUser) loginedLock = sync.RWMutex{} // 锁
loginedLock = sync.RWMutex{}
) )
// getUserInfo 根据用户名获取系统用户信息 // getOnline 获取在线用户信息,参数是多个pid,但只返回找到的第一个在线用户信息
func getUserInfo(name string) *utils.LinuxSysUser { func getOnline(pid int, other ...int) *utils.OnlineUser {
updateOnline()
rl := onlineLock.RLocker()
rl.Lock()
defer rl.Unlock()
on, have := onlineInfo[pid]
if have {
return on
}
for _, v := range other {
on, have := onlineInfo[v]
if have {
return on
}
}
return nil
}
// syncOnlineToLogin 剔除loginedInfo中有而onlineInfo里没有的数据
func syncOnlineToLogin() {
updateOnline()
rl := onlineLock.RLocker()
rl.Lock()
loginedLock.Lock()
toDelPids := make([]int, 0, 16)
for k := range loginedUser {
_, have := onlineInfo[k]
if !have {
toDelPids = append(toDelPids, k)
}
}
rl.Unlock()
for _, v := range toDelPids {
delete(loginedUser, v)
}
loginedLock.Unlock()
}
// getSysUserInfo 根据用户名获取系统用户信息
func getSysUserInfo(name string) *utils.LinuxSysUser {
rl := sysuserLock.RLocker() rl := sysuserLock.RLocker()
rl.Lock() rl.Lock()
uid, have := username2uid[name] uid, have := username2uid[name]
...@@ -76,11 +113,29 @@ func getUserInfo(name string) *utils.LinuxSysUser { ...@@ -76,11 +113,29 @@ func getUserInfo(name string) *utils.LinuxSysUser {
type LoginedUser struct { type LoginedUser struct {
Online utils.OnlineUser `json:"online"` // 在线信息 Online utils.OnlineUser `json:"online"` // 在线信息
AuthType string `json:"authType"` // 登录时的认证方式 AuthType *string `json:"authType,omitempty"` // 登录时的认证方式
KeyHash *string `json:"keyHash,omitempty"` // 登录时使用的公钥hash KeyHash *string `json:"keyHash,omitempty"` // 登录时使用的公钥hash
KeyUser *string `json:"keyUser,omitempty"` // 登录公钥的用户信息 KeyUser *string `json:"keyUser,omitempty"` // 登录公钥的用户信息
} }
func NewLoginedUer(ou *utils.OnlineUser, auth *string, kh, ku *string) *LoginedUser {
if ou == nil {
return nil
}
return &LoginedUser{
Online: utils.OnlineUser{
Name: ou.Name,
Type: ou.Type,
When: ou.When,
Pid: ou.Pid,
LoginFrom: ou.LoginFrom,
},
AuthType: auth,
KeyHash: kh,
KeyUser: ku,
}
}
// updateOnline 更新在线用户信息 // updateOnline 更新在线用户信息
func updateOnline() { func updateOnline() {
us, err := utils.GetOnlineUser() us, err := utils.GetOnlineUser()
...@@ -90,7 +145,14 @@ func updateOnline() { ...@@ -90,7 +145,14 @@ func updateOnline() {
onlineLock.Lock() onlineLock.Lock()
clear(onlineInfo) clear(onlineInfo)
for _, v := range us { for _, v := range us {
onlineInfo[v.PidString()] = v old, have := onlineInfo[v.Pid]
if !have {
onlineInfo[v.Pid] = &v
} else {
if old.When.After(v.When) {
onlineInfo[v.Pid] = &v
}
}
} }
onlineLock.Unlock() onlineLock.Unlock()
} }
...@@ -121,17 +183,21 @@ func InitSSH() { ...@@ -121,17 +183,21 @@ func InitSSH() {
username2uid[v.Name] = v.Uid username2uid[v.Name] = v.Uid
} }
globalCtx, globalCancelFunc = context.WithCancel(context.Background()) globalCtx, globalCancelFunc = context.WithCancel(context.Background())
n, err := os.Hostname()
if err != nil {
log.Fatalf("error get hostname: %v", err)
}
hostname = n
us, err := utils.GetOnlineUser() us, err := utils.GetOnlineUser()
if err != nil { if err != nil {
log.Fatalf("error get online user: %v", err) log.Fatalf("error get online user: %v", err)
} }
for _, v := range us { for _, v := range us {
onlineInfo[v.PidString()] = v old, have := onlineInfo[v.Pid]
if !have {
onlineInfo[v.Pid] = &v
loginedUser[v.Pid] = NewLoginedUer(&v, nil, nil, nil)
} else {
if old.When.After(v.When) {
onlineInfo[v.Pid] = &v
loginedUser[v.Pid] = NewLoginedUer(&v, nil, nil, nil)
}
}
} }
go updateSysuser() go updateSysuser()
} }
...@@ -176,6 +242,7 @@ func updateSysuser() { ...@@ -176,6 +242,7 @@ func updateSysuser() {
} }
sysuserLock.Unlock() sysuserLock.Unlock()
} }
syncOnlineToLogin()
case <-globalCtx.Done(): case <-globalCtx.Done():
ticker.Stop() ticker.Stop()
return return
...@@ -220,20 +287,23 @@ func ParseSSHLog(str string) { ...@@ -220,20 +287,23 @@ func ParseSSHLog(str string) {
func handleSSHLogout(str string) { func handleSSHLogout(str string) {
fields := ReSSHLogout.FindStringSubmatch(str) fields := ReSSHLogout.FindStringSubmatch(str)
user := fields[3] user := fields[3]
pidstr := strList(fields[2]) pid, err := strconv.Atoi(fields[2])
if err != nil {
return
}
rl := loginedLock.RLocker() rl := loginedLock.RLocker()
rl.Lock() rl.Lock()
u, have := loginedUser[pidstr] u, have := loginedUser[pid]
rl.Unlock() rl.Unlock()
if !have { if !have {
return return
} }
if u.Online.Name == user { if u.Online.Name == user {
loginedLock.Lock() loginedLock.Lock()
delete(loginedUser, pidstr) delete(loginedUser, pid)
loginedLock.Unlock() loginedLock.Unlock()
onlineLock.Lock() onlineLock.Lock()
delete(onlineInfo, pidstr) delete(onlineInfo, pid)
onlineLock.Unlock() onlineLock.Unlock()
} }
} }
...@@ -241,71 +311,69 @@ func handleSSHLogout(str string) { ...@@ -241,71 +311,69 @@ func handleSSHLogout(str string) {
func handleSSHLoginPK(str string) { func handleSSHLoginPK(str string) {
fields := ReSSHLoginPK.FindStringSubmatch(str) fields := ReSSHLoginPK.FindStringSubmatch(str)
name := fields[3] name := fields[3]
pidstr := strList(fields[2]) pid, err := strconv.Atoi(fields[2])
if err != nil {
return
}
keyHash := strings.Trim(fields[6], "=") keyHash := strings.Trim(fields[6], "=")
auth := "publickey" auth := "publickey"
user := getUserInfo(name) user := getSysUserInfo(name)
if user == nil { if user == nil {
log.Fatal("unknow error, can't find user") log.Printf("unknow error, can't find user")
return
} }
if user.SSHkeyInfo == nil { if user.SSHkeyInfo == nil {
log.Fatal("error, login use publickey, but can't find the key") log.Printf("error, login use publickey, but can't find the key")
return
} }
key, have := user.SSHkeyInfo[keyHash] key, have := user.SSHkeyInfo[keyHash]
if !have { if !have {
log.Fatal("error, login use publickey, but can't find the key") log.Printf("error, login use publickey, but can't find the key")
return
} }
updateOnline() var on *utils.OnlineUser
cPids, err := utils.GetPidChild2(pid)
u := LoginedUser{} if err == nil {
u.AuthType = auth on = getOnline(pid, cPids...)
u.KeyHash = &keyHash } else {
keyUser := key.UserInfo on = getOnline(pid)
u.KeyUser = &keyUser
rl := onlineLock.RLocker()
rl.Lock()
on, have := onlineInfo[pidstr]
rl.Unlock()
if !have {
log.Fatalf("sshd login, but who not find: %s", pidstr)
} }
u.Online = utils.OnlineUser{ if on == nil {
Name: on.Name, log.Printf("sshd login, but who not find: %d,%s", pid, name)
Type: on.Type, return
When: on.When,
Pids: slices.Clone(on.Pids),
LoginFrom: on.LoginFrom,
} }
keyUser := key.UserInfo
u := NewLoginedUer(on, &auth, &keyHash, &keyUser)
loginedLock.Lock() loginedLock.Lock()
loginedUser[u.Online.PidString()] = u loginedUser[u.Online.Pid] = u
loginedLock.Unlock() loginedLock.Unlock()
} }
func handleSSHLogin(str string) { func handleSSHLogin(str string) {
fields := ReSSHLogin.FindStringSubmatch(str) fields := ReSSHLogin.FindStringSubmatch(str)
from := fields[5]
auth := fields[3] auth := fields[3]
pidstr := strList(fields[2]) pid, err := strconv.Atoi(fields[2])
updateOnline() if err != nil {
rl := onlineLock.RLocker() return
rl.Lock() }
on, have := onlineInfo[pidstr] var on *utils.OnlineUser
rl.Unlock() cPids, err := utils.GetPidChild2(pid)
if !have { if err == nil {
log.Fatalf("sshd login, but who not find: %s", pidstr) on = getOnline(pid, cPids...)
} else {
on = getOnline(pid)
}
if on == nil {
log.Printf("sshd login, but who not find: %d,%s", pid, fields[4])
return
} }
u := LoginedUser{} if on == nil {
u.AuthType = auth log.Printf("sshd login, but who not find: %d,%s", pid, fields[4])
u.Online = utils.OnlineUser{ return
Name: on.Name,
Type: on.Type,
When: on.When,
Pids: slices.Clone(on.Pids),
LoginFrom: from,
} }
u := NewLoginedUer(on, &auth, nil, nil)
loginedLock.Lock() loginedLock.Lock()
loginedUser[u.Online.PidString()] = u loginedUser[u.Online.Pid] = u
loginedLock.Unlock() loginedLock.Unlock()
} }
......
...@@ -10,6 +10,7 @@ import ( ...@@ -10,6 +10,7 @@ import (
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
) )
...@@ -221,27 +222,20 @@ type OnlineUser struct { ...@@ -221,27 +222,20 @@ type OnlineUser struct {
Name string `json:"name"` Name string `json:"name"`
Type string `json:"type"` Type string `json:"type"`
When time.Time `json:"loginTime"` When time.Time `json:"loginTime"`
Pids []int `json:"pids"` Pid int `json:"pid"`
LoginFrom string `json:"loginForm"` LoginFrom string `json:"loginForm"`
} }
func (ou OnlineUser) String() string { func (ou OnlineUser) String() string {
return fmt.Sprintf("name:%s type:%s when:%s pids:%v login from:%s", ou.Name, ou.Type, ou.When.Format("2006-01-02 15:04"), ou.Pids, ou.LoginFrom) return fmt.Sprintf("name:%s type:%s when:%s pids:%v login from:%s", ou.Name, ou.Type, ou.When.Format("2006-01-02 15:04"), ou.Pid, ou.LoginFrom)
} }
func (ou OnlineUser) Sha256sum() [32]byte { func (ou OnlineUser) Sha256sum() [32]byte {
return sha256.Sum256([]byte(ou.String())) return sha256.Sum256([]byte(ou.String()))
} }
func (ou OnlineUser) PidString() string {
if len(ou.Pids) == 0 {
return "[]"
}
return fmt.Sprintf("%v", ou.Pids)
}
var ( var (
ReOnLineUser = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+([a-zA-Z0-9/]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d*(?:,\d*)*)\s+\((.*)\)$`) // sshd远程登录的 ReOnLineUser = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+([a-zA-Z0-9/]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d+)\s+\((.*)\)$`) // sshd远程登录的
ReOnLineUserTTY = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(tty[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d*(?:,\d*)*)$`) // 通过控制台登录的 ReOnLineUserTTY = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(tty[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d*(?:,\d*)*)$`) // 通过控制台登录的
ReOnLineUserX = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(:[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+\?\s+(\d*(?:,\d*)*).*$`) // 通过图像界面 ReOnLineUserX = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(:[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+\?\s+(\d*(?:,\d*)*).*$`) // 通过图像界面
) )
...@@ -269,20 +263,121 @@ func GetOnlineUser() ([]OnlineUser, error) { ...@@ -269,20 +263,121 @@ func GetOnlineUser() ([]OnlineUser, error) {
} }
u.When = t u.When = t
u.LoginFrom = m[5] u.LoginFrom = m[5]
pids := strings.Split(m[4], ",") p, err := strconv.Atoi(m[4])
u.Pids = make([]int, 0, len(pids))
for _, v := range pids {
p, err := strconv.Atoi(v)
if err != nil { if err != nil {
return nil, err return nil, err
} }
u.Pids = append(u.Pids, p) u.Pid = p
}
slices.Sort(u.Pids)
result = append(result, u) result = append(result, u)
} else if ReOnLineUserTTY.MatchString(line) {
// todo
} }
} }
return result, nil return result, nil
} }
var (
ReNum = regexp.MustCompile(`^\d+$`)
)
func GetPidChild(pid int) ([]int, error) {
items, err := os.ReadDir("/proc")
if err != nil {
return nil, err
}
result := make([]int, 0, 8)
for _, item := range items {
if !item.IsDir() {
continue
}
if !ReNum.MatchString(item.Name()) {
continue
}
ppid, err := getPPID(item.Name())
if err != nil {
continue
}
if ppid == pid {
i, err := strconv.Atoi(item.Name())
if err == nil {
result = append(result, i)
}
}
}
slices.Sort(result)
return result, nil
}
// GetPidChild2 并发获取进程的子进程,并发数为8
func GetPidChild2(pid int) ([]int, error) {
items, err := os.ReadDir("/proc")
if err != nil {
return nil, err
}
result := make([]int, 0, 16)
resultLock := sync.Mutex{}
// 指定并发数为8
goroutineNum := 8
wg := sync.WaitGroup{}
wg.Add(goroutineNum)
// 创建管道,并为每个管道添加一个处理goroutine
cs := make([]chan string, 0, goroutineNum)
for range goroutineNum {
c := make(chan string, 128)
cs = append(cs, c)
go func(sc <-chan string, ppid int, wg *sync.WaitGroup) {
for pid := range sc {
if pid == "0" {
break
}
p, err := getPPID(pid)
if err == nil && p == ppid {
pp, err := strconv.Atoi(pid)
if err == nil {
resultLock.Lock()
result = append(result, pp)
resultLock.Unlock()
}
}
}
wg.Done()
}(c, pid, &wg)
}
// 向goroutine分发任务
for i, item := range items {
cs[i%goroutineNum] <- item.Name()
}
// 向goroutine发送关闭信号
for _, c := range cs {
c <- "0"
}
// 等待goroutine关闭
wg.Wait()
// 关闭管道
for _, c := range cs {
close(c)
}
slices.Sort(result)
return result, nil
}
func getPPID(pid string) (int, error) {
path := fmt.Sprintf("/proc/%s/status", pid)
content, err := os.ReadFile(path)
if err != nil {
return 0, err
}
lines := strings.Split(string(content), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "PPid:") {
continue
}
fields := strings.Fields(line)
if len(fields) != 2 {
continue
}
return strconv.Atoi(fields[1])
}
return 0, fmt.Errorf("error not find PPid in %s", path)
}
...@@ -4,6 +4,7 @@ import ( ...@@ -4,6 +4,7 @@ import (
"crypto/sha256" "crypto/sha256"
"encoding/base64" "encoding/base64"
"testing" "testing"
"time"
) )
func TestFindCmd(t *testing.T) { func TestFindCmd(t *testing.T) {
...@@ -132,3 +133,27 @@ func TestGetOneNisUser(t *testing.T) { ...@@ -132,3 +133,27 @@ func TestGetOneNisUser(t *testing.T) {
} }
} }
} }
func TestGetPidChild(t *testing.T) {
start := time.Now()
pids, err := GetPidChild(1)
d := time.Since(start)
t.Logf("%d ms", d.Milliseconds())
if err != nil {
t.Error(err)
}
t.Logf("%v", pids)
t.Logf("num: %d",len(pids))
}
func TestGetPidChild2(t *testing.T) {
start := time.Now()
pids, err := GetPidChild2(1)
d := time.Since(start)
t.Logf("%d ms", d.Milliseconds())
if err != nil {
t.Error(err)
}
t.Logf("%v", pids)
t.Logf("num: %d",len(pids))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment