Unverified Commit 9f34c1b8 authored by UncleCode's avatar UncleCode Committed by GitHub
Browse files

Fix typo in INFERENCE.md, change return_tensors and to correct usage of device (#98)

parent 6185106e
......@@ -62,7 +62,7 @@ model.generation_config.cache_implementation = "static"
model.forward = torch.compile(model.forward, mode=compile_mode)
# warmup
inputs = tokenizer("This is for compilation", return_tensors="pt", padding="max_length", max_length=max_length).to(device)
inputs = tokenizer("This is for compilation", return_tensors="pt", padding="max_length", max_length=max_length).to(torch_device)
model_kwargs = {**inputs, "prompt_input_ids": inputs.input_ids, "prompt_attention_mask": inputs.attention_mask, }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment