Commit 6d45a903 authored by Yu Yao's avatar Yu Yao
Browse files

Fix DropPath for hidden shape [s, b, h]

parent 70169453
......@@ -45,7 +45,8 @@ class DropPath(MegatronModule):
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
# hidden_state: [s, b, h]
shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment