models.md 4.37 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
# Model conventions and rules

Shared reference for model-related conventions, patterns, and gotchas.
Linked from `AGENTS.md`, `skills/model-integration/SKILL.md`, and `review-rules.md`.

## Coding style

- All layer calls should be visible directly in `forward` — avoid helper functions that hide `nn.Module` calls.
- Avoid graph breaks for `torch.compile` compatibility — do not insert NumPy operations in forward implementations and any other patterns that can break `torch.compile` compatibility with `fullgraph=True`.
- No new mandatory dependency without discussion (e.g. `einops`). Optional deps guarded with `is_X_available()` and a dummy in `utils/dummy_*.py`.

## Common model conventions

- Models use `ModelMixin` with `register_to_config` for config serialization

## Attention pattern

Attention must follow the diffusers pattern: both the `Attention` class and its processor are defined in the model file. The processor's `__call__` handles the actual compute and must use `dispatch_attention_fn` rather than calling `F.scaled_dot_product_attention` directly. The attention class inherits `AttentionModuleMixin` and declares `_default_processor_cls` and `_available_processors`.

```python
# transformer_mymodel.py

class MyModelAttnProcessor:
    _attention_backend = None
    _parallel_config = None

    def __call__(self, attn, hidden_states, attention_mask=None, ...):
        query = attn.to_q(hidden_states)
        key = attn.to_k(hidden_states)
        value = attn.to_v(hidden_states)
        # reshape, apply rope, etc.
        hidden_states = dispatch_attention_fn(
            query, key, value,
            attn_mask=attention_mask,
            backend=self._attention_backend,
            parallel_config=self._parallel_config,
        )
        hidden_states = hidden_states.flatten(2, 3)
        return attn.to_out[0](hidden_states)


class MyModelAttention(nn.Module, AttentionModuleMixin):
    _default_processor_cls = MyModelAttnProcessor
    _available_processors = [MyModelAttnProcessor]

    def __init__(self, query_dim, heads=8, dim_head=64, ...):
        super().__init__()
        self.to_q = nn.Linear(query_dim, heads * dim_head, bias=False)
        self.to_k = nn.Linear(query_dim, heads * dim_head, bias=False)
        self.to_v = nn.Linear(query_dim, heads * dim_head, bias=False)
        self.to_out = nn.ModuleList([nn.Linear(heads * dim_head, query_dim), nn.Dropout(0.0)])
        self.set_processor(MyModelAttnProcessor())

    def forward(self, hidden_states, attention_mask=None, **kwargs):
        return self.processor(self, hidden_states, attention_mask, **kwargs)
```

Consult the implementations in `src/diffusers/models/transformers/` if you need further references.

## Gotchas

1. **Forgetting `__init__.py` lazy imports.** Every new class must be registered in the appropriate `__init__.py` with lazy imports. Missing this causes `ImportError` that only shows up when users try `from diffusers import YourNewClass`.

2. **Using `einops` or other non-PyTorch deps.** Reference implementations often use `einops.rearrange`. Always rewrite with native PyTorch (`reshape`, `permute`, `unflatten`). Don't add the dependency. If a dependency is truly unavoidable, guard its import: `if is_my_dependency_available(): import my_dependency`.

3. **Missing `make fix-copies` after `# Copied from`.** If you add `# Copied from` annotations, you must run `make fix-copies` to propagate them. CI will fail otherwise.

4. **Wrong `_supports_cache_class` / `_no_split_modules`.** These class attributes control KV cache and device placement. Copy from a similar model and verify -- wrong values cause silent correctness bugs or OOM errors.

5. **Missing `@torch.no_grad()` on pipeline `__call__`.** Forgetting this causes GPU OOM from gradient accumulation during inference.

6. **Config serialization gaps.** Every `__init__` parameter in a `ModelMixin` subclass must be captured by `register_to_config`. If you add a new param but forget to register it, `from_pretrained` will silently use the default instead of the saved value.

7. **Forgetting to update `_import_structure` and `_lazy_modules`.** The top-level `src/diffusers/__init__.py` has both -- missing either one causes partial import failures.

8. **Hardcoded dtype in model forward.** Don't hardcode `torch.float32` or `torch.bfloat16` in the model's forward pass. Use the dtype of the input tensors or `self.dtype` so the model works with any precision.