Unverified Commit eceb1042 authored by Navjot's avatar Navjot Committed by GitHub
Browse files

flax.linen.apply takes state as the first param, followed by the input (#12510)

parent f1c81d6b
......@@ -621,7 +621,7 @@ state = model_flax.init(rng, dummy_input_ids)
and then we can do the forward pass.
```python
sequences = model_flax.apply(input_ids, state)
sequences = model_flax.apply(state, input_ids)
```
Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment