Unverified Commit 7cb66c97 authored by J-shang's avatar J-shang Committed by GitHub
Browse files

[Compression] fix speedup randomize bool (#5191)

parent 57fde460
......@@ -952,7 +952,7 @@ class TaylorFOWeightPruner(EvaluatorBasedPruner):
The old API (``trainer``, ``traced_optimizer`` and ``criterion``) is still supported and will be deprecated in v3.0.
If you want to consult the old API, please refer to `v2.8 pruner API <https://nni.readthedocs.io/en/v2.8/reference/compression/pruner.html>`__.
training_steps
The step number used to collect activations.
The step number used to collect gradients.
mode
'normal', 'dependency_aware' or 'global'.
......
......@@ -69,6 +69,8 @@ def randomize_tensor(tensor, start=1, end=100):
assert isinstance(tensor, torch.Tensor)
if tensor.dtype in torch_integer_dtype:
# integer tensor can only be randomized by the torch.randint
if tensor.dtype is torch.bool:
start, end = 0, 2
torch.randint(int(start), int(end), tensor.size(),
out=tensor.data, dtype=tensor.dtype)
# pass
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment