// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. // SPDX-License-Identifier: Apache-2.0 use super::{cache::LoRACache, source::LoRASource}; use anyhow::Result; use std::{path::PathBuf, sync::Arc}; pub struct LoRADownloader { sources: Vec>, cache: LoRACache, } impl LoRADownloader { pub fn new(sources: Vec>, cache: LoRACache) -> Self { Self { sources, cache } } /// Download LoRA if not in cache, return local path /// /// For local file:// URIs, this will return the original path without copying. /// For remote URIs (s3://, gcs://, etc.), this will download to cache. pub async fn download_if_needed(&self, lora_uri: &str) -> Result { // For local file:// URIs, don't use cache - just validate and return if lora_uri.starts_with("file://") { for source in &self.sources { // Ignore errors from incompatible sources if let Ok(exists) = source.exists(lora_uri).await && exists { // LocalLoRASource.download() returns the original path return source.download(lora_uri, &PathBuf::new()).await; } } anyhow::bail!("Local LoRA not found: {}", lora_uri); } // For remote URIs, use the URI as the cache key let cache_key = self.uri_to_cache_key(lora_uri); // Check cache first if self.cache.is_cached(&cache_key) && self.cache.validate_cached(&cache_key)? { tracing::debug!("LoRA found in cache: {}", cache_key); return Ok(self.cache.get_cache_path(&cache_key)); } // Try sources in order let dest_path = self.cache.get_cache_path(&cache_key); for source in &self.sources { if let Ok(exists) = source.exists(lora_uri).await && exists { let downloaded_path = source.download(lora_uri, &dest_path).await?; if self.cache.validate_cached(&cache_key)? { return Ok(downloaded_path); } else { tracing::warn!( "Downloaded LoRA at {} failed validation", downloaded_path.display() ); } } } anyhow::bail!("LoRA {} not found in any source", lora_uri) } /// Convert URI to cache key fn uri_to_cache_key(&self, uri: &str) -> String { uri.replace("://", "_").replace(['/', '\\'], "_") } }