"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "ab81d31d20901f1e298145ef24b47f4bd28e4561"
Unverified Commit a7e0ed82 authored by Norm Inui's avatar Norm Inui Committed by GitHub
Browse files

optimize VRAM for calculating pos_bias in LayoutLM v2, v3 (#26139)



* optimize layoutv2, v3 for VRAM saving

* reformat codes

---------
Co-authored-by: default avatarNormXU <xunuo@datagrand.com>
parent ab37b801
......@@ -372,31 +372,28 @@ class LayoutLMv2Encoder(nn.Module):
if self.has_relative_attention_bias:
self.rel_pos_bins = config.rel_pos_bins
self.max_rel_pos = config.max_rel_pos
self.rel_pos_onehot_size = config.rel_pos_bins
self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
if self.has_spatial_attention_bias:
self.max_rel_2d_pos = config.max_rel_2d_pos
self.rel_2d_pos_bins = config.rel_2d_pos_bins
self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
self.gradient_checkpointing = False
def _calculate_1d_position_embeddings(self, hidden_states, position_ids):
def _calculate_1d_position_embeddings(self, position_ids):
rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
rel_pos = relative_position_bucket(
rel_pos_mat,
num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos,
)
rel_pos = nn.functional.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)
rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
rel_pos = rel_pos.contiguous()
return rel_pos
def _calculate_2d_position_embeddings(self, hidden_states, bbox):
def _calculate_2d_position_embeddings(self, bbox):
position_coord_x = bbox[:, :, 0]
position_coord_y = bbox[:, :, 3]
rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
......@@ -411,10 +408,8 @@ class LayoutLMv2Encoder(nn.Module):
num_buckets=self.rel_2d_pos_bins,
max_distance=self.max_rel_2d_pos,
)
rel_pos_x = nn.functional.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
rel_pos_y = nn.functional.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
rel_pos_x = rel_pos_x.contiguous()
rel_pos_y = rel_pos_y.contiguous()
rel_2d_pos = rel_pos_x + rel_pos_y
......@@ -434,14 +429,8 @@ class LayoutLMv2Encoder(nn.Module):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
rel_pos = (
self._calculate_1d_position_embeddings(hidden_states, position_ids)
if self.has_relative_attention_bias
else None
)
rel_2d_pos = (
self._calculate_2d_position_embeddings(hidden_states, bbox) if self.has_spatial_attention_bias else None
)
rel_pos = self._calculate_1d_position_embeddings(position_ids) if self.has_relative_attention_bias else None
rel_2d_pos = self._calculate_2d_position_embeddings(bbox) if self.has_spatial_attention_bias else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
......
......@@ -566,15 +566,13 @@ class LayoutLMv3Encoder(nn.Module):
if self.has_relative_attention_bias:
self.rel_pos_bins = config.rel_pos_bins
self.max_rel_pos = config.max_rel_pos
self.rel_pos_onehot_size = config.rel_pos_bins
self.rel_pos_bias = nn.Linear(self.rel_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_bias = nn.Linear(self.rel_pos_bins, config.num_attention_heads, bias=False)
if self.has_spatial_attention_bias:
self.max_rel_2d_pos = config.max_rel_2d_pos
self.rel_2d_pos_bins = config.rel_2d_pos_bins
self.rel_2d_pos_onehot_size = config.rel_2d_pos_bins
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_onehot_size, config.num_attention_heads, bias=False)
self.rel_pos_x_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
self.rel_pos_y_bias = nn.Linear(self.rel_2d_pos_bins, config.num_attention_heads, bias=False)
def relative_position_bucket(self, relative_position, bidirectional=True, num_buckets=32, max_distance=128):
ret = 0
......@@ -599,7 +597,7 @@ class LayoutLMv3Encoder(nn.Module):
ret += torch.where(is_small, n, val_if_large)
return ret
def _cal_1d_pos_emb(self, hidden_states, position_ids):
def _cal_1d_pos_emb(self, position_ids):
rel_pos_mat = position_ids.unsqueeze(-2) - position_ids.unsqueeze(-1)
rel_pos = self.relative_position_bucket(
......@@ -607,12 +605,11 @@ class LayoutLMv3Encoder(nn.Module):
num_buckets=self.rel_pos_bins,
max_distance=self.max_rel_pos,
)
rel_pos = F.one_hot(rel_pos, num_classes=self.rel_pos_onehot_size).type_as(hidden_states)
rel_pos = self.rel_pos_bias(rel_pos).permute(0, 3, 1, 2)
rel_pos = self.rel_pos_bias.weight.t()[rel_pos].permute(0, 3, 1, 2)
rel_pos = rel_pos.contiguous()
return rel_pos
def _cal_2d_pos_emb(self, hidden_states, bbox):
def _cal_2d_pos_emb(self, bbox):
position_coord_x = bbox[:, :, 0]
position_coord_y = bbox[:, :, 3]
rel_pos_x_2d_mat = position_coord_x.unsqueeze(-2) - position_coord_x.unsqueeze(-1)
......@@ -627,10 +624,8 @@ class LayoutLMv3Encoder(nn.Module):
num_buckets=self.rel_2d_pos_bins,
max_distance=self.max_rel_2d_pos,
)
rel_pos_x = F.one_hot(rel_pos_x, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
rel_pos_y = F.one_hot(rel_pos_y, num_classes=self.rel_2d_pos_onehot_size).type_as(hidden_states)
rel_pos_x = self.rel_pos_x_bias(rel_pos_x).permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias(rel_pos_y).permute(0, 3, 1, 2)
rel_pos_x = self.rel_pos_x_bias.weight.t()[rel_pos_x].permute(0, 3, 1, 2)
rel_pos_y = self.rel_pos_y_bias.weight.t()[rel_pos_y].permute(0, 3, 1, 2)
rel_pos_x = rel_pos_x.contiguous()
rel_pos_y = rel_pos_y.contiguous()
rel_2d_pos = rel_pos_x + rel_pos_y
......@@ -652,8 +647,8 @@ class LayoutLMv3Encoder(nn.Module):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
rel_pos = self._cal_1d_pos_emb(hidden_states, position_ids) if self.has_relative_attention_bias else None
rel_2d_pos = self._cal_2d_pos_emb(hidden_states, bbox) if self.has_spatial_attention_bias else None
rel_pos = self._cal_1d_pos_emb(position_ids) if self.has_relative_attention_bias else None
rel_2d_pos = self._cal_2d_pos_emb(bbox) if self.has_spatial_attention_bias else None
for i, layer_module in enumerate(self.layer):
if output_hidden_states:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment