build.rs 9.84 KB
Newer Older
Ryan Olson's avatar
Ryan Olson committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
// SPDX-License-Identifier: Apache-2.0

use std::env;
use std::fs;
use std::io::Read;
use std::path::Path;
use std::process::Command;

fn main() {
    println!("cargo:rerun-if-changed=cuda/tensor_kernels.cu");
    println!("cargo:rerun-if-env-changed=DYNAMO_USE_PREBUILT_KERNELS");
    println!("cargo:rerun-if-env-changed=CUDA_ARCHS");

    let use_prebuilt = determine_build_mode();

    if use_prebuilt {
        build_with_prebuilt_kernels();
    } else {
        build_from_source();

        // Only link against CUDA runtime when building from source
        // Add CUDA library search paths
        if let Ok(cuda_path) = env::var("CUDA_PATH") {
            println!("cargo:rustc-link-search=native={}/lib64", cuda_path);
            println!("cargo:rustc-link-search=native={}/lib", cuda_path);
        } else if let Ok(cuda_home) = env::var("CUDA_HOME") {
            println!("cargo:rustc-link-search=native={}/lib64", cuda_home);
            println!("cargo:rustc-link-search=native={}/lib", cuda_home);
        } else {
            // Try standard paths
            println!("cargo:rustc-link-search=native=/usr/local/cuda/lib64");
            println!("cargo:rustc-link-search=native=/usr/local/cuda/lib");
        }

        println!("cargo:rustc-link-lib=cudart");
    }
}

/// Determine whether to use prebuilt kernels based on:
/// 1. Feature flag (highest precedence)
/// 2. Environment variable
/// 3. Auto-detection of nvcc
fn determine_build_mode() -> bool {
    // Check feature flag first
    #[cfg(feature = "prebuilt-kernels")]
    {
        println!("cargo:warning=Using prebuilt kernels (feature flag enabled)");
        return true;
    }

    // Check environment variable
    if dynamo_config::env_is_truthy("DYNAMO_USE_PREBUILT_KERNELS") {
        println!("cargo:warning=Using prebuilt kernels (DYNAMO_USE_PREBUILT_KERNELS set)");
        return true;
    }

    // Auto-detect nvcc
    if !is_nvcc_available() {
        println!("cargo:warning=nvcc not found, using prebuilt kernels");
        return true;
    }

    println!("cargo:warning=Building CUDA kernels from source");
    false
}

fn is_nvcc_available() -> bool {
    Command::new("nvcc").arg("--version").output().is_ok()
}

fn build_with_prebuilt_kernels() {
    let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
    let cu_path = Path::new(&manifest_dir).join("cuda/tensor_kernels.cu");
    let md5_path = Path::new(&manifest_dir).join("cuda/prebuilt/tensor_kernels.md5");
    let fatbin_path = Path::new(&manifest_dir).join("cuda/prebuilt/tensor_kernels.fatbin");

    // Validate that prebuilt files exist
    if !md5_path.exists() {
        panic!(
            "Prebuilt mode requires cuda/prebuilt/tensor_kernels.md5 but it does not exist. \
             Please build with nvcc available first to generate the prebuilt artifacts."
        );
    }

    if !fatbin_path.exists() {
        panic!(
            "Prebuilt mode requires cuda/prebuilt/tensor_kernels.fatbin but it does not exist. \
             Please build with nvcc available first to generate the prebuilt artifacts."
        );
    }

    // Read stored hashes (three lines: build.rs, .cu, .fatbin)
    let stored_hashes_content =
        fs::read_to_string(&md5_path).expect("Failed to read cuda/prebuilt/tensor_kernels.md5");

    let stored_hashes: Vec<&str> = stored_hashes_content.lines().collect();
    if stored_hashes.len() != 3 {
        panic!(
            "Invalid .md5 file format. Expected 3 lines (build.rs, .cu, .fatbin hashes), found {}.\n\
             Please rebuild with nvcc available to regenerate the prebuilt artifacts.",
            stored_hashes.len()
        );
    }

    let stored_build_rs_hash = stored_hashes[0];
    let stored_cu_hash = stored_hashes[1];
    let stored_fatbin_hash = stored_hashes[2];

    // Compute current hashes
    let build_rs_path = Path::new(&manifest_dir).join("build.rs");
    let current_build_rs_hash = compute_file_hash(&build_rs_path);
    let current_cu_hash = compute_file_hash(&cu_path);
    let current_fatbin_hash = compute_file_hash(&fatbin_path);

    // Validate all three hashes
    let mut mismatches = Vec::new();

    if current_build_rs_hash != stored_build_rs_hash {
        mismatches.push(format!(
            "  build.rs: current={}, stored={}",
            current_build_rs_hash, stored_build_rs_hash
        ));
    }

    if current_cu_hash != stored_cu_hash {
        mismatches.push(format!(
            "  .cu source: current={}, stored={}",
            current_cu_hash, stored_cu_hash
        ));
    }

    if current_fatbin_hash != stored_fatbin_hash {
        mismatches.push(format!(
            "  .fatbin: current={}, stored={}",
            current_fatbin_hash, stored_fatbin_hash
        ));
    }

    if !mismatches.is_empty() {
        panic!(
            "Hash mismatch! The prebuilt .fatbin is out of sync:\n{}\n\
             Please rebuild with nvcc available to regenerate the prebuilt artifacts.",
            mismatches.join("\n")
        );
    }

    println!("cargo:warning=Hash validation passed:");
    println!("cargo:warning=  build.rs: {}", current_build_rs_hash);
    println!("cargo:warning=  .cu source: {}", current_cu_hash);
    println!("cargo:warning=  .fatbin: {}", current_fatbin_hash);

    // Link the prebuilt fatbin
    // Note: We need to inform the linker about the fatbin file.
    // The typical approach is to use cc to link it as an object file or
    // use CUDA's fatbinary tool. For simplicity, we'll use cc to link it.
    let out_dir = env::var("OUT_DIR").unwrap();
    let fatbin_copy = Path::new(&out_dir).join("tensor_kernels.fatbin");
    fs::copy(&fatbin_path, &fatbin_copy).expect("Failed to copy .fatbin to OUT_DIR");

    // Link the fatbin as a dependency
    println!("cargo:rustc-link-search=native={}", out_dir);

    // Create a stub object file that references the fatbin
    // This is a workaround since we can't directly link .fatbin files
    // In a real scenario, you'd use cuModuleLoadFatBinary at runtime
    println!(
        "cargo:warning=Prebuilt kernel loaded from {}",
        fatbin_path.display()
    );
}

fn build_from_source() {
    let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
    let cu_path = Path::new(&manifest_dir).join("cuda/tensor_kernels.cu");
    let out_dir = env::var("OUT_DIR").unwrap();

    // Build with cc crate
    let mut build = cc::Build::new();
    build
        .cuda(true)
        .file(&cu_path)
        .flag("-std=c++17")
        .flag("-O3")
        .flag("-Xcompiler")
        .flag("-fPIC");

    // Configure CUDA architectures
    let arch_flags = get_cuda_arch_flags();
    for flag in &arch_flags {
        build.flag(flag);
    }

    build.compile("tensor_kernels");

    // Generate .fatbin and .md5 for future prebuilt use
    generate_prebuilt_artifacts(&cu_path, &arch_flags, &out_dir);
}

fn get_cuda_arch_flags() -> Vec<String> {
    let mut flags = Vec::new();

    if let Ok(arch_list) = env::var("CUDA_ARCHS") {
        for arch in arch_list.split(',') {
            let arch = arch.trim();
            if arch.is_empty() {
                continue;
            }
            flags.push(format!("-gencode=arch=compute_{},code=sm_{}", arch, arch));
        }
    } else {
        // Default to Ampere (SM 80) and Hopper (SM 90) support.
        flags.push("-gencode=arch=compute_80,code=sm_80".to_string());
        flags.push("-gencode=arch=compute_90,code=sm_90".to_string());
    }

    flags
}

fn generate_prebuilt_artifacts(cu_path: &Path, arch_flags: &[String], out_dir: &str) {
    let manifest_dir = env::var("CARGO_MANIFEST_DIR").unwrap();
    let prebuilt_dir = Path::new(&manifest_dir).join("cuda/prebuilt");
    let fatbin_path = prebuilt_dir.join("tensor_kernels.fatbin");
    let md5_path = prebuilt_dir.join("tensor_kernels.md5");

    // Ensure prebuilt directory exists
    fs::create_dir_all(&prebuilt_dir).expect("Failed to create cuda/prebuilt directory");

    // Generate .fatbin using nvcc
    let temp_fatbin = Path::new(out_dir).join("tensor_kernels.fatbin");

    let mut nvcc_cmd = Command::new("nvcc");
    nvcc_cmd
        .arg("-fatbin")
        .arg("-std=c++17")
        .arg("-O3")
        .arg(cu_path)
        .arg("-o")
        .arg(&temp_fatbin);

    for flag in arch_flags {
        nvcc_cmd.arg(flag);
    }

    println!("cargo:warning=Generating .fatbin with nvcc...");
    let status = nvcc_cmd
        .status()
        .expect("Failed to execute nvcc for .fatbin generation");

    if !status.success() {
        panic!("nvcc failed to generate .fatbin");
    }

    // Copy .fatbin to prebuilt directory
    fs::copy(&temp_fatbin, &fatbin_path).expect("Failed to copy .fatbin to cuda/prebuilt/");

    // Generate MD5 hashes of all three files for consistency validation
    let build_rs_path = Path::new(&manifest_dir).join("build.rs");
    let build_rs_hash = compute_file_hash(&build_rs_path);
    let cu_hash = compute_file_hash(cu_path);
    let fatbin_hash = compute_file_hash(&fatbin_path);

    // Write all three hashes (one per line)
    let hashes = format!("{}\n{}\n{}\n", build_rs_hash, cu_hash, fatbin_hash);
    fs::write(&md5_path, hashes).expect("Failed to write .md5 file");

    println!(
        "cargo:warning=Generated prebuilt artifacts:\n  {}\n  {}",
        fatbin_path.display(),
        md5_path.display()
    );
    println!("cargo:warning=build.rs hash: {}", build_rs_hash);
    println!("cargo:warning=.cu source hash: {}", cu_hash);
    println!("cargo:warning=.fatbin hash: {}", fatbin_hash);
}

fn compute_file_hash(path: &Path) -> String {
    let mut file = fs::File::open(path)
        .unwrap_or_else(|e| panic!("Failed to open {} for hashing: {}", path.display(), e));

    let mut buffer = Vec::new();
    file.read_to_end(&mut buffer)
        .unwrap_or_else(|e| panic!("Failed to read {} for hashing: {}", path.display(), e));

    format!("{:x}", md5::compute(&buffer))
}