Commit da3f0934 authored by zhuwenwen's avatar zhuwenwen
Browse files

delete unused files

parent c4dd1fd4
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import distributed as dist
from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.utils import get_current_device
from torch.cuda.amp import custom_bwd, custom_fwd
class RingQK(torch.autograd.Function):
"""
Calculate QK in a ring-exchange style
"""
@staticmethod
@custom_fwd
def forward(ctx,
sub_q,
sub_k,
batch_size,
num_attention_heads,
sub_seq_length):
# save tensor for backward
ctx.save_for_backward(sub_q, sub_k)
ctx.sub_seq_length = sub_seq_length
# create local segment of attention score
attention_score = torch.empty(
batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device()
)
# compute local QK^T
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
start_idx = local_rank * sub_seq_length
end_idx = (local_rank + 1) * sub_seq_length
attention_score[:, :, start_idx: end_idx] = part_a
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
attention_score[:, :, start_idx:end_idx] = part_a
return attention_score
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
sub_q, sub_k, = ctx.saved_tensors
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# calculate gradient of sub_k
grad_k = torch.matmul(
grad_output.transpose(2, 1),
sub_q
)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
# calculate gradient for sub_q
grad_q = torch.zeros_like(sub_q,
dtype=sub_q.dtype,
device=get_current_device(), )
# compute with local sub_k
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
grad_q /= local_world_size
return grad_q, grad_k, None, None, None
class RingAV(torch.autograd.Function):
"""
Calculate AV in a ring-exchange style
"""
@staticmethod
@custom_fwd
def forward(ctx,
attention_score,
sub_v,
batch_size,
num_attention_heads,
attention_head_size,
sub_seq_length):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
sub_attention_result = torch.zeros(
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
# save tensors for backward
ctx.save_for_backward(attention_score, sub_v)
ctx.sub_seq_length = sub_seq_length
# compute local AV
part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
sub_attention_result += part_av
# compute AV in ring - all - reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
# compute QK^T
part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
sub_attention_result += part_av
return sub_attention_result
@staticmethod
@custom_bwd
def backward(ctx, grad_output):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
attention_scores, sub_v = ctx.saved_tensors
# calculate gradient of v
grad_v = torch.matmul(
attention_scores.transpose(2, 1),
grad_output
)
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_v = grad_v[:, local_start_idx:local_end_idx]
grad_v /= local_world_size
# calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores,
dtype=grad_output.dtype,
device=get_current_device())
# compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
# compute grad_q
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
return grad_attention_score, grad_v, None, None, None, None
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
def _calc_incoming_device_range(i, rank, world_size, sub_seq_length):
device_of_incoming_k = (rank - i - 1) % world_size
start_idx = sub_seq_length * device_of_incoming_k
end_idx = sub_seq_length * (device_of_incoming_k + 1)
return start_idx, end_idx
def _calc_current_device_range(rank, sub_seq_length):
start_idx = sub_seq_length * rank
end_idx = sub_seq_length * (rank + 1)
return start_idx, end_idx
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import colossalai
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import Parameter
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
from colossalai.kernel.cuda_native.scaled_softmax import AttnMaskType
from colossalai.kernel import FusedScaleMaskSoftmax
from colossalai.context import seed
@LAYERS.register_module
class TransformerSelfAttentionRing(nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
:param hidden_size: hidden size
:type hidden_size: int
:param kv_channels: channels of key/value tensor
:type kv_channels: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout: dropout probability for attention layer
:type attention_dropout: float
"""
def __init__(self,
hidden_size,
num_attention_heads,
attention_dropout,
attention_mask_func,
layer_number,
apply_query_key_layer_scaling: bool = False,
convert_fp16_to_fp32_in_softmax: bool = False,
attn_mask_type=AttnMaskType.padding,
masked_softmax_fusion=True,
fp16=False,
bf16=False
):
super().__init__()
self.convert_fp16_to_fp32_in_softmax = convert_fp16_to_fp32_in_softmax
self.apply_query_key_layer_scaling = apply_query_key_layer_scaling
self.attention_mask_func = attention_mask_func
self.layer_number = layer_number
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
self.attn_mask_type = attn_mask_type
assert self.layer_number > 0
self.attention_dropout = attention_dropout
if self.apply_query_key_layer_scaling:
self.convert_fp16_to_fp32_in_softmax = True
assert self.hidden_size % self.num_attention_heads == 0, \
'hidden size is not divisible by the number of attention heads'
self.hidden_size_per_attention_head = self.hidden_size // num_attention_heads
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# Strided linear layer.
self.query_key_value = _Linear(
hidden_size,
3 * self.hidden_size,
)
self.coeff = None
self.norm_factor = math.sqrt(self.hidden_size)
if self.apply_query_key_layer_scaling:
self.coeff = layer_number
self.norm_factor *= self.coeff
self.scale_mask_softmax = FusedScaleMaskSoftmax(
fp16, bf16,
self.attn_mask_type,
masked_softmax_fusion,
self.attention_mask_func,
self.convert_fp16_to_fp32_in_softmax,
self.coeff)
self.attention_dropout = nn.Dropout(attention_dropout)
# Output.
self.dense = _Linear(hidden_size,
hidden_size,
bias=True,
skip_bias_add=True)
def forward(self, hidden_states, attention_mask):
# hidden_states: [sub_seq_len, batch_size, hidden_size]
# attention_mask: [batch_size, 1, sub_seq_len, seq_len]
sub_seq_length, batch_size, hidden_size = hidden_states.size()
# =====================
# Query, Key, and Value
# =====================
# Attention heads shape change:
# [sub_seq_len, batch_size, hidden_size] --> [sub_seq_len, batch_size, (3 * head_size * num_heads)]
mixed_x_layer = self.query_key_value(hidden_states)
# [sub_seq_len, batch_size, num_heads, 3 * head_size] --> 3 [sub_seq_len, batch_size, num_heads, head_size]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# split into query, key and value
last_dim = mixed_x_layer.dim() - 1
last_dim_value = mixed_x_layer.size(-1)
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
'cannot be divided into query, key and value'
partition_size = last_dim_value // 3
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer, partition_size, dim=last_dim)
# attention scores: [batch_size, num_heads, sub_seq_len, seq_len]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0) * self.world_size)
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size * num_heads, head_size]
key_layer = key_layer.view(key_layer.size(0),
output_size[0] * output_size[1], -1)
# attention_scores: [batch_size * num_heads, sub_seq_len, seq_len]
attention_scores = RingQK.apply(
query_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size]
key_layer.transpose(0, 1).contiguous(), # [batch_size * num_heads, sub_seq_len, head_size],
batch_size,
self.num_attention_heads,
sub_seq_length
)
attention_scores /= self.norm_factor
# change view to [batch_size, num_heads, sub_seq_len, seq_len]
attention_scores = attention_scores.view(*output_size)
# change shape to [batch_size, num_heads, sub_seq_len, seq_len]
attention_probs = self.scale_mask_softmax(attention_scores, attention_mask)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
with seed(ParallelMode.TENSOR):
attention_probs = self.attention_dropout(attention_probs)
# context layer shape: [batch_size, num_heads, sub_seq_len, head_size]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
# change view [sub_seq_len, batch_size * num_heads, head_size]
value_layer = value_layer.contiguous().view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# # change view [b * num_heads, sub_seq_len, seq_len]
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
attention_probs.size(2),
attention_probs.size(3))
# matmul: [batch_size * num_heads, sub_seq_len, head_size]
context_layer = RingAV.apply(
attention_probs,
value_layer.transpose(0, 1).contiguous(),
batch_size,
self.num_attention_heads,
self.hidden_size_per_attention_head,
sub_seq_length
)
# change view [batch_size, num_heads, sub_seq_len, head_size]
context_layer = context_layer.view(*output_size)
# [batch_size, num_heads, sub_seq_len, head_size] -> [sub_seq_len, batch_size, num_heads, head_size]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# [sub_seq_len, batch_size, num_heads, head_size] -> [sub_seq_len, batch_size, hidden_size]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_attention_head * self.num_attention_heads,)
context_layer = context_layer.view(*new_context_layer_shape)
output, bias = self.dense(context_layer)
return output, bias
def __repr__(self):
return f'TransformerSelfAttentionRing(apply_query_key_layer_scaling={self.apply_query_key_layer_scaling}, ' \
f'layer_number={self.layer_number}, hidden_size:{self.hidden_size}, attention_dropout={self.attention_dropout}, ' \
f'attn_mask_type={self.attn_mask_type}, num_attention_heads={self.num_attention_heads}, ' \
f'hidden_size_per_attention_head={self.hidden_size_per_attention_head}, coeff={self.coeff}, norm_factor={self.norm_factor}, ' \
f'convert_fp16_to_fp32_in_softmax={self.convert_fp16_to_fp32_in_softmax})'
class _Linear(nn.Module):
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
its second dimension as A = [A_1, ..., A_p].
Arguments:
input_size: first dimension of matrix A.
output_size: second dimension of matrix A.
bias: If true, add bias
init_method: method to initialize weights. Note that bias is always set
to zero.
stride: For the strided linear layers.
keep_master_weight_for_test: This was added for testing and should be
set to False. It returns the master weights
used for initialization.
skip_bias_add: This was added to enable performance optimations where bias
can be fused with other elementwise operations. we skip
adding bias but instead return it.
"""
def __init__(self,
input_size,
output_size,
bias=True,
skip_bias_add=False):
super(_Linear, self).__init__()
# Keep input parameters
self.input_size = input_size
self.output_size = output_size
self.skip_bias_add = skip_bias_add
self.weight = Parameter(torch.empty(self.output_size,
self.input_size,
))
nn.init.xavier_normal_(self.weight)
if bias:
self.bias = Parameter(torch.empty(self.output_size))
# Always initialize bias to zero.
with torch.no_grad():
self.bias.zero_()
else:
self.register_parameter('bias', None)
def forward(self, input_):
# Matrix multiply.
bias = self.bias if not self.skip_bias_add else None
output = F.linear(input_, self.weight, bias)
if self.skip_bias_add:
return output, self.bias
else:
return output
def __repr__(self):
return f'Linear(in_features={self.input_size}, out_features={self.output_size}, ' + \
f'bias={self.bias is not None}, skip_bias_add={self.skip_bias_add})'
from .common import (ACT2FN, CheckpointModule, _ntuple, divide, get_tensor_parallel_mode,
set_tensor_parallel_attribute_by_partition, set_tensor_parallel_attribute_by_size, to_2tuple)
__all__ = [
'CheckpointModule', 'divide', 'ACT2FN', 'set_tensor_parallel_attribute_by_size',
'set_tensor_parallel_attribute_by_partition', 'get_tensor_parallel_mode', '_ntuple', 'to_2tuple'
]
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import collections.abc
from itertools import repeat
import numpy as np
import torch
from colossalai.constants import IS_TENSOR_PARALLEL, NUM_PARTITIONS
from colossalai.global_variables import tensor_parallel_env as env
from colossalai.utils import checkpoint
from torch import Tensor, nn
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = True):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError('CheckpointModule should implement _forward method instead of origin forward')
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()
def divide(numerator, denominator):
"""Only allow exact division
:param numerator: Numerator of the division
:param denominator: Denominator of the division
"""
assert numerator % denominator == 0, \
'{} is not divisible by {}'.format(numerator, denominator)
return numerator // denominator
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
ACT2FN = {"gelu": torch.nn.functional.gelu, "relu": torch.nn.functional.relu, "swish": swish}
def set_tensor_parallel_attribute_by_size(param, size):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, size // np.prod(param.shape))
def set_tensor_parallel_attribute_by_partition(param, num_partitions):
setattr(param, IS_TENSOR_PARALLEL, True)
setattr(param, NUM_PARTITIONS, num_partitions)
def get_tensor_parallel_mode():
return env.mode
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
from .layers import DropPath, VanillaClassifier, VanillaPatchEmbedding, \
WrappedDropout, WrappedDropPath
__all__ = ['VanillaPatchEmbedding', 'VanillaClassifier', 'DropPath',
'WrappedDropout', 'WrappedDropPath']
import math
from typing import Callable
import torch
import torch.nn.functional as F
from colossalai.context import seed
from colossalai.nn import init as init
from colossalai.registry import LAYERS
from colossalai.utils.cuda import get_current_device
from torch import Tensor
from torch import nn as nn
from ..utils import to_2tuple
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
class DropPath(nn.Module):
"""
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Adapted from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
class WrappedDropout(nn.Module):
"""Same as torch.nn.Dropout. But it is wrapped with the context of seed manager.
"""
def __init__(self, p: float = 0.5, inplace: bool = False, mode=None):
super().__init__()
if p < 0 or p > 1:
raise ValueError("dropout probability has to be between 0 and 1, "
"but got {}".format(p))
self.p = p
self.inplace = inplace
if mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return F.dropout(inputs, self.p, self.training, self.inplace)
def normalfunc(self, inputs):
with seed(self.mode):
return F.dropout(inputs, self.p, self.training, self.inplace)
def forward(self, inputs):
return self.func(inputs)
class WrappedDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
Here, it is wrapped with the context of seed manager.
"""
def __init__(self, p: float = 0., mode=None):
super().__init__()
self.p = p
self.mode = mode
if self.mode is None:
self.func = self.nonefunc
else:
self.func = self.normalfunc
self.mode = mode
def nonefunc(self, inputs):
return drop_path(inputs, self.p, self.training)
def normalfunc(self, inputs):
with seed(self.mode):
return drop_path(inputs, self.p, self.training)
def forward(self, inputs):
return self.func(inputs)
@LAYERS.register_module
class VanillaPatchEmbedding(nn.Module):
"""
2D Image to Patch Embedding
:param img_size: image size
:type img_size: int
:param patch_size: patch size
:type patch_size: int
:param in_chans: number of channels of input image
:type in_chans: int
:param embed_size: size of embedding
:type embed_size: int
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param flatten: whether to flatten output tensor, defaults to True
:type flatten: bool, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
:param position_embed_initializer: The intializer of position embedding, defaults to zero
:type position_embed_initializer: typing.Callable, optional
"""
def __init__(self,
img_size: int,
patch_size: int,
in_chans: int,
embed_size: int,
flatten: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
position_embed_initializer: Callable = init.zeros_()):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.weight = nn.Parameter(
torch.empty((embed_size, in_chans, *self.patch_size), device=get_current_device(), dtype=dtype))
self.bias = nn.Parameter(torch.empty(embed_size, device=get_current_device(), dtype=dtype))
self.cls_token = nn.Parameter(torch.zeros((1, 1, embed_size), device=get_current_device(), dtype=dtype))
self.pos_embed = nn.Parameter(
torch.zeros((1, self.num_patches + 1, embed_size), device=get_current_device(), dtype=dtype))
self.reset_parameters(weight_initializer, bias_initializer, position_embed_initializer)
def reset_parameters(self, weight_initializer, bias_initializer, position_embed_initializer):
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(self.weight)
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
bias_initializer(self.bias, fan_in=fan_in)
position_embed_initializer(self.pos_embed)
def forward(self, input_: Tensor) -> Tensor:
B, C, H, W = input_.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
if self.flatten:
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
output = torch.cat((cls_token, output), dim=1)
output = output + self.pos_embed
return output
@LAYERS.register_module
class VanillaClassifier(nn.Module):
"""
Dense linear classifier
:param in_features: size of each input sample
:type in_features: int
:param num_classes: number of classes
:type num_classes: int
:param weight: weight of the classifier, defaults to True
:type weight: torch.nn.Parameter, optional
:param bias: If set to ``False``, the layer will not learn an additive bias, defaults to True
:type bias: bool, optional
:param dtype: The dtype of parameters, defaults to None
:type dtype: torch.dtype, optional
:param weight_initializer: The intializer of weight, defaults to kaiming uniform initializer
:type weight_initializer: typing.Callable, optional
:param bias_initializer: The intializer of bias, defaults to xavier uniform initializer
:type bias_initializer: typing.Callable, optional
"""
def __init__(self,
in_features: int,
num_classes: int,
weight: nn.Parameter = None,
bias: bool = True,
dtype: torch.dtype = None,
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
super().__init__()
self.in_features = in_features
self.num_classes = num_classes
if weight is not None:
self.weight = weight
self.has_weight = False
else:
self.weight = nn.Parameter(
torch.empty(self.num_classes, self.in_features, device=get_current_device(), dtype=dtype))
self.has_weight = True
if bias:
self.bias = nn.Parameter(torch.zeros(self.num_classes, device=get_current_device(), dtype=dtype))
else:
self.bias = None
self.reset_parameters(weight_initializer, bias_initializer)
def reset_parameters(self, weight_initializer, bias_initializer):
fan_in, fan_out = self.in_features, self.num_classes
if self.has_weight:
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
if self.bias is not None:
bias_initializer(self.bias, fan_in=fan_in)
def forward(self, input_: Tensor) -> Tensor:
return F.linear(input_, self.weight, self.bias)
from .lambda_wrapper import LambdaWrapper
from .pipeline_wrapper import PipelineSharedModuleWrapper
__all__ = ['LambdaWrapper', 'PipelineSharedModuleWrapper']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment