Unverified Commit 844000eb authored by Ameen Patel's avatar Ameen Patel Committed by GitHub
Browse files

fix: extend LoRA download S3 timeout and stream large LoRA downloads to disk (#6544)


Signed-off-by: default avatarAmeenP <ameenp360@gmail.com>
Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent f1e8ea6e
......@@ -3,14 +3,14 @@
use anyhow::{Context, Result};
use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt;
use object_store::{ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath};
use object_store::{ClientOptions, ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath};
use std::{
path::{Path, PathBuf},
sync::Arc,
time::Duration,
};
use tokio::io::AsyncWriteExt;
use url::Url;
/// Minimal trait for LoRA sources
......@@ -65,8 +65,6 @@ impl LoRASource for LocalLoRASource {
anyhow::bail!("LoRA path is not a directory: {}", source_path.display());
}
// For local files, we don't copy - just return the source path
// This avoids unnecessary disk I/O
tracing::info!("Using local LoRA at: {:?}", source_path);
Ok(source_path)
......@@ -87,35 +85,52 @@ pub struct S3LoRASource {
endpoint: Option<String>,
}
/// Retry configuration for S3 operations
impl S3LoRASource {
/// Maximum number of retry attempts for S3 operations
const MAX_RETRIES: u32 = 3;
/// Initial backoff duration in milliseconds
const INITIAL_BACKOFF_MS: u64 = 1000;
/// Maximum backoff duration in milliseconds
const MAX_BACKOFF_MS: u64 = 30000;
/// Download a single file with retry logic and exponential backoff
async fn stream_to_file(
store: &Arc<dyn ObjectStore>,
location: &ObjectPath,
dest: &std::path::Path,
) -> Result<u64> {
let get_result = store
.get(location)
.await
.with_context(|| format!("Failed to GET {}", location))?;
let mut stream = get_result.into_stream();
let mut file = tokio::fs::File::create(dest)
.await
.with_context(|| format!("Failed to create file {:?}", dest))?;
let mut total_bytes: u64 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.with_context(|| format!("Error reading stream for {}", location))?;
file.write_all(&chunk)
.await
.with_context(|| format!("Failed to write chunk to {:?}", dest))?;
total_bytes += chunk.len() as u64;
}
file.flush().await?;
Ok(total_bytes)
}
async fn download_file_with_retry(
store: &Arc<dyn ObjectStore>,
location: &ObjectPath,
) -> Result<Bytes> {
dest: &std::path::Path,
) -> Result<u64> {
for attempt in 1..=Self::MAX_RETRIES {
let result = store.get(location).await;
let error = match result {
Ok(get_result) => match get_result.bytes().await {
Ok(bytes) => return Ok(bytes),
Err(e) => anyhow::anyhow!("Failed to read bytes: {}", e),
},
Err(e) => anyhow::anyhow!("Failed to get object: {}", e),
};
match Self::stream_to_file(store, location, dest).await {
Ok(bytes_written) => return Ok(bytes_written),
Err(error) => {
if attempt >= Self::MAX_RETRIES {
return Err(error);
}
// Calculate backoff with exponential increase, capped at MAX_BACKOFF_MS
let backoff_ms = std::cmp::min(
Self::INITIAL_BACKOFF_MS * 2u64.pow(attempt - 1),
Self::MAX_BACKOFF_MS,
......@@ -129,8 +144,9 @@ impl S3LoRASource {
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
}
}
// This should be unreachable, but provide a fallback
Err(anyhow::anyhow!(
"S3 download failed after {} retries",
Self::MAX_RETRIES
......@@ -160,22 +176,26 @@ impl S3LoRASource {
})
}
/// Build an ObjectStore for a specific bucket
fn build_store(&self, bucket: &str) -> Result<Arc<dyn ObjectStore>> {
let timeout_secs: u64 = std::env::var("LORA_DOWNLOAD_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(3600);
let client_opts = ClientOptions::new().with_timeout(Duration::from_secs(timeout_secs));
let mut builder = AmazonS3Builder::new()
.with_access_key_id(&self.access_key_id)
.with_secret_access_key(&self.secret_access_key)
.with_region(&self.region)
.with_bucket_name(bucket);
.with_bucket_name(bucket)
.with_client_options(client_opts);
if let Some(ref endpoint) = self.endpoint {
builder = builder
.with_endpoint(endpoint)
// Use path-style URLs for custom endpoints (e.g., MinIO)
.with_virtual_hosted_style_request(false);
// Only allow HTTP when explicitly enabled via environment variable
// HTTPS is the default for security
if std::env::var("AWS_ALLOW_HTTP")
.map(|v| v.eq_ignore_ascii_case("true"))
.unwrap_or(false)
......@@ -219,15 +239,10 @@ impl LoRASource for S3LoRASource {
prefix
);
// Build store for this specific bucket
let bucket_store = self.build_store(&bucket)?;
// List all objects under the prefix
let object_prefix = ObjectPath::from(prefix.clone());
let mut list_stream = bucket_store.list(Some(&object_prefix));
// Create a temporary directory in the same parent as dest_path for atomic download
// This prevents data loss if dest_path already exists
let parent = dest_path
.parent()
.ok_or_else(|| anyhow::anyhow!("Destination path has no parent directory"))?;
......@@ -236,7 +251,6 @@ impl LoRASource for S3LoRASource {
.and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Destination path has no file name"))?;
// Generate unique temp directory name
let temp_suffix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
......@@ -244,12 +258,10 @@ impl LoRASource for S3LoRASource {
let temp_dir_name = format!("{}.tmp.{}", dest_name, temp_suffix);
let temp_path = parent.join(&temp_dir_name);
// Create temporary directory
tokio::fs::create_dir_all(&temp_path)
.await
.context("Failed to create temporary directory")?;
// Cleanup closure that only removes the temp directory on error
let cleanup_on_error = async |err: anyhow::Error| -> anyhow::Error {
tracing::warn!(
"S3 download failed, cleaning up temporary directory at {:?}",
......@@ -268,7 +280,6 @@ impl LoRASource for S3LoRASource {
Err(e) => return Err(cleanup_on_error(e.into()).await),
};
// Get relative path (remove prefix)
let rel_path = meta
.location
.as_ref()
......@@ -277,12 +288,11 @@ impl LoRASource for S3LoRASource {
.trim_start_matches('/');
if rel_path.is_empty() {
continue; // Skip the prefix itself
continue;
}
let file_path = temp_path.join(rel_path);
// Create parent directories
#[allow(clippy::collapsible_if)]
if let Some(parent) = file_path.parent() {
if let Err(e) = tokio::fs::create_dir_all(parent).await {
......@@ -290,18 +300,16 @@ impl LoRASource for S3LoRASource {
}
}
// Download file with retry logic
let bytes = match Self::download_file_with_retry(&bucket_store, &meta.location).await {
Ok(b) => b,
let bytes_written =
match Self::download_file_with_retry(&bucket_store, &meta.location, &file_path)
.await
{
Ok(n) => n,
Err(e) => return Err(cleanup_on_error(e).await),
};
if let Err(e) = tokio::fs::write(&file_path, &bytes).await {
return Err(cleanup_on_error(e.into()).await);
}
file_count += 1;
tracing::debug!("Downloaded: {} ({} bytes)", rel_path, bytes.len());
tracing::debug!("Downloaded: {} ({} bytes)", rel_path, bytes_written);
}
if file_count == 0 {
......@@ -310,14 +318,11 @@ impl LoRASource for S3LoRASource {
);
}
// Atomically rename temp directory to final destination
// Remove dest_path if it exists (only after successful download to avoid data loss)
if dest_path.exists() {
tokio::fs::remove_dir_all(dest_path)
.await
.context("Failed to remove existing destination directory")?;
}
// Rename is atomic on most filesystems
tokio::fs::rename(&temp_path, dest_path)
.await
.context("Failed to atomically move temporary directory to destination")?;
......@@ -335,7 +340,6 @@ impl LoRASource for S3LoRASource {
let object_prefix = ObjectPath::from(prefix);
let mut list_stream = bucket_store.list(Some(&object_prefix));
// Check if at least one object exists, propagating errors
match list_stream.next().await {
Some(Ok(_)) => Ok(true),
Some(Err(e)) => Err(e.into()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment