""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
"""
Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
zero with :obj:`model.reset_memory_hooks_state()`.
Compute a single vector summary of a sequence hidden states.
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
token.
Returns:
:obj:`torch.FloatTensor`: The summary of the sequence hidden states.
"""
ifself.summary_type=="last":
output=hidden_states[:,-1]
...
...
@@ -1239,10 +1358,19 @@ class SequenceSummary(nn.Module):
returnoutput
defprune_linear_layer(layer,index,dim=0):
""" Prune a linear layer (a model parameters) to keep only entries in index.
Return the pruned layer as a new layer with requires_grad=True.