Unverified Commit 0324117f authored by guoshzhao's avatar guoshzhao Committed by GitHub
Browse files

Benchmarks: Fix Bug - Fix dataset precision for CNN and LSTM benchmarks.

parent 2299d238
......@@ -46,7 +46,7 @@ def _generate_dataset(self):
self._dataset = TorchRandomDataset(
[self._args.sample_count, 3, self._args.image_size, self._args.image_size],
self._world_size,
dtype=torch.long
dtype=torch.float32
)
if len(self._dataset) == 0:
logger.error('Generate random dataset failed - model: {}'.format(self._name))
......
......@@ -92,7 +92,7 @@ def _generate_dataset(self):
True if dataset is created successfully.
"""
self._dataset = TorchRandomDataset(
[self._args.sample_count, self._args.seq_len, self._args.input_size], self._world_size, dtype=torch.float
[self._args.sample_count, self._args.seq_len, self._args.input_size], self._world_size, dtype=torch.float32
)
if len(self._dataset) == 0:
logger.error('Generate random dataset failed - model: {}'.format(self._name))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment