Unverified Commit 8cbd0bd1 authored by Lysandre Debut's avatar Lysandre Debut Committed by GitHub
Browse files

Specify dataset dtype (#10195)


Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
Co-authored-by: default avatarQuentin Lhoest <lhoest.q@gmail.com>
parent 0b1f552a
......@@ -500,7 +500,7 @@ class TrainerIntegrationTest(unittest.TestCase):
self.check_trained_model(trainer.model)
# Can return tensors.
train_dataset.set_format(type="torch")
train_dataset.set_format(type="torch", dtype=torch.float32)
model = RegressionModel()
trainer = Trainer(model, args, train_dataset=train_dataset)
trainer.train()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment