Unverified Commit d33d3131 authored by Omar Sanseviero's avatar Omar Sanseviero Committed by GitHub
Browse files

Nits in Llama2 docstring (#26996)

Update llama2.md
parent ef978d0a
...@@ -28,12 +28,12 @@ Checkout all Llama2 models [here](https://huggingface.co/models?search=llama2) ...@@ -28,12 +28,12 @@ Checkout all Llama2 models [here](https://huggingface.co/models?search=llama2)
<Tip warning={true}> <Tip warning={true}>
The `Llama2` models were trained using `bfloat16`, but the original inference uses `float16. The checkpoints uploaded on the hub use `torch_dtype = 'float16'` which will be The `Llama2` models were trained using `bfloat16`, but the original inference uses `float16`. The checkpoints uploaded on the Hub use `torch_dtype = 'float16'`, which will be
used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`. used by the `AutoModel` API to cast the checkpoints from `torch.float32` to `torch.float16`.
The `dtype` of the online weights is mostly irrelevant, unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online) then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`) and finally, if there is a `torch_dtype` provided in the config, it will be used. The `dtype` of the online weights is mostly irrelevant unless you are using `torch_dtype="auto"` when initializing a model using `model = AutoModelForCausalLM.from_pretrained("path", torch_dtype = "auto")`. The reason is that the model will first be downloaded ( using the `dtype` of the checkpoints online), then it will be casted to the default `dtype` of `torch` (becomes `torch.float32`), and finally, if there is a `torch_dtype` provided in the config, it will be used.
Training the model in `float16` is not recommended and known to produce `nan`, as such the model should be trained in `bfloat16`. Training the model in `float16` is not recommended and is known to produce `nan`; as such, the model should be trained in `bfloat16`.
</Tip> </Tip>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment