build.rs 1.62 KB
Newer Older
yongshk's avatar
yongshk committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
use anyhow::{Context, Result};
use std::path::PathBuf;
use std::env;

fn main() -> Result<()> {

    env::set_var("CANDLE_FLASH_ATTN_BUILD_DIR", "./candle-flash-attention-lib/");

    println!("cargo:rerun-if-changed=build.rs");

    // 获取外部库路径
    let build_dir = match std::env::var("CANDLE_FLASH_ATTN_BUILD_DIR") {
        Err(_) => {
            // 如果没有设置,返回错误
            return Err(anyhow::anyhow!("Error: CANDLE_FLASH_ATTN_BUILD_DIR not set. Please set it to the directory containing the static library."));
        }
        Ok(path) => {
            // 使用环境变量指定的路径
            let path = PathBuf::from(path);
            path.canonicalize().context(format!(
                "Error: Directory doesn't exist: {} (the current directory is {})",
                path.display(),
                std::env::current_dir()?.display()
            ))?
        }
    };

    // 检查外部库是否存在
    let out_file = build_dir.join("libflashattention.a");
    if out_file.exists() {
        println!("Using cached static library at {}", out_file.display());
    } else {
        // 库未找到,返回错误并中断执行
        return Err(anyhow::anyhow!("Error: Static library not found at {}. Please ensure the library is present.", out_file.display()));
    }

    // 添加链接搜索路径和链接库
    println!("cargo:rustc-link-search={}", build_dir.display());
    println!("cargo:rustc-link-lib=static=flashattention");
    println!("cargo:rustc-link-lib=dylib=cudart"); // CUDA Runtime
    println!("cargo:rustc-link-lib=dylib=stdc++"); // C++ Standard Library

    Ok(())
}