Unverified Commit 844000eb authored by Ameen Patel's avatar Ameen Patel Committed by GitHub
Browse files

fix: extend LoRA download S3 timeout and stream large LoRA downloads to disk (#6544)


Signed-off-by: default avatarAmeenP <ameenp360@gmail.com>
Co-authored-by: default avatarBiswa Panda <biswa.panda@gmail.com>
parent f1e8ea6e
...@@ -3,14 +3,14 @@ ...@@ -3,14 +3,14 @@
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use async_trait::async_trait; use async_trait::async_trait;
use bytes::Bytes;
use futures::StreamExt; use futures::StreamExt;
use object_store::{ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath}; use object_store::{ClientOptions, ObjectStore, aws::AmazonS3Builder, path::Path as ObjectPath};
use std::{ use std::{
path::{Path, PathBuf}, path::{Path, PathBuf},
sync::Arc, sync::Arc,
time::Duration, time::Duration,
}; };
use tokio::io::AsyncWriteExt;
use url::Url; use url::Url;
/// Minimal trait for LoRA sources /// Minimal trait for LoRA sources
...@@ -65,8 +65,6 @@ impl LoRASource for LocalLoRASource { ...@@ -65,8 +65,6 @@ impl LoRASource for LocalLoRASource {
anyhow::bail!("LoRA path is not a directory: {}", source_path.display()); anyhow::bail!("LoRA path is not a directory: {}", source_path.display());
} }
// For local files, we don't copy - just return the source path
// This avoids unnecessary disk I/O
tracing::info!("Using local LoRA at: {:?}", source_path); tracing::info!("Using local LoRA at: {:?}", source_path);
Ok(source_path) Ok(source_path)
...@@ -87,50 +85,68 @@ pub struct S3LoRASource { ...@@ -87,50 +85,68 @@ pub struct S3LoRASource {
endpoint: Option<String>, endpoint: Option<String>,
} }
/// Retry configuration for S3 operations
impl S3LoRASource { impl S3LoRASource {
/// Maximum number of retry attempts for S3 operations
const MAX_RETRIES: u32 = 3; const MAX_RETRIES: u32 = 3;
/// Initial backoff duration in milliseconds
const INITIAL_BACKOFF_MS: u64 = 1000; const INITIAL_BACKOFF_MS: u64 = 1000;
/// Maximum backoff duration in milliseconds
const MAX_BACKOFF_MS: u64 = 30000; const MAX_BACKOFF_MS: u64 = 30000;
/// Download a single file with retry logic and exponential backoff async fn stream_to_file(
store: &Arc<dyn ObjectStore>,
location: &ObjectPath,
dest: &std::path::Path,
) -> Result<u64> {
let get_result = store
.get(location)
.await
.with_context(|| format!("Failed to GET {}", location))?;
let mut stream = get_result.into_stream();
let mut file = tokio::fs::File::create(dest)
.await
.with_context(|| format!("Failed to create file {:?}", dest))?;
let mut total_bytes: u64 = 0;
while let Some(chunk) = stream.next().await {
let chunk = chunk.with_context(|| format!("Error reading stream for {}", location))?;
file.write_all(&chunk)
.await
.with_context(|| format!("Failed to write chunk to {:?}", dest))?;
total_bytes += chunk.len() as u64;
}
file.flush().await?;
Ok(total_bytes)
}
async fn download_file_with_retry( async fn download_file_with_retry(
store: &Arc<dyn ObjectStore>, store: &Arc<dyn ObjectStore>,
location: &ObjectPath, location: &ObjectPath,
) -> Result<Bytes> { dest: &std::path::Path,
) -> Result<u64> {
for attempt in 1..=Self::MAX_RETRIES { for attempt in 1..=Self::MAX_RETRIES {
let result = store.get(location).await; match Self::stream_to_file(store, location, dest).await {
let error = match result { Ok(bytes_written) => return Ok(bytes_written),
Ok(get_result) => match get_result.bytes().await { Err(error) => {
Ok(bytes) => return Ok(bytes), if attempt >= Self::MAX_RETRIES {
Err(e) => anyhow::anyhow!("Failed to read bytes: {}", e), return Err(error);
}, }
Err(e) => anyhow::anyhow!("Failed to get object: {}", e),
}; let backoff_ms = std::cmp::min(
Self::INITIAL_BACKOFF_MS * 2u64.pow(attempt - 1),
if attempt >= Self::MAX_RETRIES { Self::MAX_BACKOFF_MS,
return Err(error); );
tracing::warn!(
"S3 download failed (attempt {}/{}), retrying in {}ms: {}",
attempt,
Self::MAX_RETRIES,
backoff_ms,
error
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
}
} }
// Calculate backoff with exponential increase, capped at MAX_BACKOFF_MS
let backoff_ms = std::cmp::min(
Self::INITIAL_BACKOFF_MS * 2u64.pow(attempt - 1),
Self::MAX_BACKOFF_MS,
);
tracing::warn!(
"S3 download failed (attempt {}/{}), retrying in {}ms: {}",
attempt,
Self::MAX_RETRIES,
backoff_ms,
error
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
} }
// This should be unreachable, but provide a fallback
Err(anyhow::anyhow!( Err(anyhow::anyhow!(
"S3 download failed after {} retries", "S3 download failed after {} retries",
Self::MAX_RETRIES Self::MAX_RETRIES
...@@ -160,22 +176,26 @@ impl S3LoRASource { ...@@ -160,22 +176,26 @@ impl S3LoRASource {
}) })
} }
/// Build an ObjectStore for a specific bucket
fn build_store(&self, bucket: &str) -> Result<Arc<dyn ObjectStore>> { fn build_store(&self, bucket: &str) -> Result<Arc<dyn ObjectStore>> {
let timeout_secs: u64 = std::env::var("LORA_DOWNLOAD_TIMEOUT_SECS")
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(3600);
let client_opts = ClientOptions::new().with_timeout(Duration::from_secs(timeout_secs));
let mut builder = AmazonS3Builder::new() let mut builder = AmazonS3Builder::new()
.with_access_key_id(&self.access_key_id) .with_access_key_id(&self.access_key_id)
.with_secret_access_key(&self.secret_access_key) .with_secret_access_key(&self.secret_access_key)
.with_region(&self.region) .with_region(&self.region)
.with_bucket_name(bucket); .with_bucket_name(bucket)
.with_client_options(client_opts);
if let Some(ref endpoint) = self.endpoint { if let Some(ref endpoint) = self.endpoint {
builder = builder builder = builder
.with_endpoint(endpoint) .with_endpoint(endpoint)
// Use path-style URLs for custom endpoints (e.g., MinIO)
.with_virtual_hosted_style_request(false); .with_virtual_hosted_style_request(false);
// Only allow HTTP when explicitly enabled via environment variable
// HTTPS is the default for security
if std::env::var("AWS_ALLOW_HTTP") if std::env::var("AWS_ALLOW_HTTP")
.map(|v| v.eq_ignore_ascii_case("true")) .map(|v| v.eq_ignore_ascii_case("true"))
.unwrap_or(false) .unwrap_or(false)
...@@ -219,15 +239,10 @@ impl LoRASource for S3LoRASource { ...@@ -219,15 +239,10 @@ impl LoRASource for S3LoRASource {
prefix prefix
); );
// Build store for this specific bucket
let bucket_store = self.build_store(&bucket)?; let bucket_store = self.build_store(&bucket)?;
// List all objects under the prefix
let object_prefix = ObjectPath::from(prefix.clone()); let object_prefix = ObjectPath::from(prefix.clone());
let mut list_stream = bucket_store.list(Some(&object_prefix)); let mut list_stream = bucket_store.list(Some(&object_prefix));
// Create a temporary directory in the same parent as dest_path for atomic download
// This prevents data loss if dest_path already exists
let parent = dest_path let parent = dest_path
.parent() .parent()
.ok_or_else(|| anyhow::anyhow!("Destination path has no parent directory"))?; .ok_or_else(|| anyhow::anyhow!("Destination path has no parent directory"))?;
...@@ -236,7 +251,6 @@ impl LoRASource for S3LoRASource { ...@@ -236,7 +251,6 @@ impl LoRASource for S3LoRASource {
.and_then(|n| n.to_str()) .and_then(|n| n.to_str())
.ok_or_else(|| anyhow::anyhow!("Destination path has no file name"))?; .ok_or_else(|| anyhow::anyhow!("Destination path has no file name"))?;
// Generate unique temp directory name
let temp_suffix = std::time::SystemTime::now() let temp_suffix = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH) .duration_since(std::time::UNIX_EPOCH)
.unwrap() .unwrap()
...@@ -244,12 +258,10 @@ impl LoRASource for S3LoRASource { ...@@ -244,12 +258,10 @@ impl LoRASource for S3LoRASource {
let temp_dir_name = format!("{}.tmp.{}", dest_name, temp_suffix); let temp_dir_name = format!("{}.tmp.{}", dest_name, temp_suffix);
let temp_path = parent.join(&temp_dir_name); let temp_path = parent.join(&temp_dir_name);
// Create temporary directory
tokio::fs::create_dir_all(&temp_path) tokio::fs::create_dir_all(&temp_path)
.await .await
.context("Failed to create temporary directory")?; .context("Failed to create temporary directory")?;
// Cleanup closure that only removes the temp directory on error
let cleanup_on_error = async |err: anyhow::Error| -> anyhow::Error { let cleanup_on_error = async |err: anyhow::Error| -> anyhow::Error {
tracing::warn!( tracing::warn!(
"S3 download failed, cleaning up temporary directory at {:?}", "S3 download failed, cleaning up temporary directory at {:?}",
...@@ -268,7 +280,6 @@ impl LoRASource for S3LoRASource { ...@@ -268,7 +280,6 @@ impl LoRASource for S3LoRASource {
Err(e) => return Err(cleanup_on_error(e.into()).await), Err(e) => return Err(cleanup_on_error(e.into()).await),
}; };
// Get relative path (remove prefix)
let rel_path = meta let rel_path = meta
.location .location
.as_ref() .as_ref()
...@@ -277,12 +288,11 @@ impl LoRASource for S3LoRASource { ...@@ -277,12 +288,11 @@ impl LoRASource for S3LoRASource {
.trim_start_matches('/'); .trim_start_matches('/');
if rel_path.is_empty() { if rel_path.is_empty() {
continue; // Skip the prefix itself continue;
} }
let file_path = temp_path.join(rel_path); let file_path = temp_path.join(rel_path);
// Create parent directories
#[allow(clippy::collapsible_if)] #[allow(clippy::collapsible_if)]
if let Some(parent) = file_path.parent() { if let Some(parent) = file_path.parent() {
if let Err(e) = tokio::fs::create_dir_all(parent).await { if let Err(e) = tokio::fs::create_dir_all(parent).await {
...@@ -290,18 +300,16 @@ impl LoRASource for S3LoRASource { ...@@ -290,18 +300,16 @@ impl LoRASource for S3LoRASource {
} }
} }
// Download file with retry logic let bytes_written =
let bytes = match Self::download_file_with_retry(&bucket_store, &meta.location).await { match Self::download_file_with_retry(&bucket_store, &meta.location, &file_path)
Ok(b) => b, .await
Err(e) => return Err(cleanup_on_error(e).await), {
}; Ok(n) => n,
Err(e) => return Err(cleanup_on_error(e).await),
if let Err(e) = tokio::fs::write(&file_path, &bytes).await { };
return Err(cleanup_on_error(e.into()).await);
}
file_count += 1; file_count += 1;
tracing::debug!("Downloaded: {} ({} bytes)", rel_path, bytes.len()); tracing::debug!("Downloaded: {} ({} bytes)", rel_path, bytes_written);
} }
if file_count == 0 { if file_count == 0 {
...@@ -310,14 +318,11 @@ impl LoRASource for S3LoRASource { ...@@ -310,14 +318,11 @@ impl LoRASource for S3LoRASource {
); );
} }
// Atomically rename temp directory to final destination
// Remove dest_path if it exists (only after successful download to avoid data loss)
if dest_path.exists() { if dest_path.exists() {
tokio::fs::remove_dir_all(dest_path) tokio::fs::remove_dir_all(dest_path)
.await .await
.context("Failed to remove existing destination directory")?; .context("Failed to remove existing destination directory")?;
} }
// Rename is atomic on most filesystems
tokio::fs::rename(&temp_path, dest_path) tokio::fs::rename(&temp_path, dest_path)
.await .await
.context("Failed to atomically move temporary directory to destination")?; .context("Failed to atomically move temporary directory to destination")?;
...@@ -335,7 +340,6 @@ impl LoRASource for S3LoRASource { ...@@ -335,7 +340,6 @@ impl LoRASource for S3LoRASource {
let object_prefix = ObjectPath::from(prefix); let object_prefix = ObjectPath::from(prefix);
let mut list_stream = bucket_store.list(Some(&object_prefix)); let mut list_stream = bucket_store.list(Some(&object_prefix));
// Check if at least one object exists, propagating errors
match list_stream.next().await { match list_stream.next().await {
Some(Ok(_)) => Ok(true), Some(Ok(_)) => Ok(true),
Some(Err(e)) => Err(e.into()), Some(Err(e)) => Err(e.into()),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment