# Copyright (c) 2025, NVIDIA CORPORATION. # SPDX-License-Identifier: BSD-3-Clause import urllib.parse as _up from datetime import datetime, timezone from email.utils import formatdate from hashlib import md5 from http import HTTPStatus from http.server import BaseHTTPRequestHandler from typing import Protocol from .auth import InvalidSignature, S3Auth from .state import S3State __all__ = ["S3RequestHandler"] class S3RequestHandler(BaseHTTPRequestHandler): """HTTP request handler implementing a minimal S3-compatible API. This handler processes HTTP requests and maps them to S3 operations. It supports basic S3 operations like bucket and object management, including multipart uploads. """ server: "S3ServerProtocol" # type: ignore[assignment] def log_message(self, fmt: str, *args): """Log a message to stdout. Args: fmt: Format string for the message. *args: Arguments to format the message with. """ print(f"{self.client_address[0]} - - {fmt % args}") def do_PUT(self): """Handle PUT requests for object creation and bucket creation.""" self._handle_write() def do_GET(self): """Handle GET requests for object retrieval and bucket listing.""" self._handle_read(listing=False) def do_HEAD(self): """Handle HEAD requests for object metadata.""" self._handle_read(listing=False, only_headers=True) def do_DELETE(self): """Handle DELETE requests for object and bucket deletion.""" self._handle_delete() def do_POST(self): """Handle POST requests for multipart upload operations.""" self._handle_post() def _read_body(self) -> bytes: """Read and return the request body. Returns: The request body as bytes. """ length = int(self.headers.get("Content-Length", 0)) if length == 0: return b"" data = self.rfile.read(length) return data def _split_path(self) -> tuple[str, str, _up.ParseResult]: """Split the request path into bucket and key components. Returns: A tuple of (bucket, key, parsed_url). """ parsed = _up.urlparse(self.path) parts = [p for p in parsed.path.split("/") if p] bucket = parts[0] if parts else "" key = "/".join(parts[1:]) if len(parts) > 1 else "" return bucket, key, parsed def _auth(self, payload: bytes, parsed: _up.ParseResult) -> bool: """Verify the request signature. Args: payload: The request body. parsed: The parsed URL. Returns: True if authentication succeeds, False otherwise. """ try: self.server.auth.verify( method=self.command, canonical_uri=parsed.path or "/", canonical_querystring=parsed.query, headers=self.headers, payload=payload, ) except InvalidSignature as err: self._send_error(HTTPStatus.FORBIDDEN, str(err)) return False except ValueError as err: self._send_error(HTTPStatus.BAD_REQUEST, str(err)) return False return True def _handle_write(self): """Handle PUT requests for object creation and bucket creation.""" bucket, key, parsed = self._split_path() body = self._read_body() if not self._auth(body, parsed): return qs = _up.parse_qs(parsed.query, keep_blank_values=True) # Multipart: upload part if "uploadId" in qs and "partNumber" in qs: upload_id = qs["uploadId"][0] try: part_no = int(qs["partNumber"][0]) except ValueError: self._send_error(HTTPStatus.BAD_REQUEST, "Invalid partNumber") return try: self.server.state.upload_part(upload_id, part_no, body) except KeyError: self._send_error(HTTPStatus.NOT_FOUND, "Upload not found") return self._send_status(HTTPStatus.OK, extra_headers={"ETag": _etag(body)}) return if not bucket: self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified") return if key == "": # Bucket create self.server.state.create_bucket(bucket) self._send_status(HTTPStatus.OK) return # Put object self.server.state.put_object(bucket, key, body) self._send_status( HTTPStatus.OK, extra_headers={"ETag": _etag(body)}, ) def _handle_read(self, listing: bool, only_headers: bool = False): """Handle GET/HEAD requests for object retrieval and bucket listing. Args: listing: Whether this is a bucket listing request. only_headers: Whether to return only headers (HEAD request). """ bucket, key, parsed = self._split_path() body = b"" # GET/HEAD normally payload considered in signature (hash of empty string) if not self._auth(body, parsed): return if not bucket: self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified") return if key == "": # List bucket contents if not listing: # We treat listing with GET only try: objects = self.server.state.list_objects(bucket) except KeyError: self._send_error(HTTPStatus.NOT_FOUND, "Bucket not found") return xml_body = self._render_bucket_list(bucket, objects) self._send_bytes(xml_body, content_type="application/xml") else: self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Listing not implemented") return try: data = self.server.state.get_object(bucket, key) except FileNotFoundError: self._send_error(HTTPStatus.NOT_FOUND, "Not found") return range_header = self.headers.get("Range") if range_header and range_header.startswith("bytes="): rng = range_header.split("=", 1)[1] if "-" not in rng: self._send_error(HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, "Invalid Range") return start_str, end_str = rng.split("-", 1) try: start = int(start_str) if start_str else 0 end = int(end_str) if end_str else len(data) - 1 except ValueError: self._send_error(HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, "Invalid Range") return if start > end or start >= len(data): self._send_error(HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE, "Invalid Range") return end = min(end, len(data) - 1) slice_data = data[start : end + 1] headers = { "Content-Range": f"bytes {start}-{end}/{len(data)}", "Accept-Ranges": "bytes", "Content-Length": str(len(slice_data)), "ETag": _etag(data), } if only_headers: headers.setdefault("Content-Type", "application/octet-stream") headers.setdefault("Last-Modified", formatdate(usegmt=True)) self._send_status(HTTPStatus.PARTIAL_CONTENT, extra_headers=headers) else: self._send_bytes( slice_data, status=HTTPStatus.PARTIAL_CONTENT, content_type="application/octet-stream", extra_headers=headers, ) else: if only_headers: self._send_status( HTTPStatus.OK, extra_headers={ "Content-Length": str(len(data)), "Accept-Ranges": "bytes", "Content-Type": "application/octet-stream", "Last-Modified": formatdate(usegmt=True), "ETag": _etag(data), }, ) else: self._send_bytes( data, content_type="application/octet-stream", extra_headers={"Accept-Ranges": "bytes"}, ) def _handle_delete(self): """Handle DELETE requests for object and bucket deletion.""" bucket, key, parsed = self._split_path() body = b"" # empty if not self._auth(body, parsed): return if not bucket: self._send_error(HTTPStatus.BAD_REQUEST, "Bucket must be specified") return if key == "": try: self.server.state.delete_bucket(bucket) except (KeyError, RuntimeError) as err: self._send_error(HTTPStatus.BAD_REQUEST, str(err)) return self._send_status(HTTPStatus.NO_CONTENT) return try: self.server.state.delete_object(bucket, key) except FileNotFoundError: self._send_error(HTTPStatus.NOT_FOUND, "Not found") return self._send_status(HTTPStatus.NO_CONTENT) def _handle_post(self): """Handle POST requests for multipart upload operations.""" bucket, key, parsed = self._split_path() body = self._read_body() if not self._auth(body, parsed): return qs = _up.parse_qs(parsed.query, keep_blank_values=True) # Initiate multipart: POST ?uploads if "uploads" in qs or parsed.query == "uploads": upload_id = self.server.state.initiate_multipart(bucket, key) xml = ( '' "" f"{_escape_xml(bucket)}" f"{_escape_xml(key)}" f"{upload_id}" "" ).encode() self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml") return # Complete multipart: POST ?uploadId=xxxx if "uploadId" in qs: upload_id = qs["uploadId"][0] try: self.server.state.complete_multipart(upload_id) except KeyError: self._send_error(HTTPStatus.NOT_FOUND, "Upload not found") return xml = ( '' "" f"{_escape_xml(bucket)}" f"{_escape_xml(key)}" f"{upload_id}" "" ).encode() self._send_bytes(xml, status=HTTPStatus.OK, content_type="application/xml") return self._send_error(HTTPStatus.NOT_IMPLEMENTED, "Unsupported POST request") def _send_status(self, status: HTTPStatus, extra_headers: dict[str, str] | None = None): """Send an HTTP response with the given status code. Args: status: The HTTP status code to send. extra_headers: Optional additional headers to include. """ self.send_response(status.value) headers = {"Server": "s3-emulator"} if extra_headers: headers.update(extra_headers) for k, v in headers.items(): self.send_header(k, v) self.end_headers() def _send_error(self, status: HTTPStatus, message: str): """Send an error response. Args: status: The HTTP status code to send. message: The error message to include in the response. """ print(f"Error {status}: {message}") self._send_bytes(message.encode(), status=status, content_type="text/plain") def _send_bytes( self, data: bytes, status: HTTPStatus = HTTPStatus.OK, content_type: str = "application/octet-stream", extra_headers: dict[str, str] | None = None, ) -> None: """Send a response with binary data. Args: data: The binary data to send. status: The HTTP status code to send. Defaults to 200 OK. content_type: The Content-Type header value. Defaults to application/octet-stream. extra_headers: Optional additional headers to include. """ self.send_response(status.value) headers = { "Server": "s3-emulator", "Content-Type": content_type, "Content-Length": str(len(data)), } if extra_headers: headers.update(extra_headers) for k, v in headers.items(): self.send_header(k, v) self.end_headers() if self.command != "HEAD": self.wfile.write(data) @staticmethod def _render_bucket_list(bucket: str, objects: list[str]) -> bytes: """Generate an XML listing of objects in a bucket. Args: bucket: The bucket name. objects: List of object keys in the bucket. Returns: The XML document as bytes. """ entries = [] now = datetime.now(timezone.utc).strftime("%Y-%m-%dT%H:%M:%S.000Z") for key in objects: try: data = S3RequestHandler.server.state.get_object(bucket, key) # type: ignore[attr-defined] size = len(data) etag = _etag(data) except Exception: # noqa: BLE001 size = 0 etag = '""' entries.append( "" f"{_escape_xml(key)}" f"{now}" f"{etag}" f"{size}" "" ) obj_elems = "".join(entries) xml = ( '' "" f"{_escape_xml(bucket)}" f"{obj_elems}" "" ) return xml.encode() class S3ServerProtocol(Protocol): # noqa: D101 state: S3State auth: S3Auth def _escape_xml(text: str) -> str: # noqa: D401 """Escape special characters for XML. Args: text: The text to escape. Returns: The escaped text. """ return ( text.replace("&", "&") .replace("<", "<") .replace(">", ">") .replace('"', """) .replace("'", "'") ) def _etag(data: bytes) -> str: # noqa: D401 """Generate an ETag for binary data. Args: data: The binary data to generate an ETag for. Returns: The MD5 hash of the data as a hex string. """ return md5(data).hexdigest()