""" Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
"""
Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero with `model.reset_memory_hooks_state()`
Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
Increase in memory consumption is stored in a :obj:`mem_rss_diff` attribute for each module and can be reset to
zero with :obj:`model.reset_memory_hooks_state()`.
cls_index: [optional] position of the classification token if summary_type == 'cls_index',
)->torch.FloatTensor:
shape (bsz,) or more generally (bsz, ...) where ... are optional leading dimensions of hidden_states.
"""
if summary_type == 'cls_index' and cls_index is None:
Compute a single vector summary of a sequence hidden states.
we take the last token of the sequence as classification token
Args:
hidden_states (:obj:`torch.FloatTensor` of shape :obj:`[batch_size, seq_len, hidden_size]`):
The hidden states of the last layer.
cls_index (:obj:`torch.LongTensor` of shape :obj:`[batch_size]` or :obj:`[batch_size, ...]` where ... are optional leading dimensions of :obj:`hidden_states`, `optional`):
Used if :obj:`summary_type == "cls_index"` and takes the last token of the sequence as classification
token.
Returns:
:obj:`torch.FloatTensor`: The summary of the sequence hidden states.
"""
"""
ifself.summary_type=="last":
ifself.summary_type=="last":
output=hidden_states[:,-1]
output=hidden_states[:,-1]
...
@@ -1239,10 +1358,19 @@ class SequenceSummary(nn.Module):
...
@@ -1239,10 +1358,19 @@ class SequenceSummary(nn.Module):