Commit 52e63688 authored by Jared Casper's avatar Jared Casper
Browse files

Merge branch 'yuya/drop_path_fix' into 'main'

Fix DropPath for hidden shape [s, b, h]

See merge request ADLR/megatron-lm!485
parents b24f4adc 6d45a903
......@@ -45,7 +45,8 @@ class DropPath(MegatronModule):
return hidden_state
keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
# hidden_state: [s, b, h]
shape = (1,) + (hidden_state.shape[1],) + (1,) * (hidden_state.ndim - 2)
random_tensor = keep_prob + \
torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment