use std::ffi::{CStr, CString};
use std::fs::OpenOptions;
use std::os::raw::{c_int,c_char};
use std::ptr;
use std::sync::atomic::{AtomicBool, Ordering::SeqCst};
use anyhow::Result;
use libpam_sys::{pam_get_user,pam_handle,PAM_SUCCESS};
use serde_json::json;
use ureq::{Agent};
use totp_rs::{Algorithm, TOTP, Secret};
use base64::{Engine as _, engine::general_purpose};
use std::io::Write;

/*
认证流程：

获取用户名、主机IP地址（如果用户是root，直接放行）
发送请求到指定接口（超时也放行）
如果返回0表示放行，返回1表示不放行

*/

const PAM_AUTH_ERR: c_int = 9;
const DEBUG: AtomicBool = AtomicBool::new(false);


#[unsafe(no_mangle)]
pub extern "C" fn pam_sm_authenticate(pamh: *mut pam_handle,flags: c_int,argc: c_int,argv: *const*const c_char) -> c_int {
    PAM_SUCCESS
}

#[unsafe(no_mangle)]
pub unsafe extern "C" fn pam_sm_acct_mgmt(pamh: *mut pam_handle,flags: c_int,argc: c_int,argv: *const*const c_char) -> c_int {
    unsafe {
        let mut user_ptr: *const c_char = ptr::null();
        let prompt = CString::new("Username: ").unwrap();
        let status = pam_get_user(pamh,&mut user_ptr, prompt.as_ptr());
        if status != PAM_SUCCESS as c_int || user_ptr.is_null() {
            return PAM_AUTH_ERR;
        }
        let uname = CStr::from_ptr(user_ptr).to_string_lossy();
        wirte_log(&uname);
        // 如果是root用户，必然成功
        if uname == "root" {
            return PAM_SUCCESS;
        }
        let res_arg = parse_args(argc, argv);
        if res_arg.is_err() {
            wirte_log(&format!("parse_args error: {}", res_arg.err().unwrap()));
            return PAM_SUCCESS;
        }
        wirte_log(&format!("parse_args: {:?}", res_arg.as_ref().unwrap()));
        let arg = res_arg.unwrap();
        let res = query_access(&uname, &arg.0, &arg.1);
        if res.is_err() {
            wirte_log(&format!("query_access error: {}", res.err().unwrap()));
            return PAM_SUCCESS;
        } else {
            wirte_log(&format!("query_access success: {}", res.as_ref().unwrap()));
            if *res.as_ref().unwrap() {
                return PAM_SUCCESS;
            } else {
                return PAM_AUTH_ERR;
            }
        }
    }
}

#[unsafe(no_mangle)]
pub extern "C" fn pam_sm_setcred(pamh: *mut pam_handle,flags: c_int,argc: c_int,argv: *const*const c_char) -> c_int {
    PAM_SUCCESS
}

/// 解析参数，返回的结果是url和totp密钥
unsafe fn parse_args(argc: c_int, argv:*const*const c_char) -> Result<(String,String)> {
    let mut args = Vec::<String>::new();
    unsafe {
        for i in 0..argc {
            let arg_ptr = *argv.offset(i as isize);
            if !arg_ptr.is_null() {
                if let Ok(s) = CStr::from_ptr(arg_ptr).to_str() {
                    args.push(s.to_string());
                }
            }
        }
    }
    let mut result:(String,String) = (String::default(),String::default());
    for i in &args {
        if i.starts_with("url=") {
            result.0 = i.strip_prefix("url=").unwrap().to_string();
            continue;
        }
        if i.starts_with("totp=") {
            result.1 = i.strip_prefix("totp=").unwrap().to_string();
            continue;
        }
        if i == "debug" {
            DEBUG.store(true,SeqCst);
            continue;
        }
    }
    Ok(result)
}



fn wirte_log(msg: &str) {
    if DEBUG.load(SeqCst) {
        let mut file = OpenOptions::new().append(true).create(true).open("/tmp/pam_rs.log").unwrap();
        _ = writeln!(&mut file, "{}", msg);
    }
}



/// 查询是否允许登录，远端返回0表示允许，其他值表示不允许
fn query_access(user: &str, url: &str, secret: &str) -> Result<bool> {

    let sec = Secret::Encoded(secret.to_string()).to_bytes()?;
    let totp = TOTP::new(Algorithm::SHA1,6,1,30,sec)?;
    let mut code = totp.generate_current()?;
    code = general_purpose::STANDARD.encode(code);

    let a = local_ip_address::list_afinet_netifas()?;
    let mut ips = Vec::<String>::with_capacity(4);
    for i in &a {
        if i.1.is_ipv4() && i.1.to_string() != "127.0.0.1" {
            ips.push(i.1.to_string());
        }
    }
    let config = ureq::Agent::config_builder().timeout_global(Some(std::time::Duration::from_secs(1))).build();
    let agent: Agent = config.into();
    let mut res = agent.post(url)
        .header("Authorization", format!("Bearer {}", code))
        .send(json!({
            "user": user,
            "host": &ips,
        }).to_string())?;
    if "0" == res.body_mut().read_to_string()? {
        return Ok(true);
    } else {
        return Ok(false);
    }
}


#[test]
fn url_get() {
    let config = ureq::Agent::config_builder().timeout_global(Some(std::time::Duration::from_secs(1))).build();
    let agent:Agent = config.into();
    let res = agent.get("https://www.baidu.com").call();
    if res.is_err() {
        eprint!("{}", res.err().unwrap());
        return;
    } 
    let body=  res.unwrap().body_mut().read_to_string();
    print!("ok: {}", body.unwrap())
}

#[test]
fn totp() {
    let totp = TOTP::new(Algorithm::SHA1, 6, 1, 30, 
        Secret::Encoded("FRZPBN2FAZMJY7G2FKTBZVXNNU".to_string()).to_bytes().unwrap()).unwrap();
    let code = totp.generate_current().unwrap();
    println!("{}", code);
    if let Ok(b) = totp.check_current(&code) {
        if b {
            println!("check pass")
        } else {
            println!("check error")
        }
    } else {
        println!("check error")
    }
}

#[test]
fn test_query() {
    let b = query_access("liming6",  "http://127.0.0.1:99", "FRZPBN2FAZMJY7G2FKTBZVXNNU");
    match b {
        Ok(o) => {println!("access: {}", o)},
        Err(e) => {println!("err: {}",e)}
    }
}

#[test]
fn test_get_ip() -> anyhow::Result<()>{
    let a = local_ip_address::list_afinet_netifas()?;
    for i in &a {
        if i.1.is_ipv4() && i.1.to_string() != "127.0.0.1" {
            println!("{}", i.1);
        }
    }
    Ok(())
}