README.md 2.06 KB
Newer Older
1
2
3
4
# amp: Automatic Mixed Precision

## Annotating User Functions

Carl Case's avatar
readme  
Carl Case committed
5
Nearly all PyTorch user code needs nothing more than the two steps
6
7
8
9
10
11
12
13
14
above to use amp. After all, custom layers are built out of simpler
PyTorch components, and amp already can see those.

However, any custom C++ or CUDA code is outside of amp's (default)
view of things. For example, suppose I implemented a new recurrent
cell called a "forgetful recurrent unit" that calls directly into a
CUDA backend:

```python
Carl Case's avatar
readme  
Carl Case committed
15
16
from backend import FRUBackend

17
def fru(input, hidden, weight, bias):
Carl Case's avatar
readme  
Carl Case committed
18
19
    # call to CUDA code
    FRUBackend(input, hidden, weight, bias)
20
21
```

Carl Case's avatar
readme  
Carl Case committed
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
In this case, it is possible to get a runtime type mismatch. For
example, you might have `input` in fp16, and `weight` in fp32, and amp
doesn't have the visibility to insert an appropriate cast.

amp exposes two ways to handle "invisible" backend code: function
annotations and explicit registration.

#### Function annotation

The first way to handle backend code is a set of function annotations:

- `@amp.half_function`
- `@amp.float_function`
- `@amp.promote_function`

These correspond to:

- Cast all arguments to fp16
- Cast all argumnets fo fp32
- If there are any type mismatches, cast everything to the widest type

In our example, we believe that the FRU unit is fp16-safe and will get
performance gains from casting its arguments to fp16, so we write:

46
```python
Carl Case's avatar
readme  
Carl Case committed
47
@amp.half_function
48
def fru(input, hidden, weight, bias):
Carl Case's avatar
readme  
Carl Case committed
49
    #...
50
```
Carl Case's avatar
readme  
Carl Case committed
51
52
53
54
55
56
57
58
59
60
61
62
63

#### Explicit registration

The other way to handle backend code is with explicit function
registration:

- `amp.register_half_function(module, function_name)`
- `amp.register_float_function(module, function_name)`
- `amp.register_promote_function(module, function_name)`

When using this API, `module` is the containing class or module for
the function, and `function_name` is the _string_ name of the
function. Note that the function must be registered before the call to
64
`amp.initalize()`.
Carl Case's avatar
readme  
Carl Case committed
65
66
67

For our FRU unit, we can register the backend function directly:

68
```python
Carl Case's avatar
readme  
Carl Case committed
69
import backend
70

Carl Case's avatar
readme  
Carl Case committed
71
72
amp.register_half_function(backend, 'FRUBackend')
```