Commit 6d45a903 authored by Yu Yao's avatar Yu Yao
Browse files

Fix DropPath for hidden shape [s, b, h]

parent 70169453
...@@ -45,7 +45,8 @@ class DropPath(MegatronModule): ...@@ -45,7 +45,8 @@ class DropPath(MegatronModule):
return hidden_state return hidden_state
keep_prob = 1 - self.drop_prob keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets # work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1) # hidden_state: [s, b, h]
shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
random_tensor = keep_prob + \ random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device) torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment