Unverified Commit 5fb36a17 authored by Vasilis Vryniotis's avatar Vasilis Vryniotis Committed by GitHub
Browse files

Resolving tracing problem on StochasticDepth iterator. (#4372)


Co-authored-by: default avatarFrancisco Massa <fvsmassa@gmail.com>
parent 6ce278bb
...@@ -28,9 +28,10 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True) ...@@ -28,9 +28,10 @@ def stochastic_depth(input: Tensor, p: float, mode: str, training: bool = True)
return input return input
survival_rate = 1.0 - p survival_rate = 1.0 - p
size = [1] * input.ndim
if mode == "row": if mode == "row":
size[0] = input.shape[0] size = [input.shape[0]] + [1] * (input.ndim - 1)
else:
size = [1] * input.ndim
noise = torch.empty(size, dtype=input.dtype, device=input.device) noise = torch.empty(size, dtype=input.dtype, device=input.device)
noise = noise.bernoulli_(survival_rate).div_(survival_rate) noise = noise.bernoulli_(survival_rate).div_(survival_rate)
return input * noise return input * noise
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment