Unverified Commit fc5b7419 authored by Sun Haozhe's avatar Sun Haozhe Committed by GitHub
Browse files

corrected the code comment for the output of find_pruneable_heads_and_indices (#22557)

* corrected/clarified the code comment of find_pruneable_heads_and_indices

* have run make style
parent 5f3ea66b
......@@ -249,7 +249,8 @@ def find_pruneable_heads_and_indices(
already_pruned_heads (`Set[int]`): A set of already pruned heads.
Returns:
`Tuple[Set[int], torch.LongTensor]`: A tuple with the remaining heads and their corresponding indices.
`Tuple[Set[int], torch.LongTensor]`: A tuple with the indices of heads to prune taking `already_pruned_heads`
into account and the indices of rows/columns to keep in the layer weight.
"""
mask = torch.ones(n_heads, head_size)
heads = set(heads) - already_pruned_heads # Convert to set and remove already pruned heads
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment