Commit 404ecbdc authored by zbian's avatar zbian
Browse files

Migrated project

parent 2ebaefc5
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
from typing import Tuple
import torch
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.registry import LAYERS
from colossalai.utils import get_current_device
from torch import Tensor, dtype
from torch.nn import Parameter
from .._common_utils import divide, set_tensor_parallel_attribute
from ._operation import Add_3D, Matmul_AB_3D, Mul_3D, Sum_3D
from ._utils import get_depth_from_env, get_last_group
@LAYERS.register_module
class LayerNorm3D(nn.Module):
def __init__(
self,
normalized_shape: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
eps: float = 1e-12,
dtype: dtype = None,
):
super().__init__()
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.depth = get_depth_from_env()
self.normalized_shape = normalized_shape
self.normalized_shape_per_partition = divide(normalized_shape,
self.depth**2)
self.weight = Parameter(
torch.ones(self.normalized_shape_per_partition,
device=get_current_device(),
dtype=dtype))
self.bias = Parameter(
torch.zeros(self.normalized_shape_per_partition,
device=get_current_device(),
dtype=dtype))
self.variance_epsilon = eps
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
set_tensor_parallel_attribute(self.bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.input_parallel_mode, self.weight_parallel_mode
def reset_parameters(self):
nn.init.zeros_(self.bias)
nn.init.ones_(self.weight)
def forward(self, input_: Tensor) -> Tensor:
'''x = weight * (x - mean) / sqrt(var + eps) + bias'''
# input: [m/q^2, n, h/q]
# [m/q^2, n, 1]
mean = Sum_3D.apply(input_, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape
# [m/q^2, n, 1]
var = (input_ - mean).pow(2)
var = Sum_3D.apply(var, -1, self.depth, self.output_parallel_mode,
True) / self.normalized_shape
output = (input_ - mean) / torch.sqrt(var + self.variance_epsilon)
output = Mul_3D.apply(output, self.weight, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
output = Add_3D.apply(output, self.bias, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
return output
def extra_repr(self):
return '{}, eps={}'.format(self.normalized_shape,
self.variance_epsilon)
@LAYERS.register_module
class Linear3D(nn.Module):
def __init__(self,
in_features: int,
out_features: int,
input_parallel_mode: ParallelMode,
weight_parallel_mode: ParallelMode,
bias: bool = True,
dtype: dtype = None):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.input_parallel_mode = input_parallel_mode
self.weight_parallel_mode = weight_parallel_mode
self.output_parallel_mode = get_last_group(self.input_parallel_mode,
self.weight_parallel_mode)
self.with_bias = bias
self.depth = get_depth_from_env()
self.in_features_per_partition = divide(in_features, self.depth)
self.out_features_per_partition = divide(out_features, self.depth**2)
# [k/q, h/q^2]
self.weight = Parameter(
torch.empty(self.in_features_per_partition,
self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
# [h/q^2]
if bias:
self.bias = Parameter(
torch.zeros(self.out_features_per_partition,
device=get_current_device(),
dtype=dtype))
else:
self.register_parameter('bias', None)
self.reset_parameters()
self._set_tensor_parallel_attributes()
def _set_tensor_parallel_attributes(self):
set_tensor_parallel_attribute(self.weight)
if self.bias is not None:
set_tensor_parallel_attribute(self.bias)
def groups_for_next_layer(self) -> Tuple[ParallelMode, ParallelMode]:
return self.output_parallel_mode, self.weight_parallel_mode
def reset_parameters(self):
# setting
fan_in = self.in_features
a = math.sqrt(5)
nonlinearity = 'leaky_relu'
# init weight
std = nn.init.calculate_gain(nonlinearity, a) / math.sqrt(fan_in)
bound = math.sqrt(3.0) * std
with seed(ParallelMode.TENSOR):
nn.init.uniform_(self.weight, -bound, bound)
# init bias
if self.with_bias:
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
with seed(ParallelMode.TENSOR):
nn.init.uniform_(self.bias, -bound, bound)
def forward(self, input_: Tensor) -> Tensor:
# input: [m/q^2, n, k/q]
# output: [m/q^2, n, h/q]
output = Matmul_AB_3D.apply(input_, self.weight, self.depth,
self.input_parallel_mode,
self.weight_parallel_mode,
self.output_parallel_mode)
if self.with_bias:
output = Add_3D.apply(output, self.bias, self.depth,
self.output_parallel_mode,
self.weight_parallel_mode,
self.input_parallel_mode)
return output
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.with_bias)
from ._operation import RingQK, RingAV
from .layers import TransformerSelfAttentionRing
__all__ = ['TransformerSelfAttentionRing', 'RingAV', 'RingQK']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from torch import distributed as dist
from colossalai.communication import ring_forward
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._utils import _calc_incoming_device_range, _calc_current_device_range
from colossalai.utils import get_current_device
class RingQK(torch.autograd.Function):
"""
Calculate QK in a ring-exchange style
"""
@staticmethod
def forward(ctx,
sub_q,
sub_k,
batch_size,
num_attention_heads,
sub_seq_length):
# save tensor for backward
ctx.save_for_backward(sub_q, sub_k)
ctx.sub_seq_length = sub_seq_length
# create local segment of attention score
attention_score = torch.empty(
batch_size * num_attention_heads,
sub_seq_length,
sub_seq_length * gpc.get_world_size(ParallelMode.SEQUENCE),
dtype=sub_q.dtype,
device=get_current_device()
)
# compute local QK^T
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
start_idx = local_rank * sub_seq_length
end_idx = (local_rank + 1) * sub_seq_length
attention_score[:, :, start_idx: end_idx] = part_a
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
part_a = torch.matmul(sub_q, sub_k.transpose(2, 1))
attention_score[:, :, start_idx:end_idx] = part_a
return attention_score
@staticmethod
def backward(ctx, grad_output):
sub_q, sub_k, = ctx.saved_tensors
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# calculate gradient of sub_k
grad_k = torch.matmul(
grad_output.transpose(2, 1),
sub_q
)
dist.all_reduce(grad_k, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_k = grad_k[:, local_rank * ctx.sub_seq_length: (local_rank + 1) * ctx.sub_seq_length]
grad_k /= local_world_size
# calculate gradient for sub_q
grad_q = torch.zeros_like(sub_q,
dtype=sub_q.dtype,
device=get_current_device(), )
# compute with local sub_k
start_idx, end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx:end_idx], sub_k)
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_k = ring_forward(sub_k, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
grad_q += torch.matmul(grad_output[:, :, start_idx: end_idx], sub_k)
grad_q /= local_world_size
return grad_q, grad_k, None, None, None
class RingAV(torch.autograd.Function):
"""
Calculate AV in a ring-exchange style
"""
@staticmethod
def forward(ctx,
attention_score,
sub_v,
batch_size,
num_attention_heads,
attention_head_size,
sub_seq_length):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, sub_seq_length)
sub_attention_result = torch.zeros(
batch_size * num_attention_heads,
sub_seq_length,
attention_head_size,
device=get_current_device(),
dtype=attention_score.dtype)
# save tensors for backward
ctx.save_for_backward(attention_score, sub_v)
ctx.sub_seq_length = sub_seq_length
# compute local AV
part_av = torch.matmul(attention_score[:, :, local_start_idx:local_end_idx], sub_v)
sub_attention_result += part_av
# compute AV in ring - all - reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, sub_seq_length)
# compute QK^T
part_av = torch.matmul(attention_score[:, :, start_idx:end_idx], sub_v)
sub_attention_result += part_av
return sub_attention_result
@staticmethod
def backward(ctx, grad_output):
local_rank = gpc.get_local_rank(ParallelMode.SEQUENCE)
local_world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
local_start_idx, local_end_idx = _calc_current_device_range(local_rank, ctx.sub_seq_length)
attention_scores, sub_v = ctx.saved_tensors
# calculate gradient of v
grad_v = torch.matmul(
attention_scores.transpose(2, 1),
grad_output
)
dist.all_reduce(grad_v, group=gpc.get_group(ParallelMode.SEQUENCE))
grad_v = grad_v[:, local_start_idx:local_end_idx]
grad_v /= local_world_size
# calculate gradient for attention score
grad_attention_score = torch.zeros_like(attention_scores,
dtype=grad_output.dtype,
device=get_current_device())
# compute with local sub_k
grad_attention_score[:, :, local_start_idx:local_end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
# compute QK^T in ring-all-reduce style
for i in range(local_world_size - 1):
sub_v = ring_forward(sub_v, ParallelMode.SEQUENCE)
start_idx, end_idx = _calc_incoming_device_range(i, local_rank, local_world_size, ctx.sub_seq_length)
# compute grad_q
grad_attention_score[:, :, start_idx:end_idx] += torch.matmul(
grad_output,
sub_v.transpose(2, 1))
return grad_attention_score, grad_v, None, None, None, None
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
def _calc_incoming_device_range(i, rank, world_size, sub_seq_length):
device_of_incoming_k = (rank - i - 1) % world_size
start_idx = sub_seq_length * device_of_incoming_k
end_idx = sub_seq_length * (device_of_incoming_k + 1)
return start_idx, end_idx
def _calc_current_device_range(rank, sub_seq_length):
start_idx = sub_seq_length * rank
end_idx = sub_seq_length * (rank + 1)
return start_idx, end_idx
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_sequence._operation import RingQK, RingAV
from colossalai.registry import LAYERS
@LAYERS.register_module
class TransformerSelfAttentionRing(nn.Module):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
and returns output of the same size.
:param hidden_size: hidden size
:type hidden_size: int
:param kv_channels: channels of key/value tensor
:type kv_channels: int
:param num_attention_heads: number of attention heads
:type num_attention_heads: int
:param attention_dropout: dropout probability for attention layer
:type attention_dropout: float
"""
def __init__(self,
hidden_size,
kv_channels,
num_attention_heads,
attention_dropout,
):
super().__init__()
self.hidden_size = hidden_size
self.num_attention_heads = num_attention_heads
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = projection_size // num_attention_heads
self.world_size = gpc.get_world_size(ParallelMode.SEQUENCE)
# Strided linear layer.
self.query_key_value = nn.Linear(
hidden_size,
3 * projection_size,
)
# coeff = None
self.norm_factor = math.sqrt(self.hidden_size)
# TODO: add apply_query_key_layer_scaling when we have the kernel module
# if self.apply_query_key_layer_scaling:
# coeff = self.layer_number
# self.norm_factor *= coeff
# TODO: add fused scale mask softmax kernel when we have the kernel module
# self.scale_mask_softmax = FusedScaleMaskSoftmax(
# self.fp16, self.bf16,
# self.attn_mask_type,
# masked_softmax_fusion,
# attention_mask_func,
# self.attention_softmax_in_fp32,
# coeff)
self.attention_dropout = nn.Dropout(attention_dropout)
# Output.
self.dense = nn.Linear(
projection_size,
hidden_size,
bias=True)
def forward(self, hidden_states, attention_mask):
# hidden_states: [sq, b, h]
sub_seq_length, batch_size, hidden_size = hidden_states.size()
# =====================
# Query, Key, and Value
# =====================
# Attention heads [sq, b, h] --> [sq, b, (3 * hn * num_heads)]
mixed_x_layer = self.query_key_value(hidden_states)
# [sq, b, num_heads, 3 * hn] --> 3 [sq, b, num_heads, hn]
new_tensor_shape = mixed_x_layer.size()[:-1] + (self.num_attention_heads,
3 * self.hidden_size_per_attention_head)
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# split into query, key and value
last_dim = mixed_x_layer.dim() - 1
last_dim_value = mixed_x_layer.size()[-1]
assert last_dim_value % 3 == 0, 'the last dimension is not a multiple of 3, ' \
'cannot be divided into query, key and value'
partition_size = last_dim_value // 3
(query_layer, key_layer, value_layer) = torch.split(
mixed_x_layer, partition_size, dim=last_dim)
# ===================================
# Raw attention scores. [b, num_heads, s, s]
# ===================================
# [b, num_heads, sq, sk]
output_size = (query_layer.size(1),
query_layer.size(2),
query_layer.size(0),
key_layer.size(0) * self.world_size)
# [sq, b, num_heads, hn] -> [sq, b * num_heads, hn]
query_layer = query_layer.view(output_size[2],
output_size[0] * output_size[1], -1)
# [sk, b, num_heads, hn] -> [sk, b * num_heads, hn]
key_layer = key_layer.view(key_layer.size(0),
output_size[0] * output_size[1], -1)
# [b, sq, sk]
attention_scores = RingQK.apply(
# [b * num_heads, sq, hn]
query_layer.transpose(0, 1).contiguous(),
key_layer.transpose(0, 1).contiguous(), # [b * num_heads, sk, hn],
batch_size,
self.num_attention_heads,
sub_seq_length
)
attention_scores /= self.norm_factor
# change view to [b, num_heads, sq, sk]
attention_scores = attention_scores.view(*output_size)
attention_scores = attention_scores.unsqueeze(1)
attention_scores = attention_scores + attention_mask
attention_probs = F.softmax(attention_scores, dim=-1)
attention_probs = attention_probs.squeeze(1)
# This is actually dropping out entire tokens to attend to, which might
# seem a bit unusual, but is taken from the original Transformer paper.
# with mpu.get_cuda_rng_tracker().fork():
# TODO: check if a rng tracker is needed
attention_probs = self.attention_dropout(attention_probs)
# context layer shape: [b, num_heads, sq, hn]
output_size = (value_layer.size(1),
value_layer.size(2),
query_layer.size(0),
value_layer.size(3))
#
# # change view [sk, b * num_heads, hn]
value_layer = value_layer.contiguous().view(value_layer.size(0),
output_size[0] * output_size[1], -1)
# # change view [b * num_heads, sq, sk]
attention_probs = attention_probs.view(attention_probs.size(0) * attention_probs.size(1),
attention_probs.size(2),
attention_probs.size(3))
# matmul: [b*num_heads, sq, hn]
# context_layer = torch.bmm(attention_probs, value_layer.transpose(0, 1))
context_layer = RingAV.apply(
attention_probs,
value_layer.transpose(0, 1).contiguous(),
batch_size,
self.num_attention_heads,
self.hidden_size_per_attention_head,
sub_seq_length
)
# # change view [b, num_heads, sq, hn]
context_layer = context_layer.view(*output_size)
# # [b, np, sq, hn] --> [sq, b, np, hn]
context_layer = context_layer.permute(2, 0, 1, 3).contiguous()
# # [sq, b, np, hn] --> [sq, b, hp]
new_context_layer_shape = context_layer.size()[:-2] + (
self.hidden_size_per_attention_head * self.num_attention_heads,)
context_layer = context_layer.view(*new_context_layer_shape)
# context_layer = context_layer.transpose(1, 0).contiguous()
output = self.dense(context_layer)
bias = self.dense.bias
return output, bias
from .layers import ViTBlock
__all__ = ['ViTBlock']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from torch import nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class ViTBlock(nn.Module):
"""Vision Transformer block
:param attention_cfg: config of attention layer
:type attention_cfg: dict
:param droppath_cfg: config of drop path
:type droppath_cfg: dict
:param mlp_cfg: config of MLP layer
:type mlp_cfg: dict
:param norm_cfg: config of normlization layer
:type norm_cfg: dict
"""
def __init__(self,
attention_cfg: dict,
droppath_cfg: dict,
mlp_cfg: dict,
norm_cfg: dict,
):
super().__init__()
self.norm1 = build_layer(norm_cfg)
self.attn = build_layer(attention_cfg)
self.drop_path = build_layer(
droppath_cfg) if droppath_cfg['drop_path'] > 0. else nn.Identity()
self.norm2 = build_layer(norm_cfg)
self.mlp = build_layer(mlp_cfg)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
# x_ = x
# x_ = self.norm1(x_)
# if self.checkpoint:
# x_ = checkpoint(self.attn, x_)
# else:
# x_ = self.attn(x_)
# x_ = self.drop_path(x_)
# x = x + x_
#
# x_ = x
# x_ = self.norm2(x_)
# if self.checkpoint:
# x_ = checkpoint(self.mlp, x_)
# else:
# x_ = self.mlp(x_)
# x_ = self.drop_path(x_)
# x = x + x_
return x
from .basic_block import ResNetBasicBlock
from .bottleneck import ResNetBottleneck
from .reslayer import ResLayer
__all__ = ['ResLayer', 'ResNetBottleneck', 'ResNetBasicBlock']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3
@LAYERS.register_module
class ResNetBasicBlock(nn.Module):
"""Basic ResNet block
"""
expansion: int = 1
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
if groups != 1 or base_width != 64:
raise ValueError(
'BasicBlock only supports groups=1 and base_width=64')
if dilation > 1:
raise NotImplementedError(
"Dilation > 1 not supported in BasicBlock")
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv3x3(inplanes, planes, stride)
self.bn1 = norm_layer(planes)
self.relu = nn.ReLU(inplace=True)
self.conv2 = conv3x3(planes, planes)
self.bn2 = norm_layer(planes)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from typing import Optional, Callable
import torch.nn as nn
from torch import Tensor
from colossalai.registry import LAYERS
from .conv import conv3x3, conv1x1
@LAYERS.register_module
class ResNetBottleneck(nn.Module):
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
# This variant is also known as ResNet V1.5 and improves accuracy according to
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
expansion: int = 4
def __init__(
self,
inplanes: int,
planes: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
groups: int = 1,
base_width: int = 64,
dilation: int = 1,
norm_layer: Optional[Callable[..., nn.Module]] = None
) -> None:
super().__init__()
if norm_layer is None:
norm_layer = nn.BatchNorm2d
width = int(planes * (base_width / 64.)) * groups
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
self.conv1 = conv1x1(inplanes, width)
self.bn1 = norm_layer(width)
self.conv2 = conv3x3(width, width, stride, groups, dilation)
self.bn2 = norm_layer(width)
self.conv3 = conv1x1(width, planes * self.expansion)
self.bn3 = norm_layer(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x: Tensor) -> Tensor:
identity = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
identity = self.downsample(x)
out += identity
out = self.relu(out)
return out
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
def conv3x3(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=dilation, groups=groups, bias=False, dilation=dilation)
def conv1x1(in_planes: int, out_planes: int, stride: int = 1) -> nn.Conv2d:
"""1x1 convolution"""
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.registry import LAYERS
from .conv import conv1x1
@LAYERS.register_module
class ResLayer(nn.Module):
def __init__(self,
block_type: str,
norm_layer_type: str,
inplanes: int,
planes: int,
blocks: int,
groups: int,
base_width: int,
stride: int = 1,
dilation: int = 1,
dilate: bool = False,
):
super().__init__()
self.block = LAYERS.get_module(block_type)
self.norm_layer = LAYERS.get_module(norm_layer_type)
self.inplanes = inplanes
self.planes = planes
self.blocks = blocks
self.groups = groups
self.dilation = dilation
self.base_width = base_width
self.dilate = dilate
self.stride = stride
self.layer = self._make_layer()
def _make_layer(self):
norm_layer = self.norm_layer
downsample = None
previous_dilation = self.dilation
if self.dilate:
self.dilation *= self.stride
self.stride = 1
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
downsample = nn.Sequential(
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
norm_layer(self.planes * self.block.expansion),
)
layers = []
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
self.base_width, previous_dilation, norm_layer))
self.inplanes = self.planes * self.block.expansion
for _ in range(1, self.blocks):
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
base_width=self.base_width, dilation=self.dilation,
norm_layer=norm_layer))
return nn.Sequential(*layers)
def forward(self, x):
return self.layer(x)
from .layers import (VanillaViTBlock, VanillaViTMLP, VanillaViTPatchEmbedding,
VanillaViTAttention, VanillaViTDropPath, VanillaViTHead)
__all__ = [
'VanillaViTBlock', 'VanillaViTMLP', 'VanillaViTPatchEmbedding',
'VanillaViTAttention', 'VanillaViTDropPath', 'VanillaViTHead'
]
import collections.abc
from itertools import repeat
import torch
from torch import nn as nn
from colossalai.registry import LAYERS
# From PyTorch internals
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
@LAYERS.register_module
class VanillaViTPatchEmbedding(nn.Module):
""" 2D Image to Patch Embedding
"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True, drop=0.):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size[0] // patch_size[0],
img_size[1] // patch_size[1])
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.flatten = flatten
self.proj = nn.Conv2d(in_chans, embed_dim,
kernel_size=patch_size, stride=patch_size)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))
self.pos_drop = nn.Dropout(p=drop)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
x = self.proj(x)
if self.flatten:
x = x.flatten(2).transpose(1, 2) # BCHW -> BNC
x = self.norm(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.pos_drop(x + self.pos_embed)
return x
@LAYERS.register_module
class VanillaViTMLP(nn.Module):
""" MLP as used in Vision Transformer, MLP-Mixer and related networks
"""
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
def drop_path(x, drop_prob: float = 0., training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
'survival rate' as the argument.
"""
if drop_prob == 0. or not training:
return x
keep_prob = 1 - drop_prob
# work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1)
random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
@LAYERS.register_module
class VanillaViTDropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
"""
def __init__(self, drop_prob=0.):
super().__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
@LAYERS.register_module
class VanillaViTAttention(nn.Module):
"""Vanilla attention layer of Vision Transformer
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads, defaults to 8
:type num_heads: int, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param proj_drop: dropout probability for linear layer, defaults to 0.
:type proj_drop: float, optional
"""
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C //
self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
@LAYERS.register_module
class VanillaViTBlock(nn.Module):
"""Vanilla Vision Transformer block
:param dim: dimension of input tensor
:type dim: int
:param num_heads: number of attention heads
:type num_heads: int
:param mlp_ratio: hidden size of MLP divided by embedding dim, defaults to 4.
:type mlp_ratio: float, optional
:param qkv_bias: enable bias for qkv if True, defaults to False
:type qkv_bias: bool, optional
:param drop: dropout probability, defaults to 0.
:type drop: float, optional
:param attn_drop: dropout probability for attention layer, defaults to 0.
:type attn_drop: float, optional
:param drop_path: drop path probability, defaults to 0.
:type drop_path: float, optional
:param act_layer: activation function, defaults to nn.GELU
:type act_layer: torch.nn.Module, optional
:param norm_layer: normalization layer, defaults to nn.LayerNorm
:type norm_layer: torch.nn.Module, optional
"""
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(dim)
self.attn = LAYERS.get_module('VanillaViTAttention')(dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
attn_drop=attn_drop,
proj_drop=drop)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = LAYERS.get_module('VanillaViTDropPath')(
drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = LAYERS.get_module('VanillaViTMLP')(in_features=dim,
hidden_features=mlp_hidden_dim,
act_layer=act_layer,
drop=drop)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
@LAYERS.register_module
class VanillaViTHead(nn.Module):
"""Output layer of vanilla Vision Transformer
:param in_features: size of input tensor
:type in_features: int
:param intermediate_features: hidden size
:type intermediate_features: int
:param out_features: size of output tensor
:type out_features: int
:param bias: whether to add bias, defaults to True
:type bias: bool, optional
"""
def __init__(self,
in_features,
intermediate_features,
out_features,
bias=True
):
super().__init__()
self.linear_1 = nn.Linear(
in_features, intermediate_features, bias=bias)
self.act = nn.Tanh()
self.linear_2 = nn.Linear(
intermediate_features, out_features, bias=bias)
def forward(self, x):
x = x[:, 0, :].squeeze(1)
x = self.linear_1(x)
x = self.act(x)
x = self.linear_2(x)
return x
from .lambda_wrapper import LambdaWrapper
__all__ = ['LambdaWrapper']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch.nn as nn
from colossalai.builder import build_layer
from colossalai.registry import LAYERS
@LAYERS.register_module
class LambdaWrapper(nn.Module):
"""Wrap a function to nn.Module, which takes a config of layers and can fully access them
:param func: user customed function
:type func: Callable
:param layers_cfg: config of layers, defaults to None
:type layers_cfg: dict, optional
"""
def __init__(self, func, layers_cfg: dict = None):
super().__init__()
self.func = func
self.layers = self._build_layers(layers_cfg)
def _build_layers(self, layers_cfg: dict):
if layers_cfg is None:
return None
else:
layers = []
for cfg in layers_cfg:
layer = build_layer(cfg)
layers.append(layer)
return layers
def forward(self, *args, **kwargs):
return self.func(self, *args, **kwargs)
from .base_loss import BaseLoss
from .cross_entropy_2d import CrossEntropyLoss2D
from .cross_entropy_2p5d import CrossEntropyLoss2p5D
from .cross_entropy_3d import CrossEntropyLoss3D
__all__ = ['CrossEntropyLoss2D', 'CrossEntropyLoss2p5D', 'CrossEntropyLoss3D']
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from abc import ABC, abstractmethod
class BaseLoss(ABC):
"""Absctract loss class
"""
@abstractmethod
def calc_loss(self, *args, **kwargs):
pass
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
import torch.nn.functional as F
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_1d._utils import vocab_range_from_per_partition_vocab_size
class _VocabParallelCrossEntropy_1D(torch.autograd.Function):
@staticmethod
def forward(ctx, vocab_parallel_logits, target):
# Maximum value along vocab dimension across all GPUs.
logits_max = torch.max(vocab_parallel_logits, dim=-1)[0]
torch.distributed.all_reduce(logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Subtract the maximum value.
vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
# Get the partition's vocab indecies
partition_vocab_size = vocab_parallel_logits.size()[-1]
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
world_size = gpc.get_world_size(ParallelMode.PARALLEL_1D)
vocab_start_index, vocab_end_index = vocab_range_from_per_partition_vocab_size(
partition_vocab_size, rank, world_size)
# Create a mask of valid vocab ids (1 means it needs to be masked).
target_mask = (target < vocab_start_index) | (target >= vocab_end_index)
masked_target = target.clone() - vocab_start_index
masked_target[target_mask] = 0
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d = vocab_parallel_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
arange_1d = torch.arange(start=0, end=logits_2d.size()[0],
device=logits_2d.device)
predicted_logits_1d = logits_2d[arange_1d, masked_target_1d]
predicted_logits_1d = predicted_logits_1d.clone().contiguous()
predicted_logits = predicted_logits_1d.view_as(target)
predicted_logits[target_mask] = 0.0
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits = vocab_parallel_logits
torch.exp(vocab_parallel_logits, out=exp_logits)
sum_exp_logits = exp_logits.sum(dim=-1)
torch.distributed.all_reduce(sum_exp_logits,
op=torch.distributed.ReduceOp.SUM,
group=gpc.get_group(ParallelMode.PARALLEL_1D))
# Loss = log(sum(exp(logits))) - predicted-logit.
loss = torch.log(sum_exp_logits) - predicted_logits
# Store softmax, target-mask and masked-target for backward pass.
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target_1d)
return loss
@staticmethod
def backward(ctx, grad_output):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target_1d = ctx.saved_tensors
# All the inputs have softmax as thier gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=grad_2d.device)
grad_2d[arange_1d, masked_target_1d] -= (
1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(grad_output.unsqueeze(dim=-1))
return grad_input, None
class LmLoss1D(_Loss):
def forward(self, lm_logits, lm_labels, loss_mask):
lm_loss = _VocabParallelCrossEntropy_1D.apply(lm_logits, lm_labels)
lm_loss = torch.sum(
lm_loss.view(-1) * loss_mask.reshape(-1)) / loss_mask.sum()
return lm_loss
class SopLoss1D(_Loss):
def forward(self, sop_logits, sentence_order):
sop_loss = F.cross_entropy(sop_logits.view(-1, 2).float(),
sentence_order.view(-1),
ignore_index=-1)
return sop_loss
class BERTDualHeadLoss(_Loss):
def __init__(self):
self.lm_loss = LmLoss1D()
self.sop_loss = SopLoss1D()
def forward(self, lm_logits, sop_logits, lm_labels, loss_mask, sentence_order):
lm_loss = self.lm_loss(lm_logits, lm_labels, loss_mask)
sop_loss = self.sop_loss(sop_logits, sentence_order)
return lm_loss + sop_loss
import torch
import torch.distributed as dist
from torch.nn.modules.loss import _Loss
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.nn.layer.parallel_2d._utils import assert_summa_initialization, get_summa_dim_from_env
from colossalai.registry import LOSSES
from colossalai.utils import get_current_device
class _ParallelCrossEntropyLossFunction_2D(torch.autograd.Function):
### Modified based on megatron.mpu.cross_entropy ###
@staticmethod
def forward(ctx, logits, targets):
# logits: [b/q, h/q]
# labels: [b/q]
# loss: [b/q]
# vocab_parallel_logits: [b/q, s, v/q]
# target: [b/q, s]
logits_max = torch.max(logits, dim=-1)[0]
torch.distributed.all_reduce(
logits_max,
op=torch.distributed.ReduceOp.MAX,
group=gpc.get_group(ParallelMode.PARALLEL_2D_ROW))
# Subtract the maximum value.
# vocab_parallel_logits.sub_(logits_max.unsqueeze(dim=-1))
logits = logits - logits_max.unsqueeze(dim=-1)
vocab_size = logits.size(-1)
rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
vocab_start = rank * (vocab_size)
vocab_end = (rank + 1) * (vocab_size) - 1
target_mask = (targets < vocab_start) | (targets > vocab_end)
masked_target = targets.clone() - vocab_start
masked_target[target_mask] = 0
arange_1d = torch.arange(
start=0, end=logits.size()[0],
)
predicted_logits = logits[arange_1d, masked_target]
predicted_logits[target_mask] = 0.
dist.all_reduce(predicted_logits, group=gpc.get_group(
ParallelMode.PARALLEL_2D_ROW))
exp_logits = torch.exp(logits)
sum_exp_logits = exp_logits.sum(dim=1)
dist.all_reduce(sum_exp_logits, group=gpc.get_group(
ParallelMode.PARALLEL_2D_ROW))
loss = torch.log(sum_exp_logits) - predicted_logits
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
ctx.save_for_backward(exp_logits, target_mask, masked_target)
return loss
@staticmethod
def backward(ctx, output_grad):
# Retreive tensors from the forward path.
softmax, target_mask, masked_target = ctx.saved_tensors
# All the inputs have softmax as their gradient.
grad_input = softmax
# For simplicity, work with the 2D gradient.
partition_vocab_size = softmax.size()[-1]
grad_2d = grad_input.view(-1, partition_vocab_size)
# Add the gradient from matching classes.
arange_1d = torch.arange(start=0, end=grad_2d.size()[0],
device=get_current_device())
grad_2d[arange_1d,
masked_target] -= (1.0 - target_mask.view(-1).float())
# Finally elementwise multiplication with the output gradients.
grad_input.mul_(output_grad.unsqueeze(dim=-1))
return grad_input, None
class _ReduceByColumn(torch.autograd.Function):
"""All-reduce the input from the model parallel region."""
@staticmethod
def symbolic(graph, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
def forward(ctx, input_):
dist.all_reduce(input_, group=gpc.get_group(
ParallelMode.PARALLEL_2D_COL))
return input_
@staticmethod
def backward(ctx, grad_output):
return grad_output
@LOSSES.register_module
class CrossEntropyLoss2D(_Loss):
"""Cross entropy loss for 2D parallelism
:param reduction: whether to average the loss, defaults to True
:type reduction: bool, optional
"""
def __init__(self, reduction=True):
super().__init__()
assert_summa_initialization()
self.summa_dim = get_summa_dim_from_env()
self.row_rank = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
self.reduction_mean = reduction
def forward(self, logits, targets):
targets = targets.chunk(self.summa_dim, dim=0)[self.row_rank]
loss = _ParallelCrossEntropyLossFunction_2D.apply(
logits, targets,
)
if self.reduction_mean:
loss = _ReduceByColumn.apply(loss) / self.summa_dim
dist_loss = loss.mean()
return dist_loss
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment