"tests/bert_generation/test_modeling_bert_generation.py" did not exist on "7fd1febf38bd01ad413abc56ed06700a9675c143"
Unverified Commit 21e259d8 authored by Kevin Koehncke's avatar Kevin Koehncke Committed by GitHub
Browse files

Fix training speed regression introduced by "optimize VRAM for calculating...

Fix training speed regression introduced by "optimize VRAM for calculating pos_bias in LayoutLM v2, v3 (#26139)" (#30988)

* Revert "optimize VRAM for calculating pos_bias in LayoutLM v2, v3 (#26139)"

This reverts commit a7e0ed82

.

* Instead of reverting commit, wrap indexing in torch.no_grad context

* Apply wrapping in LayoutLMv2

* Add comments explaining reason for no_grad

* Fix code format

---------
Co-authored-by: default avatarKevin Koehncke <kevin.koehncke@uipath.com>
parent 7f6e8741
...@@ -383,6 +383,11 @@ class LayoutLMv2Encoder(nn.Module): ...@@ -383,6 +383,11 @@ class LayoutLMv2Encoder(nn.Module):
num_buckets=self.rel_pos_bins, num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos, max_distance=self.max_rel_pos,
) )
# Since this is a simple indexing operation that is independent of the input,
# no need to track gradients for this operation
#
# Without this no_grad context, training speed slows down significantly
with torch.no_grad():
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2) rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
rel_pos = rel_pos.contiguous() rel_pos = rel_pos.contiguous()
return rel_pos return rel_pos
...@@ -402,6 +407,11 @@ class LayoutLMv2Encoder(nn.Module): ...@@ -402,6 +407,11 @@ class LayoutLMv2Encoder(nn.Module):
num_buckets=self.rel_2d_pos_bins, num_buckets=self.rel_2d_pos_bins,
max_distance=self.max_rel_2d_pos, max_distance=self.max_rel_2d_pos,
) )
# Since this is a simple indexing operation that is independent of the input,
# no need to track gradients for this operation
#
# Without this no_grad context, training speed slows down significantly
with torch.no_grad():
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2) rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2) rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
rel_pos_x = rel_pos_x.contiguous() rel_pos_x = rel_pos_x.contiguous()
......
...@@ -600,6 +600,11 @@ class LayoutLMv3Encoder(nn.Module): ...@@ -600,6 +600,11 @@ class LayoutLMv3Encoder(nn.Module):
num_buckets=self.rel_pos_bins, num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos, max_distance=self.max_rel_pos,
) )
# Since this is a simple indexing operation that is independent of the input,
# no need to track gradients for this operation
#
# Without this no_grad context, training speed slows down significantly
with torch.no_grad():
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2) rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
rel_pos = rel_pos.contiguous() rel_pos = rel_pos.contiguous()
return rel_pos return rel_pos
...@@ -619,6 +624,11 @@ class LayoutLMv3Encoder(nn.Module): ...@@ -619,6 +624,11 @@ class LayoutLMv3Encoder(nn.Module):
num_buckets=self.rel_2d_pos_bins, num_buckets=self.rel_2d_pos_bins,
max_distance=self.max_rel_2d_pos, max_distance=self.max_rel_2d_pos,
) )
# Since this is a simple indexing operation that is independent of the input,
# no need to track gradients for this operation
#
# Without this no_grad context, training speed slows down significantly
with torch.no_grad():
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2) rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2) rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
rel_pos_x = rel_pos_x.contiguous() rel_pos_x = rel_pos_x.contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment