Unverified Commit 7cb66c97 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] fix speedup randomize bool (#5191)

parent 57fde460
...@@ -952,7 +952,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner): ...@@ -952,7 +952,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner):
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0. The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__. If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_steps training_steps
The step number used to collect activations. The step number used to collect gradients.
mode mode
'normal', 'dependency_aware' or 'global'. 'normal', 'dependency_aware' or 'global'.
......
...@@ -69,6 +69,8 @@ def randomize_tensor(tensor, start=1, end=100): ...@@ -69,6 +69,8 @@ def randomize_tensor(tensor, start=1, end=100):
assert isinstance(tensor, torch.Tensor) assert isinstance(tensor, torch.Tensor)
if tensor.dtype in torch_integer_dtype: if tensor.dtype in torch_integer_dtype:
# integer tensor can only be randomized by the torch.randint # integer tensor can only be randomized by the torch.randint
if tensor.dtype is torch.bool:
start, end = 0, 2
torch.randint(int(start), int(end), tensor.size(), torch.randint(int(start), int(end), tensor.size(),
out=tensor.data, dtype=tensor.dtype) out=tensor.data, dtype=tensor.dtype)
# pass # pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment