Commit 1225db14 authored by liming6's avatar liming6
Browse files

feature 添加快速查找子进程的功能

parent abfc4893
......@@ -22,3 +22,7 @@ sshd-tool监听这个unix socket,过滤出需要的信息
- 适配ubuntu系统,对于Ubuntu,系统who -u中的pid是ssh日志中pid的子进程,需要处理一下
- 能查询出以前的,没有被sshd-tool记录的在线情况
## 注意
- sftp登录在sshd日志里有记录,而who -u的输出是没有记录的
- who -u的输出里,可能有多个pid相同的数据,那是同一个ssh连接的多个虚拟终端,由于没有登录动作,所以sshd日志里没有对应的日志条目
......@@ -2,7 +2,6 @@ package asset
import (
"regexp"
"strings"
"testing"
"time"
)
......@@ -53,6 +52,19 @@ func Test3(t *testing.T) {
}
func Test4(t *testing.T) {
a := "abc==="
t.Log(strings.Trim(a, "="))
c := make(chan int, 128)
go func(c chan<- int) {
for i := range 100 {
t.Logf("put: %d", i)
c <- i
}
}(c)
go func() {
time.Sleep(time.Second * 5)
close(c)
}()
for i := range c {
t.Logf("%d", i)
}
}
......@@ -54,7 +54,6 @@ func main() {
for {
n, _, err := conn.ReadFrom(buffer)
if err != nil {
// log.Printf("error read unix socket: %v", err)
globalCancelFunc()
break
}
......
......@@ -5,7 +5,6 @@ import (
"fmt"
"log"
"maps"
"os"
"regexp"
"slices"
"sshd-tool/utils"
......@@ -18,21 +17,59 @@ import (
)
var (
sysuerInfo map[int]utils.LinuxSysUser = nil // key uid
sysuserLock = sync.RWMutex{}
username2uid map[string]int = nil
haveNis bool = false
globalCtx context.Context = nil
globalCancelFunc context.CancelFunc = nil
onlineInfo map[string]utils.OnlineUser = make(map[string]utils.OnlineUser) // key pidstr
onlineLock = sync.RWMutex{}
hostname string
loginedUser map[string]LoginedUser = make(map[string]LoginedUser)
loginedLock = sync.RWMutex{}
sysuerInfo map[int]utils.LinuxSysUser = nil // key uid
sysuserLock = sync.RWMutex{}
username2uid map[string]int = nil
haveNis bool = false
globalCtx context.Context = nil
globalCancelFunc context.CancelFunc = nil
onlineInfo map[int]*utils.OnlineUser = make(map[int]*utils.OnlineUser) // key是pid
onlineLock = sync.RWMutex{} // 锁
loginedUser map[int]*LoginedUser = make(map[int]*LoginedUser) // key是pid
loginedLock = sync.RWMutex{} // 锁
)
// getUserInfo 根据用户名获取系统用户信息
func getUserInfo(name string) *utils.LinuxSysUser {
// getOnline 获取在线用户信息,参数是多个pid,但只返回找到的第一个在线用户信息
func getOnline(pid int, other ...int) *utils.OnlineUser {
updateOnline()
rl := onlineLock.RLocker()
rl.Lock()
defer rl.Unlock()
on, have := onlineInfo[pid]
if have {
return on
}
for _, v := range other {
on, have := onlineInfo[v]
if have {
return on
}
}
return nil
}
// syncOnlineToLogin 剔除loginedInfo中有而onlineInfo里没有的数据
func syncOnlineToLogin() {
updateOnline()
rl := onlineLock.RLocker()
rl.Lock()
loginedLock.Lock()
toDelPids := make([]int, 0, 16)
for k := range loginedUser {
_, have := onlineInfo[k]
if !have {
toDelPids = append(toDelPids, k)
}
}
rl.Unlock()
for _, v := range toDelPids {
delete(loginedUser, v)
}
loginedLock.Unlock()
}
// getSysUserInfo 根据用户名获取系统用户信息
func getSysUserInfo(name string) *utils.LinuxSysUser {
rl := sysuserLock.RLocker()
rl.Lock()
uid, have := username2uid[name]
......@@ -75,10 +112,28 @@ func getUserInfo(name string) *utils.LinuxSysUser {
}
type LoginedUser struct {
Online utils.OnlineUser `json:"online"` // 在线信息
AuthType string `json:"authType"` // 登录时的认证方式
KeyHash *string `json:"keyHash,omitempty"` // 登录时使用的公钥hash
KeyUser *string `json:"keyUser,omitempty"` // 登录公钥的用户信息
Online utils.OnlineUser `json:"online"` // 在线信息
AuthType *string `json:"authType,omitempty"` // 登录时的认证方式
KeyHash *string `json:"keyHash,omitempty"` // 登录时使用的公钥hash
KeyUser *string `json:"keyUser,omitempty"` // 登录公钥的用户信息
}
func NewLoginedUer(ou *utils.OnlineUser, auth *string, kh, ku *string) *LoginedUser {
if ou == nil {
return nil
}
return &LoginedUser{
Online: utils.OnlineUser{
Name: ou.Name,
Type: ou.Type,
When: ou.When,
Pid: ou.Pid,
LoginFrom: ou.LoginFrom,
},
AuthType: auth,
KeyHash: kh,
KeyUser: ku,
}
}
// updateOnline 更新在线用户信息
......@@ -90,7 +145,14 @@ func updateOnline() {
onlineLock.Lock()
clear(onlineInfo)
for _, v := range us {
onlineInfo[v.PidString()] = v
old, have := onlineInfo[v.Pid]
if !have {
onlineInfo[v.Pid] = &v
} else {
if old.When.After(v.When) {
onlineInfo[v.Pid] = &v
}
}
}
onlineLock.Unlock()
}
......@@ -121,17 +183,21 @@ func InitSSH() {
username2uid[v.Name] = v.Uid
}
globalCtx, globalCancelFunc = context.WithCancel(context.Background())
n, err := os.Hostname()
if err != nil {
log.Fatalf("error get hostname: %v", err)
}
hostname = n
us, err := utils.GetOnlineUser()
if err != nil {
log.Fatalf("error get online user: %v", err)
}
for _, v := range us {
onlineInfo[v.PidString()] = v
old, have := onlineInfo[v.Pid]
if !have {
onlineInfo[v.Pid] = &v
loginedUser[v.Pid] = NewLoginedUer(&v, nil, nil, nil)
} else {
if old.When.After(v.When) {
onlineInfo[v.Pid] = &v
loginedUser[v.Pid] = NewLoginedUer(&v, nil, nil, nil)
}
}
}
go updateSysuser()
}
......@@ -176,6 +242,7 @@ func updateSysuser() {
}
sysuserLock.Unlock()
}
syncOnlineToLogin()
case <-globalCtx.Done():
ticker.Stop()
return
......@@ -220,20 +287,23 @@ func ParseSSHLog(str string) {
func handleSSHLogout(str string) {
fields := ReSSHLogout.FindStringSubmatch(str)
user := fields[3]
pidstr := strList(fields[2])
pid, err := strconv.Atoi(fields[2])
if err != nil {
return
}
rl := loginedLock.RLocker()
rl.Lock()
u, have := loginedUser[pidstr]
u, have := loginedUser[pid]
rl.Unlock()
if !have {
return
}
if u.Online.Name == user {
loginedLock.Lock()
delete(loginedUser, pidstr)
delete(loginedUser, pid)
loginedLock.Unlock()
onlineLock.Lock()
delete(onlineInfo, pidstr)
delete(onlineInfo, pid)
onlineLock.Unlock()
}
}
......@@ -241,71 +311,69 @@ func handleSSHLogout(str string) {
func handleSSHLoginPK(str string) {
fields := ReSSHLoginPK.FindStringSubmatch(str)
name := fields[3]
pidstr := strList(fields[2])
pid, err := strconv.Atoi(fields[2])
if err != nil {
return
}
keyHash := strings.Trim(fields[6], "=")
auth := "publickey"
user := getUserInfo(name)
user := getSysUserInfo(name)
if user == nil {
log.Fatal("unknow error, can't find user")
log.Printf("unknow error, can't find user")
return
}
if user.SSHkeyInfo == nil {
log.Fatal("error, login use publickey, but can't find the key")
log.Printf("error, login use publickey, but can't find the key")
return
}
key, have := user.SSHkeyInfo[keyHash]
if !have {
log.Fatal("error, login use publickey, but can't find the key")
log.Printf("error, login use publickey, but can't find the key")
return
}
updateOnline()
u := LoginedUser{}
u.AuthType = auth
u.KeyHash = &keyHash
keyUser := key.UserInfo
u.KeyUser = &keyUser
rl := onlineLock.RLocker()
rl.Lock()
on, have := onlineInfo[pidstr]
rl.Unlock()
if !have {
log.Fatalf("sshd login, but who not find: %s", pidstr)
var on *utils.OnlineUser
cPids, err := utils.GetPidChild2(pid)
if err == nil {
on = getOnline(pid, cPids...)
} else {
on = getOnline(pid)
}
u.Online = utils.OnlineUser{
Name: on.Name,
Type: on.Type,
When: on.When,
Pids: slices.Clone(on.Pids),
LoginFrom: on.LoginFrom,
if on == nil {
log.Printf("sshd login, but who not find: %d,%s", pid, name)
return
}
keyUser := key.UserInfo
u := NewLoginedUer(on, &auth, &keyHash, &keyUser)
loginedLock.Lock()
loginedUser[u.Online.PidString()] = u
loginedUser[u.Online.Pid] = u
loginedLock.Unlock()
}
func handleSSHLogin(str string) {
fields := ReSSHLogin.FindStringSubmatch(str)
from := fields[5]
auth := fields[3]
pidstr := strList(fields[2])
updateOnline()
rl := onlineLock.RLocker()
rl.Lock()
on, have := onlineInfo[pidstr]
rl.Unlock()
if !have {
log.Fatalf("sshd login, but who not find: %s", pidstr)
pid, err := strconv.Atoi(fields[2])
if err != nil {
return
}
var on *utils.OnlineUser
cPids, err := utils.GetPidChild2(pid)
if err == nil {
on = getOnline(pid, cPids...)
} else {
on = getOnline(pid)
}
u := LoginedUser{}
u.AuthType = auth
u.Online = utils.OnlineUser{
Name: on.Name,
Type: on.Type,
When: on.When,
Pids: slices.Clone(on.Pids),
LoginFrom: from,
if on == nil {
log.Printf("sshd login, but who not find: %d,%s", pid, fields[4])
return
}
if on == nil {
log.Printf("sshd login, but who not find: %d,%s", pid, fields[4])
return
}
u := NewLoginedUer(on, &auth, nil, nil)
loginedLock.Lock()
loginedUser[u.Online.PidString()] = u
loginedUser[u.Online.Pid] = u
loginedLock.Unlock()
}
......
......@@ -10,6 +10,7 @@ import (
"slices"
"strconv"
"strings"
"sync"
"time"
)
......@@ -221,29 +222,22 @@ type OnlineUser struct {
Name string `json:"name"`
Type string `json:"type"`
When time.Time `json:"loginTime"`
Pids []int `json:"pids"`
Pid int `json:"pid"`
LoginFrom string `json:"loginForm"`
}
func (ou OnlineUser) String() string {
return fmt.Sprintf("name:%s type:%s when:%s pids:%v login from:%s", ou.Name, ou.Type, ou.When.Format("2006-01-02 15:04"), ou.Pids, ou.LoginFrom)
return fmt.Sprintf("name:%s type:%s when:%s pids:%v login from:%s", ou.Name, ou.Type, ou.When.Format("2006-01-02 15:04"), ou.Pid, ou.LoginFrom)
}
func (ou OnlineUser) Sha256sum() [32]byte {
return sha256.Sum256([]byte(ou.String()))
}
func (ou OnlineUser) PidString() string {
if len(ou.Pids) == 0 {
return "[]"
}
return fmt.Sprintf("%v", ou.Pids)
}
var (
ReOnLineUser = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+([a-zA-Z0-9/]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d*(?:,\d*)*)\s+\((.*)\)$`) // sshd远程登录的
ReOnLineUserTTY = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(tty[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d*(?:,\d*)*)$`) // 通过控制台登录的
ReOnLineUserX = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(:[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+\?\s+(\d*(?:,\d*)*).*$`) // 通过图像界面
ReOnLineUser = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+([a-zA-Z0-9/]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d+)\s+\((.*)\)$`) // sshd远程登录的
ReOnLineUserTTY = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(tty[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+(?:old|\.|\d{2}:\d{2})\s+(\d*(?:,\d*)*)$`) // 通过控制台登录的
ReOnLineUserX = regexp.MustCompile(`^(?i)([a-zA-Z_0-9]*)\s+(:[0-9]*)\s+(\d{4}-\d{1,2}-\d{1,2} \d{2}:\d{2})\s+\?\s+(\d*(?:,\d*)*).*$`) // 通过图像界面
)
// GetOnlineUser
......@@ -269,20 +263,121 @@ func GetOnlineUser() ([]OnlineUser, error) {
}
u.When = t
u.LoginFrom = m[5]
pids := strings.Split(m[4], ",")
u.Pids = make([]int, 0, len(pids))
for _, v := range pids {
p, err := strconv.Atoi(v)
if err != nil {
return nil, err
}
u.Pids = append(u.Pids, p)
p, err := strconv.Atoi(m[4])
if err != nil {
return nil, err
}
slices.Sort(u.Pids)
u.Pid = p
result = append(result, u)
} else if ReOnLineUserTTY.MatchString(line) {
// todo
}
}
return result, nil
}
var (
ReNum = regexp.MustCompile(`^\d+$`)
)
func GetPidChild(pid int) ([]int, error) {
items, err := os.ReadDir("/proc")
if err != nil {
return nil, err
}
result := make([]int, 0, 8)
for _, item := range items {
if !item.IsDir() {
continue
}
if !ReNum.MatchString(item.Name()) {
continue
}
ppid, err := getPPID(item.Name())
if err != nil {
continue
}
if ppid == pid {
i, err := strconv.Atoi(item.Name())
if err == nil {
result = append(result, i)
}
}
}
slices.Sort(result)
return result, nil
}
// GetPidChild2 并发获取进程的子进程,并发数为8
func GetPidChild2(pid int) ([]int, error) {
items, err := os.ReadDir("/proc")
if err != nil {
return nil, err
}
result := make([]int, 0, 16)
resultLock := sync.Mutex{}
// 指定并发数为8
goroutineNum := 8
wg := sync.WaitGroup{}
wg.Add(goroutineNum)
// 创建管道,并为每个管道添加一个处理goroutine
cs := make([]chan string, 0, goroutineNum)
for range goroutineNum {
c := make(chan string, 128)
cs = append(cs, c)
go func(sc <-chan string, ppid int, wg *sync.WaitGroup) {
for pid := range sc {
if pid == "0" {
break
}
p, err := getPPID(pid)
if err == nil && p == ppid {
pp, err := strconv.Atoi(pid)
if err == nil {
resultLock.Lock()
result = append(result, pp)
resultLock.Unlock()
}
}
}
wg.Done()
}(c, pid, &wg)
}
// 向goroutine分发任务
for i, item := range items {
cs[i%goroutineNum] <- item.Name()
}
// 向goroutine发送关闭信号
for _, c := range cs {
c <- "0"
}
// 等待goroutine关闭
wg.Wait()
// 关闭管道
for _, c := range cs {
close(c)
}
slices.Sort(result)
return result, nil
}
func getPPID(pid string) (int, error) {
path := fmt.Sprintf("/proc/%s/status", pid)
content, err := os.ReadFile(path)
if err != nil {
return 0, err
}
lines := strings.Split(string(content), "\n")
for _, line := range lines {
if !strings.HasPrefix(line, "PPid:") {
continue
}
fields := strings.Fields(line)
if len(fields) != 2 {
continue
}
return strconv.Atoi(fields[1])
}
return 0, fmt.Errorf("error not find PPid in %s", path)
}
......@@ -4,6 +4,7 @@ import (
"crypto/sha256"
"encoding/base64"
"testing"
"time"
)
func TestFindCmd(t *testing.T) {
......@@ -132,3 +133,27 @@ func TestGetOneNisUser(t *testing.T) {
}
}
}
func TestGetPidChild(t *testing.T) {
start := time.Now()
pids, err := GetPidChild(1)
d := time.Since(start)
t.Logf("%d ms", d.Milliseconds())
if err != nil {
t.Error(err)
}
t.Logf("%v", pids)
t.Logf("num: %d",len(pids))
}
func TestGetPidChild2(t *testing.T) {
start := time.Now()
pids, err := GetPidChild2(1)
d := time.Since(start)
t.Logf("%d ms", d.Milliseconds())
if err != nil {
t.Error(err)
}
t.Logf("%v", pids)
t.Logf("num: %d",len(pids))
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment