"...git@developer.sourcefind.cn:OpenDAS/megatron-lm.git" did not exist on "688448db7547be90203440cfd105703d8a853f39"
Unverified Commit dc0c1029 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`Docs`] More clarifications on BT + FA (#25823)

parent c9bae84e
......@@ -74,7 +74,7 @@ import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m").to("cuda")
model = AutoModelForCausalLM.from_pretrained("facebook/opt-350m", torch_dtype=torch.float16).to("cuda")
# convert the model to BetterTransformer
model.to_bettertransformer()
......@@ -99,6 +99,8 @@ try using the PyTorch nightly version, which may have a broader coverage for Fla
pip3 install -U --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cu118
```
Or make sure your model is correctly casted in float16 or bfloat16
Have a look at [this detailed blogpost](https://pytorch.org/blog/out-of-the-box-acceleration/) to read more about what is possible to do with `BetterTransformer` + SDPA API.
......@@ -270,4 +272,4 @@ with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable
outputs = model.generate(**inputs)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
```
\ No newline at end of file
```
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment