"zh_CN/docs/vscode:/vscode.git/clone" did not exist on "5d7923eb1c13d40f037b997dbfc8a5cf1a7980d8"
Commit f9b1a89a authored by HHL's avatar HHL
Browse files

v

parent 60e27226
import torch
from torch import nn
from torch._C import device
from torchvision.ops import roi_align
def convert_to_roi_format(lines_box):
concat_boxes = torch.cat(lines_box, dim=0)
device, dtype = concat_boxes.device, concat_boxes.dtype
ids = torch.cat(
[
torch.full((lines_box_pi.shape[0], 1), i, dtype=dtype, device=device)
for i, lines_box_pi in enumerate(lines_box)
],
dim=0
)
rois = torch.cat([ids, concat_boxes], dim=1)
return rois
class RoIPool(nn.Module):
def __init__(self, pool_size):
super().__init__()
self.pool_size = pool_size
def gen_rois(self, feats):
*_, H, W = feats.shape
pool_W, pool_H = self.pool_size
Width = W / pool_W
Height = H / pool_H
bbox_x = torch.arange(0, pool_W + 1, 1).to(feats) * Width
bbox_y = torch.arange(0, pool_H + 1, 1).to(feats) * Height
bboxes = torch.stack(
[
bbox_x[:-1].repeat(pool_W, 1),
bbox_y[:-1].repeat(pool_H, 1).transpose(0, 1),
bbox_x[1:].repeat(pool_W, 1),
bbox_y[1:].repeat(pool_H, 1).transpose(0, 1),
],
dim=-1,
).view(-1, 4)
rois = list()
for batch_idx in range(feats.shape[0]):
ids = torch.full((bboxes.shape[0], 1), batch_idx, dtype=feats.dtype, device=feats.device)
rois.append(torch.cat([ids, bboxes], dim=-1))
rois = torch.cat(rois, dim=0)
return rois
def forward(self, feats):
rois = self.gen_rois(feats)
bboxes_feat = roi_align(
input=feats,
boxes=rois,
output_size=(1, 1),
spatial_scale=1.0,
sampling_ratio=1
)
bs = feats.shape[0]
len = int(self.pool_size[0] * self.pool_size[1])
bboxes_feat = bboxes_feat.reshape(bs, len, -1)
return bboxes_feat
def tensor_convert_to_roi_format(line_bboxes):
B, L, _ = line_bboxes.shape
roi_ids = torch.zeros((B, L, 1)).to(line_bboxes).float()
for id in range(B):
roi_ids[id] = id
rois = torch.cat([roi_ids, line_bboxes], dim=-1).reshape(-1, 5)
return rois
class RoiFeatExtraxtor(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, feats, line_bboxes):
rois = tensor_convert_to_roi_format(line_bboxes)
lines_feat = roi_align(
input=feats,
boxes=rois,
output_size=(1, 1),
spatial_scale=self.scale,
sampling_ratio=1
)
lines_feat = lines_feat.reshape(lines_feat.shape[0], -1)
view_shape = line_bboxes.shape[:2]
lines_feat = lines_feat.view(*view_shape,-1)
return lines_feat
class RecogFeatExtraxtor(nn.Module):
def __init__(self, scale):
super().__init__()
self.scale = scale
def forward(self, feats, line_bboxes, output_size=(1,1)):
rois = tensor_convert_to_roi_format(line_bboxes)
lines_feat = roi_align(
input=feats,
boxes=rois,
output_size=output_size,
spatial_scale=self.scale,
sampling_ratio=2
)
return lines_feat
class ImageRegionExtractor(nn.Module):
def __init__(self, scale, output_size):
super().__init__()
self.scale = scale
self.output_size = output_size
def forward(self, images, line_bboxes):
rois = tensor_convert_to_roi_format(line_bboxes)
images_feat = roi_align(
input=images,
boxes=rois,
output_size=self.output_size,
spatial_scale=self.scale,
sampling_ratio=1
)
images_feat = images_feat.reshape(images_feat.shape[0], -1)
view_shape = line_bboxes.shape[:2]
images_feat = images_feat.view(*view_shape, -1)
return images_feat
import torch
import torch.nn.functional as F
def sigmoid_focal_loss( pred,
target,
weight=1.0,
gamma=2.0,
alpha=0.25,
reduction='mean'):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * weight
reduction_enum = F._Reduction.get_enum(reduction)
# none: 0, mean:1, sum: 2
if reduction_enum == 0:
return loss
elif reduction_enum == 1:
return loss.mean()
elif reduction_enum == 2:
return loss.sum()
\ No newline at end of file
import math
from typing import Optional
import torch
from torch import nn
class SinusoidalPositionalEmbedding(nn.Module):
"""This module produces sinusoidal positional embeddings of any length.
Padding symbols are ignored.
"""
def __init__(self, embedding_dim, padding_idx, init_size=1024):
super().__init__()
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.weights = SinusoidalPositionalEmbedding.get_embedding(
init_size, embedding_dim, padding_idx
)
self.max_positions = int(1e5)
def get_embedding(
num_embeddings: int, embedding_dim: int, padding_idx: Optional[int] = None
):
"""Build sinusoidal embeddings.
This matches the implementation in tensor2tensor, but differs slightly
from the description in Section 3.5 of "Attention Is All You Need".
"""
half_dim = embedding_dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, dtype=torch.float) * -emb)
emb = torch.arange(num_embeddings, dtype=torch.float).unsqueeze(
1
) * emb.unsqueeze(0)
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(
num_embeddings, -1
)
if embedding_dim % 2 == 1:
# zero pad
emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
if padding_idx is not None:
emb[padding_idx, :] = 0
return emb
def forward(
self,
positions,
max_pos=1024
):
"""Input is expected to be of size [bsz x seqlen]."""
bsz, seq_len = positions.size()
self.weights = SinusoidalPositionalEmbedding.get_embedding(
max_pos, self.embedding_dim, self.padding_idx
)
self.weights = self.weights.to(positions.device)
return (
self.weights.index_select(0, positions.view(-1))
.view(bsz, seq_len, -1)
.detach()
)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment