"docs/source/vscode:/vscode.git/clone" did not exist on "940d1a76b0f2ebf98b18326bffc7e4bfc8c416d7"
Unverified Commit eceb1042 authored by Navjot's avatar Navjot Committed by GitHub
Browse files

flax.linen.apply takes state as the first param, followed by the input (#12510)

parent f1c81d6b
...@@ -621,7 +621,7 @@ state = model_flax.init(rng, dummy_input_ids) ...@@ -621,7 +621,7 @@ state = model_flax.init(rng, dummy_input_ids)
and then we can do the forward pass. and then we can do the forward pass.
```python ```python
sequences = model_flax.apply(input_ids, state) sequences = model_flax.apply(state, input_ids)
``` ```
Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object: Visually, the forward pass would now be represented as passing all tensors required for the computation to the model's object:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment