Commit e63cf68a authored by chenzk's avatar chenzk
Browse files

v1.0

parents
Pipeline #2842 canceled with stages
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
def adjust_bboxes_to_image_border(boxes, image_shape, threshold=20):
"""
Adjust bounding boxes to stick to image border if they are within a certain threshold.
Args:
boxes (torch.Tensor): (n, 4)
image_shape (tuple): (height, width)
threshold (int): pixel threshold
Returns:
adjusted_boxes (torch.Tensor): adjusted bounding boxes
"""
# Image dimensions
h, w = image_shape
# Adjust boxes
boxes[boxes[:, 0] < threshold, 0] = 0 # x1
boxes[boxes[:, 1] < threshold, 1] = 0 # y1
boxes[boxes[:, 2] > w - threshold, 2] = w # x2
boxes[boxes[:, 3] > h - threshold, 3] = h # y2
return boxes
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from ultralytics.models.yolo.segment import SegmentationValidator
from ultralytics.utils.metrics import SegmentMetrics
class FastSAMValidator(SegmentationValidator):
"""
Custom validation class for fast SAM (Segment Anything Model) segmentation in Ultralytics YOLO framework.
Extends the SegmentationValidator class, customizing the validation process specifically for fast SAM. This class
sets the task to 'segment' and uses the SegmentMetrics for evaluation. Additionally, plotting features are disabled
to avoid errors during validation.
Attributes:
dataloader: The data loader object used for validation.
save_dir (str): The directory where validation results will be saved.
pbar: A progress bar object.
args: Additional arguments for customization.
_callbacks: List of callback functions to be invoked during validation.
"""
def __init__(self, dataloader=None, save_dir=None, pbar=None, args=None, _callbacks=None):
"""
Initialize the FastSAMValidator class, setting the task to 'segment' and metrics to SegmentMetrics.
Args:
dataloader (torch.utils.data.DataLoader): Dataloader to be used for validation.
save_dir (Path, optional): Directory to save results.
pbar (tqdm.tqdm): Progress bar for displaying progress.
args (SimpleNamespace): Configuration for the validator.
_callbacks (dict): Dictionary to store various callback functions.
Notes:
Plots for ConfusionMatrix and other related metrics are disabled in this class to avoid errors.
"""
super().__init__(dataloader, save_dir, pbar, args, _callbacks)
self.args.task = "segment"
self.args.plots = False # disable ConfusionMatrix and other plots to avoid errors
self.metrics = SegmentMetrics(save_dir=self.save_dir, on_plot=self.on_plot)
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import NAS
from .predict import NASPredictor
from .val import NASValidator
__all__ = "NASPredictor", "NASValidator", "NAS"
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
YOLO-NAS model interface.
Example:
```python
from ultralytics import NAS
model = NAS("yolo_nas_s")
results = model.predict("ultralytics/assets/bus.jpg")
```
"""
from pathlib import Path
import torch
from ultralytics.engine.model import Model
from ultralytics.utils.downloads import attempt_download_asset
from ultralytics.utils.torch_utils import model_info
from .predict import NASPredictor
from .val import NASValidator
class NAS(Model):
"""
YOLO NAS model for object detection.
This class provides an interface for the YOLO-NAS models and extends the `Model` class from Ultralytics engine.
It is designed to facilitate the task of object detection using pre-trained or custom-trained YOLO-NAS models.
Example:
```python
from ultralytics import NAS
model = NAS("yolo_nas_s")
results = model.predict("ultralytics/assets/bus.jpg")
```
Attributes:
model (str): Path to the pre-trained model or model name. Defaults to 'yolo_nas_s.pt'.
Note:
YOLO-NAS models only support pre-trained models. Do not provide YAML configuration files.
"""
def __init__(self, model="yolo_nas_s.pt") -> None:
"""Initializes the NAS model with the provided or default 'yolo_nas_s.pt' model."""
assert Path(model).suffix not in {".yaml", ".yml"}, "YOLO-NAS models only support pre-trained models."
super().__init__(model, task="detect")
def _load(self, weights: str, task=None) -> None:
"""Loads an existing NAS model weights or creates a new NAS model with pretrained weights if not provided."""
import super_gradients
suffix = Path(weights).suffix
if suffix == ".pt":
self.model = torch.load(attempt_download_asset(weights))
elif suffix == "":
self.model = super_gradients.training.models.get(weights, pretrained_weights="coco")
# Override the forward method to ignore additional arguments
def new_forward(x, *args, **kwargs):
"""Ignore additional __call__ arguments."""
return self.model._original_forward(x)
self.model._original_forward = self.model.forward
self.model.forward = new_forward
# Standardize model
self.model.fuse = lambda verbose=True: self.model
self.model.stride = torch.tensor([32])
self.model.names = dict(enumerate(self.model._class_names))
self.model.is_fused = lambda: False # for info()
self.model.yaml = {} # for info()
self.model.pt_path = weights # for export()
self.model.task = "detect" # for export()
def info(self, detailed=False, verbose=True):
"""
Logs model info.
Args:
detailed (bool): Show detailed information about model.
verbose (bool): Controls verbosity.
"""
return model_info(self.model, detailed=detailed, verbose=verbose, imgsz=640)
@property
def task_map(self):
"""Returns a dictionary mapping tasks to respective predictor and validator classes."""
return {"detect": {"predictor": NASPredictor, "validator": NASValidator}}
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class NASPredictor(BasePredictor):
"""
Ultralytics YOLO NAS Predictor for object detection.
This class extends the `BasePredictor` from Ultralytics engine and is responsible for post-processing the
raw predictions generated by the YOLO NAS models. It applies operations like non-maximum suppression and
scaling the bounding boxes to fit the original image dimensions.
Attributes:
args (Namespace): Namespace containing various configurations for post-processing.
Example:
```python
from ultralytics import NAS
model = NAS("yolo_nas_s")
predictor = model.predictor
# Assumes that raw_preds, img, orig_imgs are available
results = predictor.postprocess(raw_preds, img, orig_imgs)
```
Note:
Typically, this class is not instantiated directly. It is used internally within the `NAS` class.
"""
def postprocess(self, preds_in, img, orig_imgs):
"""Postprocess predictions and returns a list of Results objects."""
# Cat boxes and class scores
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
preds = ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
agnostic=self.args.agnostic_nms,
max_det=self.args.max_det,
classes=self.args.classes,
)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for pred, orig_img, img_path in zip(preds, orig_imgs, self.batch[0]):
pred[:, :4] = ops.scale_boxes(img.shape[2:], pred[:, :4], orig_img.shape)
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import ops
__all__ = ["NASValidator"]
class NASValidator(DetectionValidator):
"""
Ultralytics YOLO NAS Validator for object detection.
Extends `DetectionValidator` from the Ultralytics models package and is designed to post-process the raw predictions
generated by YOLO NAS models. It performs non-maximum suppression to remove overlapping and low-confidence boxes,
ultimately producing the final detections.
Attributes:
args (Namespace): Namespace containing various configurations for post-processing, such as confidence and IoU.
lb (torch.Tensor): Optional tensor for multilabel NMS.
Example:
```python
from ultralytics import NAS
model = NAS("yolo_nas_s")
validator = model.validator
# Assumes that raw_preds are available
final_preds = validator.postprocess(raw_preds)
```
Note:
This class is generally not instantiated directly but is used internally within the `NAS` class.
"""
def postprocess(self, preds_in):
"""Apply Non-maximum suppression to prediction outputs."""
boxes = ops.xyxy2xywh(preds_in[0][0])
preds = torch.cat((boxes, preds_in[0][1]), -1).permute(0, 2, 1)
return ops.non_max_suppression(
preds,
self.args.conf,
self.args.iou,
labels=self.lb,
multi_label=False,
agnostic=self.args.single_cls or self.args.agnostic_nms,
max_det=self.args.max_det,
max_time_img=0.5,
)
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from .model import RTDETR
from .predict import RTDETRPredictor
from .val import RTDETRValidator
__all__ = "RTDETRPredictor", "RTDETRValidator", "RTDETR"
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
"""
Interface for Baidu's RT-DETR, a Vision Transformer-based real-time object detector. RT-DETR offers real-time
performance and high accuracy, excelling in accelerated backends like CUDA with TensorRT. It features an efficient
hybrid encoder and IoU-aware query selection for enhanced detection accuracy.
For more information on RT-DETR, visit: https://arxiv.org/pdf/2304.08069.pdf
"""
from ultralytics.engine.model import Model
from ultralytics.nn.tasks import RTDETRDetectionModel
from .predict import RTDETRPredictor
from .train import RTDETRTrainer
from .val import RTDETRValidator
class RTDETR(Model):
"""
Interface for Baidu's RT-DETR model. This Vision Transformer-based object detector provides real-time performance
with high accuracy. It supports efficient hybrid encoding, IoU-aware query selection, and adaptable inference speed.
Attributes:
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
"""
def __init__(self, model="rtdetr-l.pt") -> None:
"""
Initializes the RT-DETR model with the given pre-trained model file. Supports .pt and .yaml formats.
Args:
model (str): Path to the pre-trained model. Defaults to 'rtdetr-l.pt'.
Raises:
NotImplementedError: If the model file extension is not 'pt', 'yaml', or 'yml'.
"""
super().__init__(model=model, task="detect")
@property
def task_map(self) -> dict:
"""
Returns a task map for RT-DETR, associating tasks with corresponding Ultralytics classes.
Returns:
dict: A dictionary mapping task names to Ultralytics task classes for the RT-DETR model.
"""
return {
"detect": {
"predictor": RTDETRPredictor,
"validator": RTDETRValidator,
"trainer": RTDETRTrainer,
"model": RTDETRDetectionModel,
}
}
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.data.augment import LetterBox
from ultralytics.engine.predictor import BasePredictor
from ultralytics.engine.results import Results
from ultralytics.utils import ops
class RTDETRPredictor(BasePredictor):
"""
RT-DETR (Real-Time Detection Transformer) Predictor extending the BasePredictor class for making predictions using
Baidu's RT-DETR model.
This class leverages the power of Vision Transformers to provide real-time object detection while maintaining
high accuracy. It supports key features like efficient hybrid encoding and IoU-aware query selection.
Example:
```python
from ultralytics.utils import ASSETS
from ultralytics.models.rtdetr import RTDETRPredictor
args = dict(model="rtdetr-l.pt", source=ASSETS)
predictor = RTDETRPredictor(overrides=args)
predictor.predict_cli()
```
Attributes:
imgsz (int): Image size for inference (must be square and scale-filled).
args (dict): Argument overrides for the predictor.
"""
def postprocess(self, preds, img, orig_imgs):
"""
Postprocess the raw predictions from the model to generate bounding boxes and confidence scores.
The method filters detections based on confidence and class if specified in `self.args`.
Args:
preds (list): List of [predictions, extra] from the model.
img (torch.Tensor): Processed input images.
orig_imgs (list or torch.Tensor): Original, unprocessed images.
Returns:
(list[Results]): A list of Results objects containing the post-processed bounding boxes, confidence scores,
and class labels.
"""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
nd = preds[0].shape[-1]
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
if not isinstance(orig_imgs, list): # input images are a torch.Tensor, not a list
orig_imgs = ops.convert_torch2numpy_batch(orig_imgs)
results = []
for bbox, score, orig_img, img_path in zip(bboxes, scores, orig_imgs, self.batch[0]): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
max_score, cls = score.max(-1, keepdim=True) # (300, 1)
idx = max_score.squeeze(-1) > self.args.conf # (300, )
if self.args.classes is not None:
idx = (cls == torch.tensor(self.args.classes, device=cls.device)).any(1) & idx
pred = torch.cat([bbox, max_score, cls], dim=-1)[idx] # filter
oh, ow = orig_img.shape[:2]
pred[..., [0, 2]] *= ow
pred[..., [1, 3]] *= oh
results.append(Results(orig_img, path=img_path, names=self.model.names, boxes=pred))
return results
def pre_transform(self, im):
"""
Pre-transforms the input images before feeding them into the model for inference. The input images are
letterboxed to ensure a square aspect ratio and scale-filled. The size must be square(640) and scaleFilled.
Args:
im (list[np.ndarray] |torch.Tensor): Input images of shape (N,3,h,w) for tensor, [(h,w,3) x N] for list.
Returns:
(list): List of pre-transformed images ready for model inference.
"""
letterbox = LetterBox(self.imgsz, auto=False, scaleFill=True)
return [letterbox(image=x) for x in im]
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
from copy import copy
import torch
from ultralytics.models.yolo.detect import DetectionTrainer
from ultralytics.nn.tasks import RTDETRDetectionModel
from ultralytics.utils import RANK, colorstr
from .val import RTDETRDataset, RTDETRValidator
class RTDETRTrainer(DetectionTrainer):
"""
Trainer class for the RT-DETR model developed by Baidu for real-time object detection. Extends the DetectionTrainer
class for YOLO to adapt to the specific features and architecture of RT-DETR. This model leverages Vision
Transformers and has capabilities like IoU-aware query selection and adaptable inference speed.
Notes:
- F.grid_sample used in RT-DETR does not support the `deterministic=True` argument.
- AMP training can lead to NaN outputs and may produce errors during bipartite graph matching.
Example:
```python
from ultralytics.models.rtdetr.train import RTDETRTrainer
args = dict(model="rtdetr-l.yaml", data="coco8.yaml", imgsz=640, epochs=3)
trainer = RTDETRTrainer(overrides=args)
trainer.train()
```
"""
def get_model(self, cfg=None, weights=None, verbose=True):
"""
Initialize and return an RT-DETR model for object detection tasks.
Args:
cfg (dict, optional): Model configuration. Defaults to None.
weights (str, optional): Path to pre-trained model weights. Defaults to None.
verbose (bool): Verbose logging if True. Defaults to True.
Returns:
(RTDETRDetectionModel): Initialized model.
"""
model = RTDETRDetectionModel(cfg, nc=self.data["nc"], verbose=verbose and RANK == -1)
if weights:
model.load(weights)
return model
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build and return an RT-DETR dataset for training or validation.
Args:
img_path (str): Path to the folder containing images.
mode (str): Dataset mode, either 'train' or 'val'.
batch (int, optional): Batch size for rectangle training. Defaults to None.
Returns:
(RTDETRDataset): Dataset object for the specific mode.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=mode == "train",
hyp=self.args,
rect=False,
cache=self.args.cache or None,
single_cls=self.args.single_cls or False,
prefix=colorstr(f"{mode}: "),
classes=self.args.classes,
data=self.data,
fraction=self.args.fraction if mode == "train" else 1.0,
)
def get_validator(self):
"""
Returns a DetectionValidator suitable for RT-DETR model validation.
Returns:
(RTDETRValidator): Validator object for model validation.
"""
self.loss_names = "giou_loss", "cls_loss", "l1_loss"
return RTDETRValidator(self.test_loader, save_dir=self.save_dir, args=copy(self.args))
def preprocess_batch(self, batch):
"""
Preprocess a batch of images. Scales and converts the images to float format.
Args:
batch (dict): Dictionary containing a batch of images, bboxes, and labels.
Returns:
(dict): Preprocessed batch.
"""
batch = super().preprocess_batch(batch)
bs = len(batch["img"])
batch_idx = batch["batch_idx"]
gt_bbox, gt_class = [], []
for i in range(bs):
gt_bbox.append(batch["bboxes"][batch_idx == i].to(batch_idx.device))
gt_class.append(batch["cls"][batch_idx == i].to(device=batch_idx.device, dtype=torch.long))
return batch
# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license
import torch
from ultralytics.data import YOLODataset
from ultralytics.data.augment import Compose, Format, v8_transforms
from ultralytics.models.yolo.detect import DetectionValidator
from ultralytics.utils import colorstr, ops
__all__ = ("RTDETRValidator",) # tuple or list
class RTDETRDataset(YOLODataset):
"""
Real-Time DEtection and TRacking (RT-DETR) dataset class extending the base YOLODataset class.
This specialized dataset class is designed for use with the RT-DETR object detection model and is optimized for
real-time detection and tracking tasks.
"""
def __init__(self, *args, data=None, **kwargs):
"""Initialize the RTDETRDataset class by inheriting from the YOLODataset class."""
super().__init__(*args, data=data, **kwargs)
# NOTE: add stretch version load_image for RTDETR mosaic
def load_image(self, i, rect_mode=False):
"""Loads 1 image from dataset index 'i', returns (im, resized hw)."""
return super().load_image(i=i, rect_mode=rect_mode)
def build_transforms(self, hyp=None):
"""Temporary, only for evaluation."""
if self.augment:
hyp.mosaic = hyp.mosaic if self.augment and not self.rect else 0.0
hyp.mixup = hyp.mixup if self.augment and not self.rect else 0.0
transforms = v8_transforms(self, self.imgsz, hyp, stretch=True)
else:
# transforms = Compose([LetterBox(new_shape=(self.imgsz, self.imgsz), auto=False, scaleFill=True)])
transforms = Compose([])
transforms.append(
Format(
bbox_format="xywh",
normalize=True,
return_mask=self.use_segments,
return_keypoint=self.use_keypoints,
batch_idx=True,
mask_ratio=hyp.mask_ratio,
mask_overlap=hyp.overlap_mask,
)
)
return transforms
class RTDETRValidator(DetectionValidator):
"""
RTDETRValidator extends the DetectionValidator class to provide validation capabilities specifically tailored for
the RT-DETR (Real-Time DETR) object detection model.
The class allows building of an RTDETR-specific dataset for validation, applies Non-maximum suppression for
post-processing, and updates evaluation metrics accordingly.
Example:
```python
from ultralytics.models.rtdetr import RTDETRValidator
args = dict(model="rtdetr-l.pt", data="coco8.yaml")
validator = RTDETRValidator(args=args)
validator()
```
Note:
For further details on the attributes and methods, refer to the parent DetectionValidator class.
"""
def build_dataset(self, img_path, mode="val", batch=None):
"""
Build an RTDETR Dataset.
Args:
img_path (str): Path to the folder containing images.
mode (str): `train` mode or `val` mode, users are able to customize different augmentations for each mode.
batch (int, optional): Size of batches, this is for `rect`. Defaults to None.
"""
return RTDETRDataset(
img_path=img_path,
imgsz=self.args.imgsz,
batch_size=batch,
augment=False, # no augmentation
hyp=self.args,
rect=False, # no rect
cache=self.args.cache or None,
prefix=colorstr(f"{mode}: "),
data=self.data,
)
def postprocess(self, preds):
"""Apply Non-maximum suppression to prediction outputs."""
if not isinstance(preds, (list, tuple)): # list for PyTorch inference but list[0] Tensor for export inference
preds = [preds, None]
bs, _, nd = preds[0].shape
bboxes, scores = preds[0].split((4, nd - 4), dim=-1)
bboxes *= self.args.imgsz
outputs = [torch.zeros((0, 6), device=bboxes.device)] * bs
for i, bbox in enumerate(bboxes): # (300, 4)
bbox = ops.xywh2xyxy(bbox)
score, cls = scores[i].max(-1) # (300, )
# Do not need threshold for evaluation as only got 300 boxes here
# idx = score > self.args.conf
pred = torch.cat([bbox, score[..., None], cls[..., None]], dim=-1) # filter
# Sort by confidence to correctly get internal metrics
pred = pred[score.argsort(descending=True)]
outputs[i] = pred # [idx]
return outputs
def _prepare_batch(self, si, batch):
"""Prepares a batch for training or inference by applying transformations."""
idx = batch["batch_idx"] == si
cls = batch["cls"][idx].squeeze(-1)
bbox = batch["bboxes"][idx]
ori_shape = batch["ori_shape"][si]
imgsz = batch["img"].shape[2:]
ratio_pad = batch["ratio_pad"][si]
if len(cls):
bbox = ops.xywh2xyxy(bbox) # target boxes
bbox[..., [0, 2]] *= ori_shape[1] # native-space pred
bbox[..., [1, 3]] *= ori_shape[0] # native-space pred
return {"cls": cls, "bbox": bbox, "ori_shape": ori_shape, "imgsz": imgsz, "ratio_pad": ratio_pad}
def _prepare_pred(self, pred, pbatch):
"""Prepares and returns a batch with transformed bounding boxes and class labels."""
predn = pred.clone()
predn[..., [0, 2]] *= pbatch["ori_shape"][1] / self.args.imgsz # native-space pred
predn[..., [1, 3]] *= pbatch["ori_shape"][0] / self.args.imgsz # native-space pred
return predn.float()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment