use std::ffi::{CStr, CString}; use std::fs::OpenOptions; use std::os::raw::{c_int,c_char}; use std::ptr; use std::sync::atomic::{AtomicBool, Ordering::SeqCst}; use anyhow::Result; use libpam_sys::{pam_get_user,pam_handle,PAM_SUCCESS}; use serde_json::json; use ureq::{Agent}; use totp_rs::{Algorithm, TOTP, Secret}; use base64::{Engine as _, engine::general_purpose}; use std::io::Write; /* 认证流程: 获取用户名、主机IP地址(如果用户是root,直接放行) 发送请求到指定接口(超时也放行) 如果返回0表示放行,返回1表示不放行 */ const PAM_AUTH_ERR: c_int = 9; const DEBUG: AtomicBool = AtomicBool::new(false); #[unsafe(no_mangle)] pub extern "C" fn pam_sm_authenticate(pamh: *mut pam_handle,flags: c_int,argc: c_int,argv: *const*const c_char) -> c_int { PAM_SUCCESS } #[unsafe(no_mangle)] pub unsafe extern "C" fn pam_sm_acct_mgmt(pamh: *mut pam_handle,flags: c_int,argc: c_int,argv: *const*const c_char) -> c_int { unsafe { let mut user_ptr: *const c_char = ptr::null(); let prompt = CString::new("Username: ").unwrap(); let status = pam_get_user(pamh,&mut user_ptr, prompt.as_ptr()); if status != PAM_SUCCESS as c_int || user_ptr.is_null() { return PAM_AUTH_ERR; } let uname = CStr::from_ptr(user_ptr).to_string_lossy(); wirte_log(&uname); // 如果是root用户,必然成功 if uname == "root" { return PAM_SUCCESS; } let res_arg = parse_args(argc, argv); if res_arg.is_err() { wirte_log(&format!("parse_args error: {}", res_arg.err().unwrap())); return PAM_SUCCESS; } wirte_log(&format!("parse_args: {:?}", res_arg.as_ref().unwrap())); let arg = res_arg.unwrap(); let res = query_access(&uname, &arg.0, &arg.1); if res.is_err() { wirte_log(&format!("query_access error: {}", res.err().unwrap())); return PAM_SUCCESS; } else { wirte_log(&format!("query_access success: {}", res.as_ref().unwrap())); if *res.as_ref().unwrap() { return PAM_SUCCESS; } else { return PAM_AUTH_ERR; } } } } #[unsafe(no_mangle)] pub extern "C" fn pam_sm_setcred(pamh: *mut pam_handle,flags: c_int,argc: c_int,argv: *const*const c_char) -> c_int { PAM_SUCCESS } /// 解析参数,返回的结果是url和totp密钥 unsafe fn parse_args(argc: c_int, argv:*const*const c_char) -> Result<(String,String)> { let mut args = Vec::::new(); unsafe { for i in 0..argc { let arg_ptr = *argv.offset(i as isize); if !arg_ptr.is_null() { if let Ok(s) = CStr::from_ptr(arg_ptr).to_str() { args.push(s.to_string()); } } } } let mut result:(String,String) = (String::default(),String::default()); for i in &args { if i.starts_with("url=") { result.0 = i.strip_prefix("url=").unwrap().to_string(); continue; } if i.starts_with("totp=") { result.1 = i.strip_prefix("totp=").unwrap().to_string(); continue; } if i == "debug" { DEBUG.store(true,SeqCst); continue; } } Ok(result) } fn wirte_log(msg: &str) { if DEBUG.load(SeqCst) { let mut file = OpenOptions::new().append(true).create(true).open("/tmp/pam_rs.log").unwrap(); _ = writeln!(&mut file, "{}", msg); } } /// 查询是否允许登录,远端返回0表示允许,其他值表示不允许 fn query_access(user: &str, url: &str, secret: &str) -> Result { let sec = Secret::Encoded(secret.to_string()).to_bytes()?; let totp = TOTP::new(Algorithm::SHA1,6,1,30,sec)?; let mut code = totp.generate_current()?; code = general_purpose::STANDARD.encode(code); let a = local_ip_address::list_afinet_netifas()?; let mut ips = Vec::::with_capacity(4); for i in &a { if i.1.is_ipv4() && i.1.to_string() != "127.0.0.1" { ips.push(i.1.to_string()); } } let config = ureq::Agent::config_builder().timeout_global(Some(std::time::Duration::from_secs(1))).build(); let agent: Agent = config.into(); let mut res = agent.post(url) .header("Authorization", format!("Bearer {}", code)) .send(json!({ "user": user, "host": &ips, }).to_string())?; if "0" == res.body_mut().read_to_string()? { return Ok(true); } else { return Ok(false); } } #[test] fn url_get() { let config = ureq::Agent::config_builder().timeout_global(Some(std::time::Duration::from_secs(1))).build(); let agent:Agent = config.into(); let res = agent.get("https://www.baidu.com").call(); if res.is_err() { eprint!("{}", res.err().unwrap()); return; } let body= res.unwrap().body_mut().read_to_string(); print!("ok: {}", body.unwrap()) } #[test] fn totp() { let totp = TOTP::new(Algorithm::SHA1, 6, 1, 30, Secret::Encoded("FRZPBN2FAZMJY7G2FKTBZVXNNU".to_string()).to_bytes().unwrap()).unwrap(); let code = totp.generate_current().unwrap(); println!("{}", code); if let Ok(b) = totp.check_current(&code) { if b { println!("check pass") } else { println!("check error") } } else { println!("check error") } } #[test] fn test_query() { let b = query_access("liming6", "http://127.0.0.1:99", "FRZPBN2FAZMJY7G2FKTBZVXNNU"); match b { Ok(o) => {println!("access: {}", o)}, Err(e) => {println!("err: {}",e)} } } #[test] fn test_get_ip() -> anyhow::Result<()>{ let a = local_ip_address::list_afinet_netifas()?; for i in &a { if i.1.is_ipv4() && i.1.to_string() != "127.0.0.1" { println!("{}", i.1); } } Ok(()) }